retinaFace PyTorch 版损失函数部分,SSD目标检测 multibox_loss 部分, match() 函数中先验框(anchor,PriorBox)与目标框匹配过程超详细解析
发表于: 2020-09-30 08:27:00 | 已被阅读: 133 | 分类于: 图像处理基础
RetinaFace 因为其速度快,精度高,是工程中常用的人脸检测模型,和 SSD 一样,它在训练和预测阶段也采用了先验框。最近拿到一份 PyTorch 版本的 RetinaFace,在损失函数(loss function)部分,明显采用了 SSD 的 multibox loss 计算方法。
研究某个深度学习模型,除了它的网络设计之外,loss 的计算是最值得研究的了,因此本文将尝试对 multibox loss 计算过程中比较费解的 match() 函数逐行做一个比较深入的分析。
multibox loss 计算中的匹配过程
如前文所述,无论是 RetinaFace 还是 SSD 都采用了先验框,所谓“先验框”,其实就是依据事先的经验提出的一系列可能的目标框,这些先验框遍布整个特征图的位置,狭义的直观上来说,可以认为要检测的图中的每一个像素点都被事先预设了若干先验框,负责预测该像素点可能存在的目标。
请注意“可能”这个词。
在预测和训练阶段,需要先对所有的先验框做分类,区分哪些先验框是背景,哪些是目标(在 RetinaFace 中,目标即人脸),而区分的标准就是看目标框与先验框的 IOU,如果先验框 p 与目标框 t 的 IOU 高于设定的阈值,就让先验框 p 负责预测目标 t。如果某个先验框与所有的目标框的 IOU 都低于设定的阈值,那么这个先验框就属于背景。这样的过程称作“匹配”,mutltibox loss 计算中的 match() 函数就是负责“匹配”过程的。
match() 函数原型及其参数
在我拿到的代码中,match() 函数的原型如下,请看:
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
# jaccard index
overlaps = jaccard(truths, point_form(priors))
# (Bipartite Matching)
# best prior for each ground truth
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) # [num_obj, 1]
# ignore hard gt
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
loc_t[idx] = 0
conf_t[idx] = 0
return
# best ground truth for each prior
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_idx_filter.squeeze_(1)
best_prior_overlap.squeeze_(1)
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
# TODO refactor: index best_prior_idx with long tensor
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.size(0)):
best_truth_idx[best_prior_idx[j]] = j
matches = truths[best_truth_idx]
conf = labels[best_truth_idx]
conf[best_truth_overlap < threshold] = 0
loc = encode(matches, priors, variances)
matches_landm = landms[best_truth_idx]
landm = encode_landm(matches_landm, priors, variances)
loc_t[idx] = loc
conf_t[idx] = conf
landm_t[idx] = landm
稍稍了解下参数:
threshold
,即所谓的 IOU 阈值。truths
,目标框,形状为 [num_obj, 4],4 为矩形框的参数数目priors
,先验框,形状为 [num_priors, 4]variances
,每个先验框的参数,用于编码(本文不过于关注)labels
,目标类别的标签,形状为 [num_obj]loc_t
,conf_t
,landm_t
,输出,分别为匹配好的先验框编码后的位置,置信度,人脸 landmark 的位置。idx
,索引,也即 batch 中的第 idx 张图片(元素)。
match() 函数解析
如前文所述,究竟将先验框匹配为目标还是背景取决于先验框与目标框的 IOU,因此 match() 函数首先计算 IOU:
overlaps = jaccard(truths, point_form(priors))
jaccard()
函数不是本文的重点,它的目的其实就是计算输入的重叠区域比(也即 IOU),这里关注下 overlaps 的形状:[num_obj, num_priors]
,也即 overlaps 是一个“目标数”行,“先验框数”列的矩阵,为了便于理解,我们假设要预测的目标有 a, b, c
三类,先验框一共有 5 个,因此 overlaps 的形状如下表所示:
- | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
a | 0.1 | 0.5 | 0.3 | 0.2 | 0.25 |
b | 0.2 | 0.3 | 0.15 | 0.1 | 0.1 |
c | 0.5 | 0.1 | 0.2 | 0.3 | 0.33 |
我们的目的是为每个先验框(priorBox)找到最匹配的目标,也就是找到与先验框 overlap 最大的目标,比如,对于第 0 个 priorBox,与它最匹配的显然是目标 c,与第 2 个 priorBox 最匹配的目标则显然是 a。所以,match() 函数中有下面这一句代码:
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
overlaps
是 PyTorch 中的 tensor,因此 max() 是 Tensor 类的方法,上面这行的目的是简单的:计算各个 priorBox 的最大 overlap,其中 best_truth_overlap
中存放的是 overlap 值,best_truth_idx
中存放的是最大值对应的下标索引:
- | - | - | - | - | - |
---|---|---|---|---|---|
best_truth_overlap | 0.5 | 0.5 | 0.3 | 0.3 | 0.33 |
best_truth_idx | 2 | 0 | 0 | 2 | 2 |
请注意,
best_truth_idx
中的数最大不会超过目标数 3。
此时 best_truth_idx
中存放的就是匹配的目标索引,比如,第 0 个元素值为 2,意思是第 0 个 priorBox 与第 2 个目标(c)匹配;第 1 个元素值为 0,意思是第 1 个 priorBox 与第 0 个目标(a)匹配。所以,有如下代码:
matches = truths[best_truth_idx]
conf = labels[best_truth_idx]
matches 和 conf 都是 num_priors
行的向量,前者表示num_priors
个 priorBox 的匹配目标框,后者则表示各个匹配的置信度。可能有读者发现了,到这里所有的 priorBox 都匹配到目标了,即使相应的 overlap 非常低(只要它是最大值),这是不合理的,因为实际中对于一张待检测的图来说,目标 priorBox 的数目应该是远低于背景 priorBox 数目的。所以,需要设定一个阈值过滤下,如果 priorBox 与目标的 overlap 没有超过该阈值,则把该 priorBox 归为背景类:
conf[best_truth_overlap < threshold] = 0
至此,priorBox 与目标框匹配的基本过程基本完成,但是从 match() 函数的源代码来看,它所做的远不止于此,还有什么别的操作呢?
进一步解析
在实际应用中,一张待检测的图中的目标通常比较少,以人脸检测为例,可能一张高宽为 800x800 图中只有一个人脸。再考虑 priorBox 的数目,即使负责最终检测的特征图尺寸是原图的的 16 分之一(50x50),每个特征像素只有 2 个 priorBox,也会产生 50x50x2 = 5000 个 priorBox。这会导致严重的正负样本失衡。
要缓解正负样本失衡,其中一个符合直觉的做法就是产生尽量多的正样本,所以,尽可能的为每个目标都匹配上至少一个 priorBox,那么哪一个 priorBox 符合“至少”这个词呢?显然,应该是所有的 priorBox 中与该目标的 overlap 最大的那个,所以 match() 函数中有下面这段代码:
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
为了便于理解,我们还是假设要预测的目标有 a, b, c
三类,先验框一共有 5 个,所示:
- | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
a | 0.1 | 0.5 | 0.3 | 0.2 | 0.25 |
b | 0.2 | 0.3 | 0.15 | 0.1 | 0.1 |
c | 0.5 | 0.1 | 0.2 | 0.3 | 0.33 |
此时,
- | - | - | - |
---|---|---|---|
best_prior_overlap | 0.5 | 0.3 | 0.5 |
best_prior_idx | 1 | 1 | 0 |
best_prior_idx
中存放的就是与目标最匹配的 priorBox 索引,例如第 0 个元素为 1,表示第 0 个目标与第 1 个 priorBox 匹配;第 2 个元素为 0,表示第 2 个目标与第 0 个 priorBox 匹配。
应注意,match() 函数用来做最终匹配操作(用来做标签truths
和labels
下标索引执行选择操作)的是 best_truth_overlap/idx
:
matches = truths[best_truth_idx]
conf = labels[best_truth_idx]
所以我们需要把 best_priors_overlap/idx
的结果叠加到best_truth_overlap/idx
,那么这个叠加过程是怎样的呢?其实就是best_truth_overlap.index_fill_()
函数的使用了,请看:
我们就在基础之上,实现了“尽可能为每个目标都匹配至少一个 priorBox”的目的了。
小结
match() 函数做的工作把 priorBox 与目标匹配起来,而是否匹配的标准则是通过计算和对比二者的 overlap。多个目标与多个 priorBox 构成了一个 overlap 矩阵,直觉上来看,一般要挑选一个最大的 overlap 对应的匹配。为一个 priorBox 挑选一个匹配类,最基础就是计算 overlap 矩阵中对应的列矩阵的最大值对应的目标。在基础 之上,我们还想尽可能的多挖掘正样本,因此希望尽量的为每个目标都匹配一个 priorBox,这时就需要计算 overlap 的行向量最大值对应的 priorBox。这其实就是 match() 函数的基本设计思想。