**UNet++: A Nested U-Net Architecture for Medical Image Segmentation**    
*Zongwei Zhou, Md Mahfuzur Rahman Siddiquee, Nima Tajbakhsh, Jianming Liang*  
DLMIA 2018  



* 본 논문에 이미 pytorch 구현의 GitHub link를 제공하고 있음. 실제 운용 면에서 더 섬세하게 구현을 해놓으셨기 때문에 이해가 아닌 적용을 목적에 둔 경우 이 글보단 원본 깃허브를 참조하면서 공부하는게 더 낫겠다 싶다. 


U-Net ++의 핵심은 encoder와 decoder를 연결해주는 skip connection에서 sementatic gap (dissimilar)를 최소화 하는 것을 목표로 skip connection 위에 convolution block을 여러장 두는 방식으로 접근을 한다. 


A model implementation of the paper, UNet++: A Nested U-Net Architecture for Medical Image Segmentation by PyTorch.

You can check this implementation version via my GitHub link.

Basic Operation Module


- 컨볼루션 레이어 한장! (weight, normalization, activation function)

class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) -> None:
        super(ConvLayer, self).__init__()

        self.conv = nn.Sequential(
                conv_type(in_dim, out_dim, kernel_size=3, stride=1, padding=1),

    def forward(self, inputs):
        return self.conv(inputs)


class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None, num_conv=2,
                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) -> None:
        assert num_conv > 0

        super(ConvBlock, self).__init__()

        if hidden_dim is None:
            hidden_dim = out_dim

        if num_conv == 1:
            self.blocks = ConvLayer(in_dim, out_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)
            self.blocks = nn.Sequential(
                        [ConvLayer(in_dim, hidden_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
                        + [ConvLayer(hidden_dim, hidden_dim, 
                                     conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) for _ in range(num_conv - 2)]
                        + [ConvLayer(hidden_dim, out_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]

    def forward(self, inputs):
        return self.blocks(inputs)


class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear') -> None:
        super(UpsamplingLayer, self).__init__()

        if is_deconv:
            self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, kernel_size=2, stride=2, padding=0)
            self.upsampler = nn.Upsample(scale_factor=2, mode=mode)

    def forward(self, x):
        return self.upsampler(x)



- encoder class는 U-Net과 동일

class Encoder(nn.Module):
    def __init__(self) -> None:
        super(Encoder, self).__init__()

        self.dim = 32

        self.conv1 = ConvBlock(1, self.dim, num_conv=2)
        self.conv2 = ConvBlock(self.dim, self.dim*2, num_conv=2)
        self.conv3 = ConvBlock(self.dim*2, self.dim*4, num_conv=2)
        self.conv4 = ConvBlock(self.dim*4, self.dim*8, num_conv=2)
        self.conv5 = ConvBlock(self.dim*8, self.dim*16,num_conv=2)

        self.pool  = nn.MaxPool2d(2,2)

    def forward(self, inputs):

        h1 = self.conv1(inputs)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        h5 = self.conv5(h4)

        stage_outputs = [h1, h2, h3, h4]

        return h5, stage_outputs



  • 아래층의 입력들 ($x_{i-1,j}$)을  upsampling을 수행
  • $x_{i,0}$를 몇 개의 컨볼루션 레이어로 구성된 skip path를 거치게 하여 skip connection을 수행
  • 또한, id_x에 dense connection을 위해 $H$(=convolution)을 통과 하기 전 후의 임베딩을 concat
  • 각 컨볼루션 레이어를 거친 아웃풋 xi_inter는 위 층의 skip pathway로 전달되어야 하기 때문에 저장
class SkipPathways(nn.Module):
    def __init__(self, in_dim, out_dim, lower_dim, path_length=1) -> None:
        super(SkipPathways, self).__init__()

        self.path_length = path_length

        self.conv_path = nn.ModuleList([])
        self.upSamplers = nn.ModuleList([])

        for idx in range(path_length): 
            self.conv_path.append(ConvLayer(in_dim*idx, out_dim,
                             conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU))
            self.upSamplers.append(UpsamplingLayer(lower_dim, out_dim, is_deconv=True))
    def forward(self, xi_0, lowers=[]):
        xi_0: the same level encoder output (Xi_0)
        lowers: the 1 lower level encoder and pathways outputs (X_{i-1}j)
        assert self.path_length == (len(lowers))

        xi_j = []

        id_x = xi_0 
        xi_inter = xi_0
        for idx in range(self.path_length):
            upsampled_lower = self.upSamplers[idx](lowers[idx])
            concated = torch.concat((xi_inter, upsampled_lower), dim=1)
            xi_inter = self.conv_path[idx](concated)
            id_x = torch.concat(id_x, xi_inter) #xi_1 + xi_0

        return id_x, xi_j



1x1 conv for 'deep supervision'; 본 논문에서 L4만 사용하면 L3를 사용하는 것보다 성능이 떨어짐을 발표에서 보여주었는데, deep supervision을 사용하는 경우 그 성능이 다른 것보다 매우 좋아지는 것을 보여줬음

class Decoder(nn.Module):
    def __init__(self) -> None:
        super(Decoder, self).__init__()

        self.dim = 32

        # X4_0 -> X3_1
        self.up1    = UpsamplingLayer(self.dim*16, self.dim*8, is_deconv=True)
        self.conv1  = ConvBlock(self.dim*16, self.dim*8, num_conv=2)

        # X3_1 && X2_0 -> X2_2
        self.up2    = UpsamplingLayer(self.dim*8, self.dim*4, is_deconv=True)
        self.skip2  = SkipPathways(in_dim=self.dim*4 + self.dim*8, out_dim=self.dim*4, lower_dim=self.dim*8, path_length=1)
        self.conv2  = ConvBlock(self.dim*8, self.dim*4, num_conv=2)

        # X2_2 && X1_0 -> X1_3
        self.up3    = UpsamplingLayer(self.dim*4, self.dim*2, is_deconv=True)
        self.skip3  = SkipPathways(in_dim=self.dim*2 + self.dim*4, out_dim=self.dim*2, lower_dim=self.dim*4, path_length=2)
        self.conv3  = ConvBlock(self.dim*4, self.dim*2, num_conv=2)

        # X1_3 && X0_0 -> X0_4
        self.up4    = UpsamplingLayer(self.dim*2, self.dim, is_deconv=True)
        self.skip4  = SkipPathways(in_dim=self.dim + self.dim*2, out_dim=self.dim, lower_dim=self.dim*2, path_length=3)
        self.conv4  = ConvBlock(self.dim*2, self.dim, num_conv=2)

        # for deep supervision
        self.dsp_l1 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
        self.dsp_l2 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
        self.dsp_l3 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
        self.dsp_l4 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)

    def forward(self, enc_out, stage_outputs):

        x3_0 = stage_outputs[-1]
        x3_1 = self.up1(enc_out)
        x3_1 = torch.concat((x3_1, x3_0), dim=1) # x3_0, x4_0
        x3_1 = self.conv1(x3_1)

        x2_0 = stage_outputs[-2]
        x2_2 = self.up2(x3_1)
        x2_1, x2_j = self.skip2(x2_0, [x3_0])
        x2_2 = torch.concat((x2_2, x2_1), dim=1)
        x2_2 = self.conv2(x2_2)

        x1_0 = stage_outputs[-3]
        x1_3 = self.up3(x2_2)
        x1_2, x1_j = self.skip3(x1_0, x2_j)
        x1_3 = torch.concat((x1_3, x1_2), dim=1)
        x1_3 = self.conv3(x1_3)

        x0_0 = stage_outputs[-4]
        x0_4 = self.up4(x1_3)
        x0_3, x0_j = self.skip4(x0_0, x1_j)
        x0_4 = torch.concat((x0_4, x0_3), dim=1)
        x0_4 = self.conv4(x0_4)

        out_l1 = self.dsp_l1(x0_j[0])
        out_l2 = self.dsp_l2(x0_j[1])
        out_l3 = self.dsp_l3(x0_j[2])
        out_l4 = self.dsp_l4(x0_4)

        return out_l1, out_l2, out_l3, out_l4



class UNetpp(nn.modules):
    def __init__(self) -> None:
        super(UNetpp, self).__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, inputs):
        enc_out, stage_outputs = self.encoder(inputs)

        out_l1, out_l2, out_l3, out_l4 = self.decoder(enc_out, stage_outputs)

        # for deep supervision
        # out = (out_l1 + out_l2 + out_l3 + out_l4) / 4

        return out_l1, out_l2, out_l3, out_l4

