1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > coco数据集大小分类_如何处理不平衡数据集的分类任务

coco数据集大小分类_如何处理不平衡数据集的分类任务

时间:2022-01-04 22:07:17

相关推荐

coco数据集大小分类_如何处理不平衡数据集的分类任务

在情感分类任务中,数据集的标签分布往往是极度不平衡的。以我目前手上的这个二分类任务来说,正例样本14.4万个:负例样本166.1万 = 1 :11.5。很显然这是一个极度不平衡的数据集,假设我把样本全部预测为负,那准确率也高达92%,但这么做没有意义。

那么我们如何处理这个不平衡数据集呢?

因为我用的是神经网络,我不希望减少训练样本,因此我不会采用下采样的方式。有三个方向可以尝试:

使用自定义的loss函数设置class weight设置sample weight

这里,我将尝试三种不同的loss函数并进行对比。

(一)三种损失函数

表示样本的真实标签,则 。令 表示sigmoid输出的预测类别为1的概率,显然 。下面给出三种损失函数的定义:

1. Binary crossentropy

大名鼎鼎的交叉熵损失函数,这里就不赘述了。

2. 修正的Binary crossentropy

这个损失函数来自苏神的两篇文章:

【1】文本情感分类(四):更好的损失函数

【2】何恺明大神的「Focal Loss」,如何更好地理解?

引入单位跃阶函数

取定阈值

(可调超参数,原则上大于0.5均可),则:

(这里我稍微修改了一点点,以使得损失函数更加对称)

这个损失函数跟Binary Crossentropy比起来,就是多了

这个调节因子,我们来分析一下这个公式:当正样本的预测概率大于m时,根据 的定义,这一项的损失就变为了0;当正样本的预测概率小于m时,保持这一项损失不变;当负样本的预测概率小于1-m时,这一项的损失也变为了0;当负样本的预测概率大于1-m时,保持这一项损失不变。

也就是说,这个损失函数将焦点放在了分类错误的样本上面,希望能够把更多的样本正确分类。

3. Focal Loss

来自论文Focal Loss for Dense Object Detection

其中

为权重因子, 为调节参数。

我们来分析一下这个函数,考虑

的情形:当正样本的预测概率 接近1时(我们希望的),则 接近0,则 就变得很小很小。也就是说,当某个样本分类合理时,函数会对其损失进行打折(down weighting),而打折的幅度依赖于参数 。当正样本的预测概率 接近0时(我们不希望的),则 接近1,因此 会变小一点点,但跟上面的情况比起来,其实相当于是放大了,因为大小是相对的。

对于负样本,同理可分析,此处略过。

再来考虑

这个参数,它其实是一个权重的调节因子,用于平衡正负样本的损失贡献。但由于 的存在,我们很难从实际数据中得到指导来设置这个参数,更多可能要去尝试和调参。一般情况下,先令 。

Focal Loss函数对容易分类的样本进行down weighting,聚焦于难分类的样本上。跟苏神的那个思路类似,却更加高明。

那么在类别不均衡的分类任务中,这个损失函数到底怎么起作用呢?

我们知道

,而我的任务中负样本占了绝大多数,对模型来说,它们绝大部分是很好分类的样本,因此它们的损失贡献会大打折扣,从而使模型聚焦在难分类的样本上面,包括绝大部分的正样本。同时, 这个参数也能起到平衡正负样本损失的作用。

(二)损失函数代码

1. 修正的Binary crossentropy(keras版本)

import keras.backend as Kmargin = 0.8theta = lambda t: (K.sign(t)+1.)/2.def variant_crossentropy_loss(y_true, y_pred):return - theta(margin - y_pred) * y_true * K.log(y_pred + 1e-9)- theta(y_pred - 1 + m) * (1 - y_true) * K.log(1 - y_pred + 1e-9))

2. Focal Loss(tensorflow版本)

由于网上的代码都是多分类的(基于softmax输出的),这里我写了一个二分类的(基于sigmoid输出),同时我还加了一个rescale的flag来控制损失函数的量级,单任务学习中,这个flag按照默认的False即可。

import tensorflow as tfdef variant_focal_loss(gamma=2., alpha=0.5, rescale = False):gamma = float(gamma)alpha = float(alpha)def focal_loss_fixed(y_true, y_pred):"""Focal loss for bianry-classificationFL(p_t)=-rescaled_factor*alpha_t*(1-p_t)^{gamma}log(p_t)Notice: y_pred is probability after sigmoidArguments:y_true {tensor} -- groud truth label, shape of [batch_size, 1]y_pred {tensor} -- predicted label, shape of [batch_size, 1]Keyword Arguments:gamma {float} -- (default: {2.0}) alpha {float} -- (default: {0.5})Returns:[tensor] -- loss."""epsilon = 1.e-9 y_true = tf.convert_to_tensor(y_true, tf.float32)y_pred = tf.convert_to_tensor(y_pred, tf.float32)model_out = tf.clip_by_value(y_pred, epsilon, 1.-epsilon) # to advoid numeric underflow# compute cross entropy ce = ce_0 + ce_1 = - (1-y)*log(1-y_hat) - y*log(y_hat)ce_0 = tf.multiply(tf.subtract(1., y_true), -tf.log(tf.subtract(1., model_out)))ce_1 = tf.multiply(y_true, -tf.log(model_out))# compute focal loss fl = fl_0 + fl_1# obviously fl < ce because of the down-weighting, we can fix it by rescaling# fl_0 = -(1-y_true)*(1-alpha)*((y_hat)^gamma)*log(1-y_hat) = (1-alpha)*((y_hat)^gamma)*ce_0fl_0 = tf.multiply(tf.pow(model_out, gamma), ce_0)fl_0 = tf.multiply(1.-alpha, fl_0)# fl_1= -y_true*alpha*((1-y_hat)^gamma)*log(y_hat) = alpha*((1-y_hat)^gamma*ce_1fl_1 = tf.multiply(tf.pow(tf.subtract(1., model_out), gamma), ce_1)fl_1 = tf.multiply(alpha, fl_1)fl = tf.add(fl_0, fl_1)f1_avg = tf.reduce_mean(fl)if rescale:# rescale f1 to keep the quantity as cece = tf.add(ce_0, ce_1)ce_avg = tf.reduce_mean(ce)rescaled_factor = tf.divide(ce_avg, f1_avg + epsilon)f1_avg = tf.multiply(rescaled_factor, f1_avg)return f1_avgreturn focal_loss_fixed

(三)结果对比

我采用了双向GRU模型,在保持模型以及数据不变的情况下,仅改变损失函数以对比不同损失函数在我任务上的表现,由于Local CV是我关注的最终metric,我使用Local CV作为early stopping的依据,若两个epoch后Local CV没有得到提升,则模型停止训练,并取Local CV最高的模型作为预测模型。

以下是各个损失函数在验证集上表现,我取了三个维度:

Accuracy(阈值为0.5)AUC(Area Under the Curve)Local CV(本质是多个AUC的加权平均)

根据以上数据,我们可以得到如下结论:

Focal Loss在我的任务上获得了最大的Local CV,比带class weight的Binary Crossentropy损失高出3个千分点。Focal Loss真是一个优秀的损失函数!Variant Crossentropy这个损失函数获得了最高的Accuracy,但是AUC和Local CV都很低,显然不适合我手中的任务。根据Variant Crossentropy的公式,其实也可推断,这个损失函数是在优化Accuracy。对于关注正确率的任务,这个损失函数应该是不错的选择。

参考资料:

【1】非平衡数据集 focal loss 多类分类

【2】Focal Loss for Dense Object Detection

【3】文本情感分类(四):更好的损失函数

【4】何恺明大神的「Focal Loss」,如何更好地理解?

(注:本文同步发布于简书Littletree_Zou。)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。