+-
PyTorch版深度残差收缩网络的代码

本文转载自知乎:https://zhuanlan.zhihu.com/p/...

原文:Deep Residual Shrinkage Networks for Fault Diagnosis

作者:Minghang Zhao , Shisheng Zhong, Xuyun Fu

时间:2019年9月

1. Introduction(介绍)

本文主要是针对于判断机械传动系统中,噪声带来的影响而导致故障诊断失误的情况,为了解决这样一个问题,而提出了一种残差收缩网络,通过机器学习的方法来自适应确定软阈值去消除噪声的影响。

并且该文提出来两种情况下的残差收缩网络:DRSN-CS和DRSN-CW。

2. Architecture of the DRSN(DRSN的架构)

本文的模型架构是基于深度残差卷积网络而提出来的,其残差单元是Resnet的基本组成部分,与一般Convnet不同的是,Resnet有一个shortcut(捷径),能够使网络更深。示意图如下图所示:

在这里插入图片描述
(a)中表示在两次卷积后,输入的大小依旧没变,所以只需要将第一次卷积的输入和后两次的卷积相加即可。

(b)中两次卷积后对图像的宽度减半,所以在shortcut部分也需要通过一次卷积缩减为一样大小。(c)对通道的扩充也同理。

(d)RUB表示的就是(a)(b)(c)中的残差模块。

2.1 Soft thresholding(软阈值)

软阈值通常用来过滤掉噪声,但是对于不同的情况,需要专业知识去确定阈值大小,下图是软阈值的函数图以及导数图:
在这里插入图片描述
(a)表示软阈值的曲线,将0附近的值变为0,其他值不变。(b)表示其梯度,两边为1,中间为0。

这样一个经典的软阈值函数,通常只适用于部分情况,对于大多数情况,还是需要特定的估计测试才能确定好适用的阈值函数。所以一种自适应阈值的方式将可解决这样一种情况。

2.2 Architecture of the Developed DRSN-CS(DRSN-CS的架构)

DRSN-CS是Resnet的一种变体,使用了自适应软阈值层的方式(残差收缩层)去消除噪声,这个残差收缩层的结构如下图所示:
在这里插入图片描述
其中经过残差的两次卷积后,将结果取绝对值,之后经过GAP(全局平均池化),将W缩减为1,之后再经过两次全连接层得到z,该结果进行Sigmoid得到 [公式] ,最后将GAP的结果取平均后与 [公式] 相乘,得到软阈值的结果。最后与原输入进行阈值化。(PS. 该论文中有keras的源代码和tf的源代码),我写了下pytorch收缩层的源代码,如下:

class Shrinkage(nn.Module):
    def __init__(self, gap_size, channel):
        super(Shrinkage, self).__init__()
        self.gap = nn.AdaptiveAvgPool2d(gap_size)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel),
            nn.BatchNorm1d(channel),
            nn.ReLU(inplace=True),
            nn.Linear(channel, 1),
            nn.Sigmoid(),
        )
   def forward(self, x):
        x_raw = x
        x = torch.abs(x)
        x_abs = x
        x = self.gap(x)
        x = torch.flatten(x, 1)
        average = torch.mean(x, dim=1, keepdim=True)
        # average = x
        x = self.fc(x)
        x = torch.mul(average, x)
        x = x.unsqueeze(2).unsqueeze(2)
        # 软阈值化
        sub = x_abs - x
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        x = torch.mul(torch.sign(x_raw), n_sub)
        return x

经过测试,在图像分类任务中,这个模型效果不太行,与baseline相比,几乎没有提升,还有点下降,对于图像来讲这个还是慎用,因为它将channel给平均下来了。

2.3 Architecture of the Developed DRSN-CW(DRSN-CW的架构)

RSN-CW与DRSN-CS的架构类似,不同点在于它没有将channel给平均掉,结构图如下:

在这里插入图片描述
从图中可知,Average运算消失了,这是他们唯一不同的点,程序修改也很简单,代码如下:

 class Shrinkage(nn.Module):
    def __init__(self, gap_size, channel):
        super(Shrinkage, self).__init__()
        self.gap = nn.AdaptiveAvgPool2d(gap_size)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel),
            nn.BatchNorm1d(channel),
            nn.ReLU(inplace=True),
            nn.Linear(channel, 1),# 可能应该是nn.Linear(channel, channel)
            nn.Sigmoid(),
        )
    def forward(self, x):
        x_raw = x
        x = torch.abs(x)
        x_abs = x
        x = self.gap(x)
        x = torch.flatten(x, 1)
        # average = torch.mean(x, dim=1, keepdim=True)
        average = x
        x = self.fc(x)
        x = torch.mul(average, x)
        x = x.unsqueeze(2).unsqueeze(2)
        # 软阈值化
        sub = x_abs - x
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        x = torch.mul(torch.sign(x_raw), n_sub)
        return x

同样,在图像分类任务中,这个模型效果效果提升十分明显,在测试时,准确率提升3%左右,训练时,提升的效果8%左右,对于图像来说,这个效果还是可以的。

3. Experimental Results(实验结果)

该文使用的数据集是机械传动诊断仿真器中的噪声,通过与Convnet,Resnet,DRSN-CS和DRSN-CW的准确率进行比较,得到了如下结果:

在这里插入图片描述
TABLE III中手动添加了噪声,整体性能下降,在训练集和测试集中,DRSN-CW性能表现突出。在TABLE IV中,没有添加手动噪声,整体性能都较高,其中DRSN-CW表现优秀。

4. Conclusion(结论)

创新点:

在Resnet中的残差模块,增加了一个收缩层(自适应软阈值层)去消除无关噪声——残差收缩网络,这样一个网络有两种变体,对于图像来讲,建议使用DRSN-CW,有3%左右的性能提升。