Coding, Filming, and Nothing
article thumbnail

Info

**UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation**    
*Huimin Huang, Lanfen Lin, Ruofeng Tong, Hongjie Hu, Qiaowei Zhang, Yutaro Iwamoto, Xianhua Han, Yen-Wei Chen, Jian Wu*  
ICASSP(IEEE) 2020  

-

U-Net++와 마찬가지로, U-Net3+ 모델은 Skip connection의 개선을 목표로 두고 설계된 모델이다. 본문에서 저자는 U-Net의 skip connection과 U-Net++의 Skip Pathways (skip connection)이 아직 개선될 여지가 많다는 점을 말하며 모델의 전개를 시작한다. 

 

U-Net3+에서 사용하는 skip connection은 Full-scale skip connection이라는 이름을 가진 네트워크 모듈이다. 3가지의 서로 다른 스케일의 이미지 임베딩을 입력으로 받고 새로운 이미지 임베딩을 생성한다. 1) fine-grained인 얕은 인코더의 출력이미지, 2) 동일한 해상도를 가지고 있는(동일한 층) 인코더 출력이미지, 3) 더 낮은 해상도를 가진 디코더 출력이미지. 

 

그리고 그 외에 classification guidance module을 두어 이미지에 target organ이 포함되어 있는지를 구별하는 classifier를 배치시켜놓아 학습에 도움되게 하였다고 한다. 특이한 점은 3D volumne데이터를 빠른 처리를 위해서 3 channel을 가진 2d 이미지 데이터로 전처리를 하여 입력을 하였다는 점이다. 

 

또한, U-Net++에서 그러했듯이 deep supervision기술이 사용된다. U-Net3+는 각 디코더에서 deep supervision 학습이 진행된다. 


 

U-Net schemes

 

Basic operation Modules

Hyperparameter configuration

- 본문에서 특이한 점은, 3D 이미지 처리를 수행하는 모델의 '속도 향상'을 위해서 voxel 이미지를 위 아래로 인접한 이미지를 합쳐서 3 channel의 2D 이미지로 만들어서 처리했다는 점이다. 그래서 args.in_dim = 3 으로 설정된다.

import torch
import torch.nn as nn
import torch.nn.functional as F


from easydict import EasyDict as edict

args = edict() 

# net dim 
args.in_dim     = 3  # 3d img -> 2d img with 3 channels
args.init_dim   = 16
args.enc_depth  = 5
args.net_dim    = [args.init_dim*2**x for x in range(args.enc_depth)] # [16, 32, 64, 128, 256]
args.out_dim    = 2

# net operator
conv_kwargs = edict()
conv_kwargs.kernel_size = 3
conv_kwargs.stride      = 1
conv_kwargs.padding     = 1

down_kwargs = edict()
down_kwargs.kernel_size = 2
down_kwargs.stride      = 2
down_kwargs.padding     = 0

up_kwargs = edict()
up_kwargs.kernel_size = 2
up_kwargs.stride      = 2
up_kwargs.padding     = 0
up_kwargs.scale       = 2

args.conv_kwargs = conv_kwargs
args.down_kwargs = down_kwargs
args.up_kwargs   = up_kwargs

 

 

Convolution Layer

  • 컨볼루션 레이어 한장
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim,
                    conv_type=nn.Conv2d, conv_kwargs=None,
                    norm_type=nn.BatchNorm2d, act_type=nn.ReLU) -> None:
        
        super(ConvLayer, self).__init__()

        if conv_kwargs is None:
            conv_kwargs = edict()
            conv_kwargs.kernel_size = 3
            conv_kwargs.stride      = 1
            conv_kwargs.padding     = 1            
        

        if norm_type is None:
            self.conv = nn.Sequential(
                    conv_type(in_dim, out_dim, **conv_kwargs),
                    act_type()
                )
        else:
            self.conv = nn.Sequential(
                    conv_type(in_dim, out_dim, **conv_kwargs),
                    norm_type(out_dim),
                    act_type()
                )

        
    def forward(self, inputs):
        return self.conv(inputs)

 

 

Convolution Block 

class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None, num_conv=2,
                    conv_type=nn.Conv2d, conv_kwargs=None,
                    norm_type=nn.BatchNorm2d, 
                    act_type=nn.ReLU) -> None:
        assert num_conv > 0

        super(ConvBlock, self).__init__()

        if hidden_dim is None:
            hidden_dim = out_dim

        if conv_kwargs is None:
            conv_kwargs = edict()
            conv_kwargs.kernel_size = 3
            conv_kwargs.stride      = 1
            conv_kwargs.padding     = 1

        if num_conv == 1:
            self.blocks = ConvLayer(in_dim, out_dim,
                                    conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)
        else:
            self.blocks = nn.Sequential(
                    *(
                        [ConvLayer(in_dim, hidden_dim,
                                    conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
                        + [ConvLayer(hidden_dim, hidden_dim, 
                                     conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) for _ in range(num_conv - 2)]
                        + [ConvLayer(hidden_dim, out_dim,
                                    conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
                    )
                )

    def forward(self, inputs):
        return self.blocks(inputs)

 

 

UpsamplingLayer

class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear', up_kwargs=None) -> None:
        super(UpsamplingLayer, self).__init__()

        if up_kwargs is None:
            up_kwargs = edict()
            up_kwargs.kernel_size = 2
            up_kwargs.stride      = 2
            up_kwargs.padding     = 0
            up_kwargs.scale       = 2

        if is_deconv:
            self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, **up_kwargs)
        else:
            # dim doesn't be changed
            self.upsampler = nn.Upsample(mode=mode, **up_kwargs)

    def forward(self, x):
        return self.upsampler(x)

 

Encoder

  • U-Net, U-Net++와 동일하다.
class Encoder(nn.Module):
    def __init__(self, args) -> None:
        super(Encoder, self).__init__()

        

        self.conv1 = ConvBlock(args.in_dim,     args.net_dim[0], num_conv=2)
        self.conv2 = ConvBlock(args.net_dim[0], args.net_dim[1], num_conv=2)
        self.conv3 = ConvBlock(args.net_dim[1], args.net_dim[2], num_conv=2)
        self.conv4 = ConvBlock(args.net_dim[2], args.net_dim[3], num_conv=2)
        self.conv5 = ConvBlock(args.net_dim[3], args.net_dim[4], num_conv=2)

        self.pool  = nn.MaxPool2d(**args.down_kwargs)

    def forward(self, inputs):

        conv1_out = self.conv1(inputs)
        h1 = self.pool(conv1_out)

        conv2_out = self.conv2(h1)
        h2 = self.pool(conv2_out)

        conv3_out = self.conv3(h2)
        h3 = self.pool(conv3_out)

        conv4_out = self.conv4(h3)
        h4 = self.pool(conv4_out)

        conv5_out = self.conv5(h4)
        h5 = self.pool(conv5_out)

        stage_outputs = [conv1_out, conv2_out, conv3_out, conv4_out]

        return h5, stage_outputs

 

Encoder - VGG16 backbone

  • 본문에서는 encoder를 초기화 된 상태에서 시작하지 않고 backbone으로 pretrained model을 사용했다고 한다. 하나는 VGG16 모델이고, 다른 하나는 ResNet101 모델이다. 
  • VGG16은 features 라는 이름으로 CNN encoder가 하나로 묶여 있기 때문에 내부에서 추가 메소드를 가진다.
class VGG16_backbone(nn.modules):
    def __init__(self) -> None:
        super(VGG16_backbone, self).__init__()
        import torchvision.models as models

        self.CNN_encoder = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features # weights=VGG16_Weights.IMAGENET1K_V1

    def vgg_layer_forward(self, x, indices):
        output = x
        start_idx, end_idx = indices
        for idx in range(start_idx, end_idx):
            if idx == (end_idx-1):
                pooling = self.CNN_encoder[idx](output)
            else:
                output = self.CNN_encoder[idx](output)
        return pooling, output

    def vgg_forward(self, x):
        out = {}
        depth = 5
        layer_indices = [0, 5, 10, 15, 20, 24] # 
        for layer_num in range(len(depth)-1):
            pooling, output = self.vgg_layer_forward(x, layer_indices[layer_num:layer_num+2])
            out[f'pool{layer_num+1}'] = pooling
            out[f'conv{layer_num+1}'] = output
        return out

    def forward(self, inputs):

        vgg_enc_out = self.CNN_encoder(inputs)

        vgg_conv1 = vgg_enc_out['conv1'].detach()
        vgg_conv2 = vgg_enc_out['conv2'].detach()
        vgg_conv3 = vgg_enc_out['conv3'].detach()
        vgg_conv4 = vgg_enc_out['conv4'].detach()
        vgg_conv5 = vgg_enc_out['conv5'].detach()

        stage_outputs = [vgg_conv4, vgg_conv3, vgg_conv2, vgg_conv1]

        return vgg_conv5, stage_outputs

 

Encoder - ResNet101 backbone

  • ResNet101은 모듈마다 이름으로 구별을 해주셔서, backbone 생성이 어렵지 않다. 
class ResNet101_backbone(nn.modules):
    def __init__(self) -> None:
        super(ResNet101_backbone, self).__init__()
        import torchvision.models as models

        self.ResNet101 = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
        self.conv1     = nn.Sequential(
                            self.ResNet101.conv1, # (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
                            self.ResNet101.bn1,   # (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                            self.ResNet101.relu   # (relu): ReLU(inplace=True)
                        )
        self.init_pool = self.ResNet101.maxpool # (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.conv2     = self.ResNet101.layer1
        self.conv3     = self.ResNet101.layer2
        self.conv4     = self.ResNet101.layer3
        self.conv5     = self.ResNet101.layer4

    def forward(self, inputs):

        h1 = self.conv1(inputs)
        p1 = self.init_pool(h1)

        h2 = self.conv2(p1)
        h3 = self.conv3(h2)
        h4 = self.conv4(h3)
        h5 = self.conv5(h4)

        stage_outputs = [h4, h3, h2, h1]

        return h5, stage_outputs

 

 

Skip connection module - full-scale skip connection

explanation of the full-scale skip connection module

  • U-Net 3+의 핵심, full-scale skip connection 이다. 3+ 라는 이름인 것도 서로 다른 3개의 스케일의 이미지를 하나로 수집하기 때문에 붙여낸 이름 같다. 
  1. 얕은(더 큰 이미지 크기를 가진) 층에 위치한 인코더의 출력은 maxpooling을 통해 down-scaling을 수행
  2. 깊은(더 작은 이미지 크기를 가진) 층에 위치한 디코더의 출력은 bilinear upsampling을 통해 up-scaling을 수행
  3. 해당 디코더와 같은 층에 위치한 인코더의 출력은 스케일 변화 없이 weight를 통과시킨다. 
  4. 각 층을 전담하는 가중치를 통과시킨 후, 출력된 이미지 임베딩을 채널기준으로 concat시킨다. 
  5. concat 시킨 임베딩은 디코더에 전달 되어 디코더의 가중치를 통과시킨다 $X^3_{De}$
  • 위의 과정을 통해 디코더는 모든 스케일의 이미지를 받으며 (U-Net과 U-Net++의 단점보완), 각 인코더 디코더의 semantical gap, 즉 dissimilarity도 최소화한다 (U-Net++의 장점). 
class FullScale_SkipConnection(nn.Module):
    def __init__(self, enc_init_dim=64, skip_hidden_dim=64, decoding_dim=320, dec_depth=4, level=0) -> None:
        assert level > 0
        super(FullScale_SkipConnection, self).__init__()

        
        self.feature_aggregator = nn.ModuleList([])


        # X5_en (dim=enc_init_dim* # pooling(=dec_depth))
        self.feature_aggregator.append(
            nn.Sequential(
                UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'),
                ConvLayer(enc_init_dim*2**dec_depth, skip_hidden_dim, norm_type=None)
            )
        )
        for L in range(dec_depth, 0, -1): # l, l-1, ... , 1
            
            if L > level: # lower level (needed to up scale)
                self.feature_aggregator.append(
                    nn.Sequential(
                        UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'),
                        ConvLayer(decoding_dim, skip_hidden_dim, norm_type=None)
                    )
                )
            elif L == level: # same level
                self.feature_aggregator.append(
                    ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None) # norm_type=None -> weight-ReLU
                )
            elif L < level: # upper level (needed to down sampling)
                self.feature_aggregator.append(
                    nn.Sequential(
                        nn.MaxPool2d(kernel_size=dec_depth//L, stride=dec_depth//L), # if level=1 -> Maxpooling(4,4)
                        ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None)
                    )
                )

    def forward(self, stage_outputs):

        
        skip_out1 = self.feature_aggregator[0](stage_outputs[0]) # x5_de
        skip_out2 = self.feature_aggregator[1](stage_outputs[1])
        skip_out3 = self.feature_aggregator[2](stage_outputs[2])
        skip_out4 = self.feature_aggregator[3](stage_outputs[3])
        skip_out5 = self.feature_aggregator[4](stage_outputs[4]) # x1_en

        skip_out = [skip_out1, skip_out2, skip_out3, skip_out4, skip_out5]
        skip_out = torch.concat(skip_out, dim=1)

        return skip_out

 

 

Decoder & Deep supervision

U-Net 3+ 디코더, classification-guided module (GCM) 동작방식

  • 각 디코더의 출력에 classification 예측 결과를 곱해준다. 이 방법을 사용하는 이유는 3D volumne 데이터를 3 channel 2d 데이터로 만들었기 때문이다. segmentation을 수행할 장기자체가 존재하지 않는 이미지도 존재할 수 있기 때문의 디코더의 학습을 조절한다.
  • 각각 디코더의 출력에 classification prediction을 곱하고, deep supervision을 사용하여 gradient가 디코더로 직접 흐를 수 있도록 설계되어 있다. 
class Decoder(nn.Module):
    def __init__(self) -> None:
        super(Decoder, self).__init__()

        self.decoding_dim = 320

        self.skip1 = FullScale_SkipConnection(dec_depth=4, level=4)
        self.conv1 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.skip2 = FullScale_SkipConnection(dec_depth=4, level=3)
        self.conv2 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.skip3 = FullScale_SkipConnection(dec_depth=4, level=2)
        self.conv3 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.skip4 = FullScale_SkipConnection(dec_depth=4, level=1)
        self.conv4 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)

        self.dsv1  = nn.Sequential(
                        nn.Upsample(scale_factor=16, mode='bilinear'),
                        nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv2  = nn.Sequential(
                        nn.Upsample(scale_factor=8, mode='bilinear'),
                        nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv3  = nn.Sequential(
                        nn.Upsample(scale_factor=4, mode='bilinear'),
                        nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv4  = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='bilinear'),
                        nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
                    )
        self.dsv5  = nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)

        
    def forward(self, enc_out, stage_outputs, organ_flag=True):

        stage_outputs = [enc_out] + stage_outputs
        skip_out = self.skip1(stage_outputs) 
        x4_de    = self.conv1(skip_out) # x4_de

        stage_outputs[1] = x4_de
        skip_out = self.skip2(stage_outputs)
        x3_de    = self.conv2(skip_out)

        stage_outputs[2] = x3_de
        skip_out = self.skip3(stage_outputs)
        x2_de    = self.conv3(skip_out)

        stage_outputs[3] = x2_de
        skip_out = self.skip4(x2_de, stage_outputs)
        x1_de    = self.conv4(skip_out)
        
        enc_out  *= organ_flag
        x4_de    *= organ_flag
        x3_de    *= organ_flag
        x2_de    *= organ_flag
        x1_de    *= organ_flag

        dsv1_out = self.dsv1(enc_out)
        dsv2_out = self.dsv2(x4_de)
        dsv3_out = self.dsv3(x3_de)
        dsv4_out = self.dsv4(x2_de)
        out = self.dsv5(x1_de)

        return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out

 

 

U-Net 3+ & classification-guided module

class UNet3p(nn.Module):
    def __init__(self, args) -> None:
        super(UNet3p, self).__init__()

        self.encoder = Encoder(args)
        # self.encoder = VGG16_backbone()
        # self.encoder = ResNet101_backbone()
        self.decoder = Decoder()
        self.classification_guide = nn.Sequential(
                                        nn.Dropout2d(),
                                        nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0),
                                        nn.AdaptiveAvgPool2d(1),
                                        nn.Sigmoid()
                                    )

    def forward(self, inputs):

        enc_out, stage_outputs = self.encoder(inputs)
        organ_flag = torch.argmax(self.classification_guide(enc_out), dim=1)
        out, dsv4_out, dsv3_out, dsv2_out, dsv1_out = self.Decoder(enc_out, stage_outputs, organ_flag)

        return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out
profile

Coding, Filming, and Nothing

@_안쑤

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!