retinaFace PyTorch 版损失函数部分,SSD目标检测 multibox_loss 部分, match() 函数中先验框(anchor,PriorBox)与目标框匹配过程超详细解析

RetinaFace 因为其速度快,精度高,是工程中常用的人脸检测模型,和 SSD 一样,它在训练和预测阶段也采用了先验框。最近拿到一份 PyTorch 版本的 RetinaFace,在损失函数(loss function)部分,明显采用了 SSD 的 multibox loss 计算方法。

研究某个深度学习模型,除了它的网络设计之外,loss 的计算是最值得研究的了,因此本文将尝试对 multibox loss 计算过程中比较费解的 match() 函数逐行做一个比较深入的分析。

multibox loss 计算中的匹配过程

如前文所述,无论是 RetinaFace 还是 SSD 都采用了先验框,所谓“先验框”,其实就是依据事先的经验提出的一系列可能的目标框,这些先验框遍布整个特征图的位置,狭义的直观上来说,可以认为要检测的图中的每一个像素点都被事先预设了若干先验框,负责预测该像素点可能存在的目标。

请注意“可能”这个词。

在预测和训练阶段,需要先对所有的先验框做分类,区分哪些先验框是背景,哪些是目标(在 RetinaFace 中,目标即人脸),而区分的标准就是看目标框与先验框的 IOU,如果先验框 p 与目标框 t 的 IOU 高于设定的阈值,就让先验框 p 负责预测目标 t。如果某个先验框与所有的目标框的 IOU 都低于设定的阈值,那么这个先验框就属于背景。这样的过程称作“匹配”,mutltibox loss 计算中的 match() 函数就是负责“匹配”过程的。

match() 函数原型及其参数

在我拿到的代码中,match() 函数的原型如下,请看:

def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
    # jaccard index
    overlaps = jaccard(truths, point_form(priors))
    # (Bipartite Matching)
    # best prior for each ground truth
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)  # [num_obj, 1]

    # ignore hard gt
    valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
    best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]  
    if best_prior_idx_filter.shape[0] <= 0:
        loc_t[idx] = 0
        conf_t[idx] = 0
        return

    # best ground truth for each prior
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)  
    best_truth_idx.squeeze_(0) 
    best_truth_overlap.squeeze_(0)  

    best_prior_idx.squeeze_(1)
    best_prior_idx_filter.squeeze_(1) 
    best_prior_overlap.squeeze_(1) 

    best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2)  # ensure best prior
    # TODO refactor: index  best_prior_idx with long tensor
    # ensure every gt matches with its prior of max overlap
    for j in range(best_prior_idx.size(0)):   
        best_truth_idx[best_prior_idx[j]] = j
    matches = truths[best_truth_idx]          
    conf = labels[best_truth_idx]             
    conf[best_truth_overlap < threshold] = 0    
    loc = encode(matches, priors, variances)

    matches_landm = landms[best_truth_idx]
    landm = encode_landm(matches_landm, priors, variances)
    loc_t[idx] = loc    
    conf_t[idx] = conf  
    landm_t[idx] = landm

稍稍了解下参数:

  • threshold,即所谓的 IOU 阈值。
  • truths,目标框,形状为 [num_obj, 4],4 为矩形框的参数数目
  • priors,先验框,形状为 [num_priors, 4]
  • variances,每个先验框的参数,用于编码(本文不过于关注)
  • labels,目标类别的标签,形状为 [num_obj]
  • loc_tconf_tlandm_t,输出,分别为匹配好的先验框编码后的位置,置信度,人脸 landmark 的位置。
  • idx,索引,也即 batch 中的第 idx 张图片(元素)。

match() 函数解析

如前文所述,究竟将先验框匹配为目标还是背景取决于先验框与目标框的 IOU,因此 match() 函数首先计算 IOU:

overlaps = jaccard(truths, point_form(priors))

jaccard()函数不是本文的重点,它的目的其实就是计算输入的重叠区域比(也即 IOU),这里关注下 overlaps 的形状:[num_obj, num_priors],也即 overlaps 是一个“目标数”行,“先验框数”列的矩阵,为了便于理解,我们假设要预测的目标有 a, b, c 三类,先验框一共有 5 个,因此 overlaps 的形状如下表所示:

- 0 1 2 3 4
a 0.1 0.5 0.3 0.2 0.25
b 0.2 0.3 0.15 0.1 0.1
c 0.5 0.1 0.2 0.3 0.33

我们的目的是为每个先验框(priorBox)找到最匹配的目标,也就是找到与先验框 overlap 最大的目标,比如,对于第 0 个 priorBox,与它最匹配的显然是目标 c,与第 2 个 priorBox 最匹配的目标则显然是 a。所以,match() 函数中有下面这一句代码:

best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)

overlaps 是 PyTorch 中的 tensor,因此 max() 是 Tensor 类的方法,上面这行的目的是简单的:计算各个 priorBox 的最大 overlap,其中 best_truth_overlap 中存放的是 overlap 值,best_truth_idx 中存放的是最大值对应的下标索引:

- - - - - -
best_truth_overlap 0.5 0.5 0.3 0.3 0.33
best_truth_idx 2 0 0 2 2

请注意,best_truth_idx 中的数最大不会超过目标数 3。

此时 best_truth_idx 中存放的就是匹配的目标索引,比如,第 0 个元素值为 2,意思是第 0 个 priorBox 与第 2 个目标(c)匹配;第 1 个元素值为 0,意思是第 1 个 priorBox 与第 0 个目标(a)匹配。所以,有如下代码:

matches = truths[best_truth_idx]
conf = labels[best_truth_idx]

matches 和 conf 都是 num_priors 行的向量,前者表示num_priors个 priorBox 的匹配目标框,后者则表示各个匹配的置信度。可能有读者发现了,到这里所有的 priorBox 都匹配到目标了,即使相应的 overlap 非常低(只要它是最大值),这是不合理的,因为实际中对于一张待检测的图来说,目标 priorBox 的数目应该是远低于背景 priorBox 数目的。所以,需要设定一个阈值过滤下,如果 priorBox 与目标的 overlap 没有超过该阈值,则把该 priorBox 归为背景类:

conf[best_truth_overlap < threshold] = 0

至此,priorBox 与目标框匹配的基本过程基本完成,但是从 match() 函数的源代码来看,它所做的远不止于此,还有什么别的操作呢?

进一步解析

在实际应用中,一张待检测的图中的目标通常比较少,以人脸检测为例,可能一张高宽为 800x800 图中只有一个人脸。再考虑 priorBox 的数目,即使负责最终检测的特征图尺寸是原图的的 16 分之一(50x50),每个特征像素只有 2 个 priorBox,也会产生 50x50x2 = 5000 个 priorBox。这会导致严重的正负样本失衡。

要缓解正负样本失衡,其中一个符合直觉的做法就是产生尽量多的正样本,所以,尽可能的为每个目标都匹配上至少一个 priorBox,那么哪一个 priorBox 符合“至少”这个词呢?显然,应该是所有的 priorBox 中与该目标的 overlap 最大的那个,所以 match() 函数中有下面这段代码:

best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 

为了便于理解,我们还是假设要预测的目标有 a, b, c 三类,先验框一共有 5 个,所示:

- 0 1 2 3 4
a 0.1 0.5 0.3 0.2 0.25
b 0.2 0.3 0.15 0.1 0.1
c 0.5 0.1 0.2 0.3 0.33

此时,

- - - - - -
best_prior_overlap 0.5 0.3 0.5
best_prior_idx 1 1 0

best_prior_idx 中存放的就是与目标匹配的 priorBox 索引,例如第 0 个元素为 1,表示第 0 个目标与第 1 个 priorBox 匹配;第 2 个元素为 0,表示第 2 个目标与第 0 个 priorBox 匹配。

应注意,match() 函数用来做最终匹配操作(用来做标签truthslabels下标索引执行选择操作)的是 best_truth_overlap/idx

matches = truths[best_truth_idx]
conf = labels[best_truth_idx]

所以我们需要把 best_priors_overlap/idx 的结果叠加到best_truth_overlap/idx,那么这个叠加过程是怎样的呢?其实就是best_truth_overlap.index_fill_() 函数的使用了,请看:

价格: 1.10 元 (已有24人付款)
温馨提示:
输入邮箱或者手机号码,付款后可永久阅读隐藏的内容,请勿未经本站许可,擅自分享付费内容。
如果您已购买本页内容,输入购买时的手机号码或者邮箱,点击支付,即可查看内容。
电子商品,一经购买,不支持退款,特殊原因,请联系客服。
付费可读

我们就在基础之上,实现了“尽可能为每个目标都匹配至少一个 priorBox”的目的了。

小结

match() 函数做的工作把 priorBox 与目标匹配起来,而是否匹配的标准则是通过计算和对比二者的 overlap。多个目标与多个 priorBox 构成了一个 overlap 矩阵,直觉上来看,一般要挑选一个最大的 overlap 对应的匹配。为一个 priorBox 挑选一个匹配类,最基础就是计算 overlap 矩阵中对应的列矩阵的最大值对应的目标。在基础 之上,我们还想尽可能的多挖掘正样本,因此希望尽量的为每个目标都匹配一个 priorBox,这时就需要计算 overlap 的行向量最大值对应的 priorBox。这其实就是 match() 函数的基本设计思想。

阅读更多:   图像处理基础
已有 8 条评论
  1. Sunny

    懂了 2016kuk.gif

  2. spidermiao

    博主,我想付费看你的这篇文章,但是扫码显示二维码失效了,怎么办呢

    1. 我看看可能处理好 2016shuai.gif

  3. spidermiao

    博主,可以和你邮箱联系吗,我最近就是被这个困扰,希望能快点得到解答,毕竟散养的孩子只能面向csdn github学习

    1. faty2009@foxmail.com,qq 734657539

  4. yy

    扫码显示二维码失效了,怎么办

    1. 我看看可能处理好 2016shuai.gif

  5. 处理好了,可以付费了,不过不明白为什么要加钱才可以。。。 icon_question.gif

添加新评论

icon_redface.gificon_idea.gificon_cool.gif2016kuk.gificon_mrgreen.gif2016shuai.gif2016tp.gif2016db.gif2016ch.gificon_razz.gif2016zj.gificon_sad.gificon_cry.gif2016zhh.gificon_question.gif2016jk.gif2016bs.gificon_lol.gif2016qiao.gificon_surprised.gif2016fendou.gif2016ll.gif