Coding, Filming, and Nothing
article thumbnail

Info

**UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation**    
*Huimin Huang, Lanfen Lin, Ruofeng Tong, Hongjie Hu, Qiaowei Zhang, Yutaro Iwamoto, Xianhua Han, Yen-Wei Chen, Jian Wu*  
ICASSP(IEEE) 2020  

-

U-Net++와 마찬가지로, U-Net3+ 모델은 Skip connection의 개선을 목표로 두고 설계된 모델이다. 본문에서 저자는 U-Net의 skip connection과 U-Net++의 Skip Pathways (skip connection)이 아직 개선될 여지가 많다는 점을 말하며 모델의 전개를 시작한다. 

 

U-Net3+에서 사용하는 skip connection은 Full-scale skip connection이라는 이름을 가진 네트워크 모듈이다. 3가지의 서로 다른 스케일의 이미지 임베딩을 입력으로 받고 새로운 이미지 임베딩을 생성한다. 1) fine-grained인 얕은 인코더의 출력이미지, 2) 동일한 해상도를 가지고 있는(동일한 층) 인코더 출력이미지, 3) 더 낮은 해상도를 가진 디코더 출력이미지. 

 

그리고 그 외에 classification guidance module을 두어 이미지에 target organ이 포함되어 있는지를 구별하는 classifier를 배치시켜놓아 학습에 도움되게 하였다고 한다. 특이한 점은 3D volumne데이터를 빠른 처리를 위해서 3 channel을 가진 2d 이미지 데이터로 전처리를 하여 입력을 하였다는 점이다. 

 

또한, U-Net++에서 그러했듯이 deep supervision기술이 사용된다. U-Net3+는 각 디코더에서 deep supervision 학습이 진행된다. 


 

U-Net schemes

 

1. Basic operation Modules

1.1. Hyperparameter configuration

- 본문에서 특이한 점은, 3D 이미지 처리를 수행하는 모델의 '속도 향상'을 위해서 voxel 이미지를 위 아래로 인접한 이미지를 합쳐서 3 channel의 2D 이미지로 만들어서 처리했다는 점이다. 그래서 args.in_dim = 3 으로 설정된다.

<python />
import torch import torch.nn as nn import torch.nn.functional as F from easydict import EasyDict as edict args = edict() # net dim args.in_dim = 3 # 3d img -> 2d img with 3 channels args.init_dim = 16 args.enc_depth = 5 args.net_dim = [args.init_dim*2**x for x in range(args.enc_depth)] # [16, 32, 64, 128, 256] args.out_dim = 2 # net operator conv_kwargs = edict() conv_kwargs.kernel_size = 3 conv_kwargs.stride = 1 conv_kwargs.padding = 1 down_kwargs = edict() down_kwargs.kernel_size = 2 down_kwargs.stride = 2 down_kwargs.padding = 0 up_kwargs = edict() up_kwargs.kernel_size = 2 up_kwargs.stride = 2 up_kwargs.padding = 0 up_kwargs.scale = 2 args.conv_kwargs = conv_kwargs args.down_kwargs = down_kwargs args.up_kwargs = up_kwargs

 

 

1.2. Convolution Layer

  • 컨볼루션 레이어 한장
<python />
class ConvLayer(nn.Module): def __init__(self, in_dim, out_dim, conv_type=nn.Conv2d, conv_kwargs=None, norm_type=nn.BatchNorm2d, act_type=nn.ReLU) -> None: super(ConvLayer, self).__init__() if conv_kwargs is None: conv_kwargs = edict() conv_kwargs.kernel_size = 3 conv_kwargs.stride = 1 conv_kwargs.padding = 1 if norm_type is None: self.conv = nn.Sequential( conv_type(in_dim, out_dim, **conv_kwargs), act_type() ) else: self.conv = nn.Sequential( conv_type(in_dim, out_dim, **conv_kwargs), norm_type(out_dim), act_type() ) def forward(self, inputs): return self.conv(inputs)

 

 

1.3. Convolution Block 

<python />
class ConvBlock(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim=None, num_conv=2, conv_type=nn.Conv2d, conv_kwargs=None, norm_type=nn.BatchNorm2d, act_type=nn.ReLU) -> None: assert num_conv > 0 super(ConvBlock, self).__init__() if hidden_dim is None: hidden_dim = out_dim if conv_kwargs is None: conv_kwargs = edict() conv_kwargs.kernel_size = 3 conv_kwargs.stride = 1 conv_kwargs.padding = 1 if num_conv == 1: self.blocks = ConvLayer(in_dim, out_dim, conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) else: self.blocks = nn.Sequential( *( [ConvLayer(in_dim, hidden_dim, conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)] + [ConvLayer(hidden_dim, hidden_dim, conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) for _ in range(num_conv - 2)] + [ConvLayer(hidden_dim, out_dim, conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)] ) ) def forward(self, inputs): return self.blocks(inputs)

 

 

1.4. UpsamplingLayer

<python />
class UpsamplingLayer(nn.Module): def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear', up_kwargs=None) -> None: super(UpsamplingLayer, self).__init__() if up_kwargs is None: up_kwargs = edict() up_kwargs.kernel_size = 2 up_kwargs.stride = 2 up_kwargs.padding = 0 up_kwargs.scale = 2 if is_deconv: self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, **up_kwargs) else: # dim doesn't be changed self.upsampler = nn.Upsample(mode=mode, **up_kwargs) def forward(self, x): return self.upsampler(x)

 

2. Encoder

  • U-Net, U-Net++와 동일하다.
<python />
class Encoder(nn.Module): def __init__(self, args) -> None: super(Encoder, self).__init__() self.conv1 = ConvBlock(args.in_dim, args.net_dim[0], num_conv=2) self.conv2 = ConvBlock(args.net_dim[0], args.net_dim[1], num_conv=2) self.conv3 = ConvBlock(args.net_dim[1], args.net_dim[2], num_conv=2) self.conv4 = ConvBlock(args.net_dim[2], args.net_dim[3], num_conv=2) self.conv5 = ConvBlock(args.net_dim[3], args.net_dim[4], num_conv=2) self.pool = nn.MaxPool2d(**args.down_kwargs) def forward(self, inputs): conv1_out = self.conv1(inputs) h1 = self.pool(conv1_out) conv2_out = self.conv2(h1) h2 = self.pool(conv2_out) conv3_out = self.conv3(h2) h3 = self.pool(conv3_out) conv4_out = self.conv4(h3) h4 = self.pool(conv4_out) conv5_out = self.conv5(h4) h5 = self.pool(conv5_out) stage_outputs = [conv1_out, conv2_out, conv3_out, conv4_out] return h5, stage_outputs

 

3. Encoder - VGG16 backbone

  • 본문에서는 encoder를 초기화 된 상태에서 시작하지 않고 backbone으로 pretrained model을 사용했다고 한다. 하나는 VGG16 모델이고, 다른 하나는 ResNet101 모델이다. 
  • VGG16은 features 라는 이름으로 CNN encoder가 하나로 묶여 있기 때문에 내부에서 추가 메소드를 가진다.
<python />
class VGG16_backbone(nn.modules): def __init__(self) -> None: super(VGG16_backbone, self).__init__() import torchvision.models as models self.CNN_encoder = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features # weights=VGG16_Weights.IMAGENET1K_V1 def vgg_layer_forward(self, x, indices): output = x start_idx, end_idx = indices for idx in range(start_idx, end_idx): if idx == (end_idx-1): pooling = self.CNN_encoder[idx](output) else: output = self.CNN_encoder[idx](output) return pooling, output def vgg_forward(self, x): out = {} depth = 5 layer_indices = [0, 5, 10, 15, 20, 24] # for layer_num in range(len(depth)-1): pooling, output = self.vgg_layer_forward(x, layer_indices[layer_num:layer_num+2]) out[f'pool{layer_num+1}'] = pooling out[f'conv{layer_num+1}'] = output return out def forward(self, inputs): vgg_enc_out = self.CNN_encoder(inputs) vgg_conv1 = vgg_enc_out['conv1'].detach() vgg_conv2 = vgg_enc_out['conv2'].detach() vgg_conv3 = vgg_enc_out['conv3'].detach() vgg_conv4 = vgg_enc_out['conv4'].detach() vgg_conv5 = vgg_enc_out['conv5'].detach() stage_outputs = [vgg_conv4, vgg_conv3, vgg_conv2, vgg_conv1] return vgg_conv5, stage_outputs

 

4. Encoder - ResNet101 backbone

  • ResNet101은 모듈마다 이름으로 구별을 해주셔서, backbone 생성이 어렵지 않다. 
<python />
class ResNet101_backbone(nn.modules): def __init__(self) -> None: super(ResNet101_backbone, self).__init__() import torchvision.models as models self.ResNet101 = models.resnet101(weights=models.ResNet101_Weights.DEFAULT) self.conv1 = nn.Sequential( self.ResNet101.conv1, # (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.ResNet101.bn1, # (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.ResNet101.relu # (relu): ReLU(inplace=True) ) self.init_pool = self.ResNet101.maxpool # (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) self.conv2 = self.ResNet101.layer1 self.conv3 = self.ResNet101.layer2 self.conv4 = self.ResNet101.layer3 self.conv5 = self.ResNet101.layer4 def forward(self, inputs): h1 = self.conv1(inputs) p1 = self.init_pool(h1) h2 = self.conv2(p1) h3 = self.conv3(h2) h4 = self.conv4(h3) h5 = self.conv5(h4) stage_outputs = [h4, h3, h2, h1] return h5, stage_outputs

 

 

5. Skip connection module - full-scale skip connection

explanation of the full-scale skip connection module

  • U-Net 3+의 핵심, full-scale skip connection 이다. 3+ 라는 이름인 것도 서로 다른 3개의 스케일의 이미지를 하나로 수집하기 때문에 붙여낸 이름 같다. 
  1. 얕은(더 큰 이미지 크기를 가진) 층에 위치한 인코더의 출력은 maxpooling을 통해 down-scaling을 수행
  2. 깊은(더 작은 이미지 크기를 가진) 층에 위치한 디코더의 출력은 bilinear upsampling을 통해 up-scaling을 수행
  3. 해당 디코더와 같은 층에 위치한 인코더의 출력은 스케일 변화 없이 weight를 통과시킨다. 
  4. 각 층을 전담하는 가중치를 통과시킨 후, 출력된 이미지 임베딩을 채널기준으로 concat시킨다. 
  5. concat 시킨 임베딩은 디코더에 전달 되어 디코더의 가중치를 통과시킨다 $X^3_{De}$
  • 위의 과정을 통해 디코더는 모든 스케일의 이미지를 받으며 (U-Net과 U-Net++의 단점보완), 각 인코더 디코더의 semantical gap, 즉 dissimilarity도 최소화한다 (U-Net++의 장점). 
<python />
class FullScale_SkipConnection(nn.Module): def __init__(self, enc_init_dim=64, skip_hidden_dim=64, decoding_dim=320, dec_depth=4, level=0) -> None: assert level > 0 super(FullScale_SkipConnection, self).__init__() self.feature_aggregator = nn.ModuleList([]) # X5_en (dim=enc_init_dim* # pooling(=dec_depth)) self.feature_aggregator.append( nn.Sequential( UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'), ConvLayer(enc_init_dim*2**dec_depth, skip_hidden_dim, norm_type=None) ) ) for L in range(dec_depth, 0, -1): # l, l-1, ... , 1 if L > level: # lower level (needed to up scale) self.feature_aggregator.append( nn.Sequential( UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'), ConvLayer(decoding_dim, skip_hidden_dim, norm_type=None) ) ) elif L == level: # same level self.feature_aggregator.append( ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None) # norm_type=None -> weight-ReLU ) elif L < level: # upper level (needed to down sampling) self.feature_aggregator.append( nn.Sequential( nn.MaxPool2d(kernel_size=dec_depth//L, stride=dec_depth//L), # if level=1 -> Maxpooling(4,4) ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None) ) ) def forward(self, stage_outputs): skip_out1 = self.feature_aggregator[0](stage_outputs[0]) # x5_de skip_out2 = self.feature_aggregator[1](stage_outputs[1]) skip_out3 = self.feature_aggregator[2](stage_outputs[2]) skip_out4 = self.feature_aggregator[3](stage_outputs[3]) skip_out5 = self.feature_aggregator[4](stage_outputs[4]) # x1_en skip_out = [skip_out1, skip_out2, skip_out3, skip_out4, skip_out5] skip_out = torch.concat(skip_out, dim=1) return skip_out

 

 

6. Decoder & Deep supervision

U-Net 3+ 디코더, classification-guided module (GCM) 동작방식

  • 각 디코더의 출력에 classification 예측 결과를 곱해준다. 이 방법을 사용하는 이유는 3D volumne 데이터를 3 channel 2d 데이터로 만들었기 때문이다. segmentation을 수행할 장기자체가 존재하지 않는 이미지도 존재할 수 있기 때문의 디코더의 학습을 조절한다.
  • 각각 디코더의 출력에 classification prediction을 곱하고, deep supervision을 사용하여 gradient가 디코더로 직접 흐를 수 있도록 설계되어 있다. 
<python />
class Decoder(nn.Module): def __init__(self) -> None: super(Decoder, self).__init__() self.decoding_dim = 320 self.skip1 = FullScale_SkipConnection(dec_depth=4, level=4) self.conv1 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1) self.skip2 = FullScale_SkipConnection(dec_depth=4, level=3) self.conv2 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1) self.skip3 = FullScale_SkipConnection(dec_depth=4, level=2) self.conv3 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1) self.skip4 = FullScale_SkipConnection(dec_depth=4, level=1) self.conv4 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1) self.dsv1 = nn.Sequential( nn.Upsample(scale_factor=16, mode='bilinear'), nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0) ) self.dsv2 = nn.Sequential( nn.Upsample(scale_factor=8, mode='bilinear'), nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0) ) self.dsv3 = nn.Sequential( nn.Upsample(scale_factor=4, mode='bilinear'), nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0) ) self.dsv4 = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0) ) self.dsv5 = nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0) def forward(self, enc_out, stage_outputs, organ_flag=True): stage_outputs = [enc_out] + stage_outputs skip_out = self.skip1(stage_outputs) x4_de = self.conv1(skip_out) # x4_de stage_outputs[1] = x4_de skip_out = self.skip2(stage_outputs) x3_de = self.conv2(skip_out) stage_outputs[2] = x3_de skip_out = self.skip3(stage_outputs) x2_de = self.conv3(skip_out) stage_outputs[3] = x2_de skip_out = self.skip4(x2_de, stage_outputs) x1_de = self.conv4(skip_out) enc_out *= organ_flag x4_de *= organ_flag x3_de *= organ_flag x2_de *= organ_flag x1_de *= organ_flag dsv1_out = self.dsv1(enc_out) dsv2_out = self.dsv2(x4_de) dsv3_out = self.dsv3(x3_de) dsv4_out = self.dsv4(x2_de) out = self.dsv5(x1_de) return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out

 

 

7. U-Net 3+ & classification-guided module

<python />
class UNet3p(nn.Module): def __init__(self, args) -> None: super(UNet3p, self).__init__() self.encoder = Encoder(args) # self.encoder = VGG16_backbone() # self.encoder = ResNet101_backbone() self.decoder = Decoder() self.classification_guide = nn.Sequential( nn.Dropout2d(), nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0), nn.AdaptiveAvgPool2d(1), nn.Sigmoid() ) def forward(self, inputs): enc_out, stage_outputs = self.encoder(inputs) organ_flag = torch.argmax(self.classification_guide(enc_out), dim=1) out, dsv4_out, dsv3_out, dsv2_out, dsv1_out = self.Decoder(enc_out, stage_outputs, organ_flag) return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out
profile

Coding, Filming, and Nothing

@_안쑤

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!