Info
-
U-Net++와 마찬가지로, U-Net3+ 모델은 Skip connection의 개선을 목표로 두고 설계된 모델이다. 본문에서 저자는 U-Net의 skip connection과 U-Net++의 Skip Pathways (skip connection)이 아직 개선될 여지가 많다는 점을 말하며 모델의 전개를 시작한다.
U-Net3+에서 사용하는 skip connection은 Full-scale skip connection이라는 이름을 가진 네트워크 모듈이다. 3가지의 서로 다른 스케일의 이미지 임베딩을 입력으로 받고 새로운 이미지 임베딩을 생성한다. 1) fine-grained인 얕은 인코더의 출력이미지, 2) 동일한 해상도를 가지고 있는(동일한 층) 인코더 출력이미지, 3) 더 낮은 해상도를 가진 디코더 출력이미지.
그리고 그 외에 classification guidance module을 두어 이미지에 target organ이 포함되어 있는지를 구별하는 classifier를 배치시켜놓아 학습에 도움되게 하였다고 한다. 특이한 점은 3D volumne데이터를 빠른 처리를 위해서 3 channel을 가진 2d 이미지 데이터로 전처리를 하여 입력을 하였다는 점이다.
또한, U-Net++에서 그러했듯이 deep supervision기술이 사용된다. U-Net3+는 각 디코더에서 deep supervision 학습이 진행된다.
Basic operation Modules
Hyperparameter configuration
- 본문에서 특이한 점은, 3D 이미지 처리를 수행하는 모델의 '속도 향상'을 위해서 voxel 이미지를 위 아래로 인접한 이미지를 합쳐서 3 channel의 2D 이미지로 만들어서 처리했다는 점이다. 그래서 args.in_dim = 3 으로 설정된다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict as edict
args = edict()
# net dim
args.in_dim = 3 # 3d img -> 2d img with 3 channels
args.init_dim = 16
args.enc_depth = 5
args.net_dim = [args.init_dim*2**x for x in range(args.enc_depth)] # [16, 32, 64, 128, 256]
args.out_dim = 2
# net operator
conv_kwargs = edict()
conv_kwargs.kernel_size = 3
conv_kwargs.stride = 1
conv_kwargs.padding = 1
down_kwargs = edict()
down_kwargs.kernel_size = 2
down_kwargs.stride = 2
down_kwargs.padding = 0
up_kwargs = edict()
up_kwargs.kernel_size = 2
up_kwargs.stride = 2
up_kwargs.padding = 0
up_kwargs.scale = 2
args.conv_kwargs = conv_kwargs
args.down_kwargs = down_kwargs
args.up_kwargs = up_kwargs
Convolution Layer
- 컨볼루션 레이어 한장
class ConvLayer(nn.Module):
def __init__(self, in_dim, out_dim,
conv_type=nn.Conv2d, conv_kwargs=None,
norm_type=nn.BatchNorm2d, act_type=nn.ReLU) -> None:
super(ConvLayer, self).__init__()
if conv_kwargs is None:
conv_kwargs = edict()
conv_kwargs.kernel_size = 3
conv_kwargs.stride = 1
conv_kwargs.padding = 1
if norm_type is None:
self.conv = nn.Sequential(
conv_type(in_dim, out_dim, **conv_kwargs),
act_type()
)
else:
self.conv = nn.Sequential(
conv_type(in_dim, out_dim, **conv_kwargs),
norm_type(out_dim),
act_type()
)
def forward(self, inputs):
return self.conv(inputs)
Convolution Block
class ConvBlock(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None, num_conv=2,
conv_type=nn.Conv2d, conv_kwargs=None,
norm_type=nn.BatchNorm2d,
act_type=nn.ReLU) -> None:
assert num_conv > 0
super(ConvBlock, self).__init__()
if hidden_dim is None:
hidden_dim = out_dim
if conv_kwargs is None:
conv_kwargs = edict()
conv_kwargs.kernel_size = 3
conv_kwargs.stride = 1
conv_kwargs.padding = 1
if num_conv == 1:
self.blocks = ConvLayer(in_dim, out_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)
else:
self.blocks = nn.Sequential(
*(
[ConvLayer(in_dim, hidden_dim,
conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
+ [ConvLayer(hidden_dim, hidden_dim,
conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) for _ in range(num_conv - 2)]
+ [ConvLayer(hidden_dim, out_dim,
conv_type=nn.Conv2d, conv_kwargs=conv_kwargs, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
)
)
def forward(self, inputs):
return self.blocks(inputs)
UpsamplingLayer
class UpsamplingLayer(nn.Module):
def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear', up_kwargs=None) -> None:
super(UpsamplingLayer, self).__init__()
if up_kwargs is None:
up_kwargs = edict()
up_kwargs.kernel_size = 2
up_kwargs.stride = 2
up_kwargs.padding = 0
up_kwargs.scale = 2
if is_deconv:
self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, **up_kwargs)
else:
# dim doesn't be changed
self.upsampler = nn.Upsample(mode=mode, **up_kwargs)
def forward(self, x):
return self.upsampler(x)
Encoder
- U-Net, U-Net++와 동일하다.
class Encoder(nn.Module):
def __init__(self, args) -> None:
super(Encoder, self).__init__()
self.conv1 = ConvBlock(args.in_dim, args.net_dim[0], num_conv=2)
self.conv2 = ConvBlock(args.net_dim[0], args.net_dim[1], num_conv=2)
self.conv3 = ConvBlock(args.net_dim[1], args.net_dim[2], num_conv=2)
self.conv4 = ConvBlock(args.net_dim[2], args.net_dim[3], num_conv=2)
self.conv5 = ConvBlock(args.net_dim[3], args.net_dim[4], num_conv=2)
self.pool = nn.MaxPool2d(**args.down_kwargs)
def forward(self, inputs):
conv1_out = self.conv1(inputs)
h1 = self.pool(conv1_out)
conv2_out = self.conv2(h1)
h2 = self.pool(conv2_out)
conv3_out = self.conv3(h2)
h3 = self.pool(conv3_out)
conv4_out = self.conv4(h3)
h4 = self.pool(conv4_out)
conv5_out = self.conv5(h4)
h5 = self.pool(conv5_out)
stage_outputs = [conv1_out, conv2_out, conv3_out, conv4_out]
return h5, stage_outputs
Encoder - VGG16 backbone
- 본문에서는 encoder를 초기화 된 상태에서 시작하지 않고 backbone으로 pretrained model을 사용했다고 한다. 하나는 VGG16 모델이고, 다른 하나는 ResNet101 모델이다.
- VGG16은 features 라는 이름으로 CNN encoder가 하나로 묶여 있기 때문에 내부에서 추가 메소드를 가진다.
class VGG16_backbone(nn.modules):
def __init__(self) -> None:
super(VGG16_backbone, self).__init__()
import torchvision.models as models
self.CNN_encoder = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features # weights=VGG16_Weights.IMAGENET1K_V1
def vgg_layer_forward(self, x, indices):
output = x
start_idx, end_idx = indices
for idx in range(start_idx, end_idx):
if idx == (end_idx-1):
pooling = self.CNN_encoder[idx](output)
else:
output = self.CNN_encoder[idx](output)
return pooling, output
def vgg_forward(self, x):
out = {}
depth = 5
layer_indices = [0, 5, 10, 15, 20, 24] #
for layer_num in range(len(depth)-1):
pooling, output = self.vgg_layer_forward(x, layer_indices[layer_num:layer_num+2])
out[f'pool{layer_num+1}'] = pooling
out[f'conv{layer_num+1}'] = output
return out
def forward(self, inputs):
vgg_enc_out = self.CNN_encoder(inputs)
vgg_conv1 = vgg_enc_out['conv1'].detach()
vgg_conv2 = vgg_enc_out['conv2'].detach()
vgg_conv3 = vgg_enc_out['conv3'].detach()
vgg_conv4 = vgg_enc_out['conv4'].detach()
vgg_conv5 = vgg_enc_out['conv5'].detach()
stage_outputs = [vgg_conv4, vgg_conv3, vgg_conv2, vgg_conv1]
return vgg_conv5, stage_outputs
Encoder - ResNet101 backbone
- ResNet101은 모듈마다 이름으로 구별을 해주셔서, backbone 생성이 어렵지 않다.
class ResNet101_backbone(nn.modules):
def __init__(self) -> None:
super(ResNet101_backbone, self).__init__()
import torchvision.models as models
self.ResNet101 = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
self.conv1 = nn.Sequential(
self.ResNet101.conv1, # (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.ResNet101.bn1, # (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.ResNet101.relu # (relu): ReLU(inplace=True)
)
self.init_pool = self.ResNet101.maxpool # (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.conv2 = self.ResNet101.layer1
self.conv3 = self.ResNet101.layer2
self.conv4 = self.ResNet101.layer3
self.conv5 = self.ResNet101.layer4
def forward(self, inputs):
h1 = self.conv1(inputs)
p1 = self.init_pool(h1)
h2 = self.conv2(p1)
h3 = self.conv3(h2)
h4 = self.conv4(h3)
h5 = self.conv5(h4)
stage_outputs = [h4, h3, h2, h1]
return h5, stage_outputs
Skip connection module - full-scale skip connection
- U-Net 3+의 핵심, full-scale skip connection 이다. 3+ 라는 이름인 것도 서로 다른 3개의 스케일의 이미지를 하나로 수집하기 때문에 붙여낸 이름 같다.
- 얕은(더 큰 이미지 크기를 가진) 층에 위치한 인코더의 출력은 maxpooling을 통해 down-scaling을 수행
- 깊은(더 작은 이미지 크기를 가진) 층에 위치한 디코더의 출력은 bilinear upsampling을 통해 up-scaling을 수행
- 해당 디코더와 같은 층에 위치한 인코더의 출력은 스케일 변화 없이 weight를 통과시킨다.
- 각 층을 전담하는 가중치를 통과시킨 후, 출력된 이미지 임베딩을 채널기준으로 concat시킨다.
- concat 시킨 임베딩은 디코더에 전달 되어 디코더의 가중치를 통과시킨다 $X^3_{De}$
- 위의 과정을 통해 디코더는 모든 스케일의 이미지를 받으며 (U-Net과 U-Net++의 단점보완), 각 인코더 디코더의 semantical gap, 즉 dissimilarity도 최소화한다 (U-Net++의 장점).
class FullScale_SkipConnection(nn.Module):
def __init__(self, enc_init_dim=64, skip_hidden_dim=64, decoding_dim=320, dec_depth=4, level=0) -> None:
assert level > 0
super(FullScale_SkipConnection, self).__init__()
self.feature_aggregator = nn.ModuleList([])
# X5_en (dim=enc_init_dim* # pooling(=dec_depth))
self.feature_aggregator.append(
nn.Sequential(
UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'),
ConvLayer(enc_init_dim*2**dec_depth, skip_hidden_dim, norm_type=None)
)
)
for L in range(dec_depth, 0, -1): # l, l-1, ... , 1
if L > level: # lower level (needed to up scale)
self.feature_aggregator.append(
nn.Sequential(
UpsamplingLayer(-1, -1, is_deconv=False, mode='bilinear'),
ConvLayer(decoding_dim, skip_hidden_dim, norm_type=None)
)
)
elif L == level: # same level
self.feature_aggregator.append(
ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None) # norm_type=None -> weight-ReLU
)
elif L < level: # upper level (needed to down sampling)
self.feature_aggregator.append(
nn.Sequential(
nn.MaxPool2d(kernel_size=dec_depth//L, stride=dec_depth//L), # if level=1 -> Maxpooling(4,4)
ConvLayer(enc_init_dim*2**(L-1), skip_hidden_dim, norm_type=None)
)
)
def forward(self, stage_outputs):
skip_out1 = self.feature_aggregator[0](stage_outputs[0]) # x5_de
skip_out2 = self.feature_aggregator[1](stage_outputs[1])
skip_out3 = self.feature_aggregator[2](stage_outputs[2])
skip_out4 = self.feature_aggregator[3](stage_outputs[3])
skip_out5 = self.feature_aggregator[4](stage_outputs[4]) # x1_en
skip_out = [skip_out1, skip_out2, skip_out3, skip_out4, skip_out5]
skip_out = torch.concat(skip_out, dim=1)
return skip_out
Decoder & Deep supervision
- 각 디코더의 출력에 classification 예측 결과를 곱해준다. 이 방법을 사용하는 이유는 3D volumne 데이터를 3 channel 2d 데이터로 만들었기 때문이다. segmentation을 수행할 장기자체가 존재하지 않는 이미지도 존재할 수 있기 때문의 디코더의 학습을 조절한다.
- 각각 디코더의 출력에 classification prediction을 곱하고, deep supervision을 사용하여 gradient가 디코더로 직접 흐를 수 있도록 설계되어 있다.
class Decoder(nn.Module):
def __init__(self) -> None:
super(Decoder, self).__init__()
self.decoding_dim = 320
self.skip1 = FullScale_SkipConnection(dec_depth=4, level=4)
self.conv1 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)
self.skip2 = FullScale_SkipConnection(dec_depth=4, level=3)
self.conv2 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)
self.skip3 = FullScale_SkipConnection(dec_depth=4, level=2)
self.conv3 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)
self.skip4 = FullScale_SkipConnection(dec_depth=4, level=1)
self.conv4 = ConvBlock(self.decoding_dim, self.decoding_dim, num_conv=1)
self.dsv1 = nn.Sequential(
nn.Upsample(scale_factor=16, mode='bilinear'),
nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0)
)
self.dsv2 = nn.Sequential(
nn.Upsample(scale_factor=8, mode='bilinear'),
nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
)
self.dsv3 = nn.Sequential(
nn.Upsample(scale_factor=4, mode='bilinear'),
nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
)
self.dsv4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
)
self.dsv5 = nn.Conv2d(self.decoding_dim, 2, kernel_size=1, stride=1, padding=0)
def forward(self, enc_out, stage_outputs, organ_flag=True):
stage_outputs = [enc_out] + stage_outputs
skip_out = self.skip1(stage_outputs)
x4_de = self.conv1(skip_out) # x4_de
stage_outputs[1] = x4_de
skip_out = self.skip2(stage_outputs)
x3_de = self.conv2(skip_out)
stage_outputs[2] = x3_de
skip_out = self.skip3(stage_outputs)
x2_de = self.conv3(skip_out)
stage_outputs[3] = x2_de
skip_out = self.skip4(x2_de, stage_outputs)
x1_de = self.conv4(skip_out)
enc_out *= organ_flag
x4_de *= organ_flag
x3_de *= organ_flag
x2_de *= organ_flag
x1_de *= organ_flag
dsv1_out = self.dsv1(enc_out)
dsv2_out = self.dsv2(x4_de)
dsv3_out = self.dsv3(x3_de)
dsv4_out = self.dsv4(x2_de)
out = self.dsv5(x1_de)
return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out
U-Net 3+ & classification-guided module
class UNet3p(nn.Module):
def __init__(self, args) -> None:
super(UNet3p, self).__init__()
self.encoder = Encoder(args)
# self.encoder = VGG16_backbone()
# self.encoder = ResNet101_backbone()
self.decoder = Decoder()
self.classification_guide = nn.Sequential(
nn.Dropout2d(),
nn.Conv2d(1024, 2, kernel_size=1, stride=1, padding=0),
nn.AdaptiveAvgPool2d(1),
nn.Sigmoid()
)
def forward(self, inputs):
enc_out, stage_outputs = self.encoder(inputs)
organ_flag = torch.argmax(self.classification_guide(enc_out), dim=1)
out, dsv4_out, dsv3_out, dsv2_out, dsv1_out = self.Decoder(enc_out, stage_outputs, organ_flag)
return out, dsv4_out, dsv3_out, dsv2_out, dsv1_out