Info
-
* 본 논문에 이미 pytorch 구현의 GitHub link를 제공하고 있음. 실제 운용 면에서 더 섬세하게 구현을 해놓으셨기 때문에 이해가 아닌 적용을 목적에 둔 경우 이 글보단 원본 깃허브를 참조하면서 공부하는게 더 낫겠다 싶다.
U-Net ++의 핵심은 encoder와 decoder를 연결해주는 skip connection에서 sementatic gap (dissimilar)를 최소화 하는 것을 목표로 skip connection 위에 convolution block을 여러장 두는 방식으로 접근을 한다.
A model implementation of the paper, UNet++: A Nested U-Net Architecture for Medical Image Segmentation by PyTorch.
You can check this implementation version via my GitHub link.
Basic Operation Module
ConvLayer
- 컨볼루션 레이어 한장! (weight, normalization, activation function)
class ConvLayer(nn.Module):
def __init__(self, in_dim, out_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) -> None:
super(ConvLayer, self).__init__()
self.conv = nn.Sequential(
conv_type(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
norm_type(out_dim),
act_type()
)
def forward(self, inputs):
return self.conv(inputs)
ConvBlock
class ConvBlock(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None, num_conv=2,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) -> None:
assert num_conv > 0
super(ConvBlock, self).__init__()
if hidden_dim is None:
hidden_dim = out_dim
if num_conv == 1:
self.blocks = ConvLayer(in_dim, out_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)
else:
self.blocks = nn.Sequential(
*(
[ConvLayer(in_dim, hidden_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
+ [ConvLayer(hidden_dim, hidden_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU) for _ in range(num_conv - 2)]
+ [ConvLayer(hidden_dim, out_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU)]
)
)
def forward(self, inputs):
return self.blocks(inputs)
UpSamplingLayer
class UpsamplingLayer(nn.Module):
def __init__(self, in_dim, out_dim, is_deconv=True, mode='bilinear') -> None:
super(UpsamplingLayer, self).__init__()
if is_deconv:
self.upsampler = nn.ConvTranspose2d(in_dim, out_dim, kernel_size=2, stride=2, padding=0)
else:
self.upsampler = nn.Upsample(scale_factor=2, mode=mode)
def forward(self, x):
return self.upsampler(x)
Encoder
- encoder class는 U-Net과 동일
class Encoder(nn.Module):
def __init__(self) -> None:
super(Encoder, self).__init__()
self.dim = 32
self.conv1 = ConvBlock(1, self.dim, num_conv=2)
self.conv2 = ConvBlock(self.dim, self.dim*2, num_conv=2)
self.conv3 = ConvBlock(self.dim*2, self.dim*4, num_conv=2)
self.conv4 = ConvBlock(self.dim*4, self.dim*8, num_conv=2)
self.conv5 = ConvBlock(self.dim*8, self.dim*16,num_conv=2)
self.pool = nn.MaxPool2d(2,2)
def forward(self, inputs):
h1 = self.conv1(inputs)
h2 = self.conv2(h1)
h3 = self.conv3(h2)
h4 = self.conv4(h3)
h5 = self.conv5(h4)
stage_outputs = [h1, h2, h3, h4]
return h5, stage_outputs
SkipPathways
- 아래층의 입력들 ($x_{i-1,j}$)을 upsampling을 수행
- $x_{i,0}$를 몇 개의 컨볼루션 레이어로 구성된 skip path를 거치게 하여 skip connection을 수행
- 또한,
id_x
에 dense connection을 위해 $H$(=convolution)을 통과 하기 전 후의 임베딩을 concat - 각 컨볼루션 레이어를 거친 아웃풋
xi_inter
는 위 층의 skip pathway로 전달되어야 하기 때문에 저장
class SkipPathways(nn.Module):
def __init__(self, in_dim, out_dim, lower_dim, path_length=1) -> None:
super(SkipPathways, self).__init__()
self.path_length = path_length
self.conv_path = nn.ModuleList([])
self.upSamplers = nn.ModuleList([])
self.upSamplers.append()
for idx in range(path_length):
self.conv_path.append(ConvLayer(in_dim*idx, out_dim,
conv_type=nn.Conv2d, norm_type=nn.BatchNorm2d, act_type=nn.LeakyReLU))
self.upSamplers.append(UpsamplingLayer(lower_dim, out_dim, is_deconv=True))
def forward(self, xi_0, lowers=[]):
'''
xi_0: the same level encoder output (Xi_0)
lowers: the 1 lower level encoder and pathways outputs (X_{i-1}j)
'''
assert self.path_length == (len(lowers))
xi_j = []
xi_j.append(xi_0)
id_x = xi_0
xi_inter = xi_0
for idx in range(self.path_length):
upsampled_lower = self.upSamplers[idx](lowers[idx])
concated = torch.concat((xi_inter, upsampled_lower), dim=1)
xi_inter = self.conv_path[idx](concated)
xi_j.append(xi_inter)
id_x = torch.concat(id_x, xi_inter) #xi_1 + xi_0
return id_x, xi_j
Decoder
1x1 conv for 'deep supervision'; 본 논문에서 L4만 사용하면 L3를 사용하는 것보다 성능이 떨어짐을 발표에서 보여주었는데, deep supervision을 사용하는 경우 그 성능이 다른 것보다 매우 좋아지는 것을 보여줬음
class Decoder(nn.Module):
def __init__(self) -> None:
super(Decoder, self).__init__()
self.dim = 32
# X4_0 -> X3_1
self.up1 = UpsamplingLayer(self.dim*16, self.dim*8, is_deconv=True)
self.conv1 = ConvBlock(self.dim*16, self.dim*8, num_conv=2)
# X3_1 && X2_0 -> X2_2
self.up2 = UpsamplingLayer(self.dim*8, self.dim*4, is_deconv=True)
self.skip2 = SkipPathways(in_dim=self.dim*4 + self.dim*8, out_dim=self.dim*4, lower_dim=self.dim*8, path_length=1)
self.conv2 = ConvBlock(self.dim*8, self.dim*4, num_conv=2)
# X2_2 && X1_0 -> X1_3
self.up3 = UpsamplingLayer(self.dim*4, self.dim*2, is_deconv=True)
self.skip3 = SkipPathways(in_dim=self.dim*2 + self.dim*4, out_dim=self.dim*2, lower_dim=self.dim*4, path_length=2)
self.conv3 = ConvBlock(self.dim*4, self.dim*2, num_conv=2)
# X1_3 && X0_0 -> X0_4
self.up4 = UpsamplingLayer(self.dim*2, self.dim, is_deconv=True)
self.skip4 = SkipPathways(in_dim=self.dim + self.dim*2, out_dim=self.dim, lower_dim=self.dim*2, path_length=3)
self.conv4 = ConvBlock(self.dim*2, self.dim, num_conv=2)
# for deep supervision
self.dsp_l1 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
self.dsp_l2 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
self.dsp_l3 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
self.dsp_l4 = nn.Conv2d(self.dim, 2, kernel_size=1, stride=1, padding=0)
def forward(self, enc_out, stage_outputs):
x3_0 = stage_outputs[-1]
x3_1 = self.up1(enc_out)
x3_1 = torch.concat((x3_1, x3_0), dim=1) # x3_0, x4_0
x3_1 = self.conv1(x3_1)
x2_0 = stage_outputs[-2]
x2_2 = self.up2(x3_1)
x2_1, x2_j = self.skip2(x2_0, [x3_0])
x2_2 = torch.concat((x2_2, x2_1), dim=1)
x2_2 = self.conv2(x2_2)
x1_0 = stage_outputs[-3]
x1_3 = self.up3(x2_2)
x1_2, x1_j = self.skip3(x1_0, x2_j)
x1_3 = torch.concat((x1_3, x1_2), dim=1)
x1_3 = self.conv3(x1_3)
x0_0 = stage_outputs[-4]
x0_4 = self.up4(x1_3)
x0_3, x0_j = self.skip4(x0_0, x1_j)
x0_4 = torch.concat((x0_4, x0_3), dim=1)
x0_4 = self.conv4(x0_4)
out_l1 = self.dsp_l1(x0_j[0])
out_l2 = self.dsp_l2(x0_j[1])
out_l3 = self.dsp_l3(x0_j[2])
out_l4 = self.dsp_l4(x0_4)
return out_l1, out_l2, out_l3, out_l4
U-Net++
class UNetpp(nn.modules):
def __init__(self) -> None:
super(UNetpp, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, inputs):
enc_out, stage_outputs = self.encoder(inputs)
out_l1, out_l2, out_l3, out_l4 = self.decoder(enc_out, stage_outputs)
# for deep supervision
# out = (out_l1 + out_l2 + out_l3 + out_l4) / 4
return out_l1, out_l2, out_l3, out_l4