Info
**U-Net: Convolutional Networks for Biomedical Image Segmentation**
*Olaf Ronneberger, Philipp Fischer, Thomas Brox*
MICCAI 2015
-
원래 U-Net에서는 Convolution 연산을 수행할 때 패딩을 넣지 않아서 이미지의 크기가 점진적으로 줄어들지만, 현 시점의 구현에서는 입력 이미지의 크기를 줄일 필요가 없습니다. (오히려 잠재적으로 여러방면에서 손해) 하지만 해당 본문에서는 실제 U-Net 모델을 그대로 재현했습니다.
U-Net 특징
- 매개변수의 효율적인 사용: U-Net 아키텍처는 건너뛰기 연결을 사용하여 인코더와 디코더의 기능 맵을 연결하므로 이미지 분할을 수행하는 데 필요한 매개변수의 수를 줄이는 데 도움이 됩니다.
- 작은 데이터 세트에서 우수한 성능: U-Net 아키텍처는 생물-의학 이미지 분석에서 일반적으로 겪게되는 문제인 적은 양의 훈련 데이터로 작업할 때도 효과적입니다.
- 유연성: U-Net 구조는 네트워크에서 레이어의 크기와 수를 변경하여 다양한 유형의 이미지 분할 작업을 처리하도록 쉽게 수정할 수 있습니다.
- End-to-End 학습: U-Net 아키텍처는 end-to-end learning을 수행함으로 전체 네트워크를 하나의 단계로 교육할 수 있습니다.
- 시각화를 통한 해석 가능한 feature map: U-Net의 인코더-디코더 구조는 쉽게 시각화할 수 있는 feature map을 생성하여 네트워크의 내부 작동을 더 잘 이해하고 훈련 중 잠재적인 문제 해결을 가능하게 합니다.
5 Advantages of U-Net
- Efficient use of parameters: The U-Net architecture uses skip connections to concatenate the feature maps from the encoder and decoder, which helps to reduce the number of parameters needed to perform image segmentation.
- Good performance on small datasets: The U-Net architecture is effective even when working with limited amounts of training data, which is a common issue in biomedical image analysis.
- Flexibility: The U-Net can be easily modified to handle different types of image segmentation tasks by changing the size and number of layers in the network.
- End-to-end training: The U-Net architecture allows for end-to-end training, which means that the entire network can be trained in a single step.
- Interpretable feature maps: The U-Net's encoder-decoder structure produces feature maps that can be easily visualized, allowing for better understanding of the network's inner workings and potential troubleshooting of issues during training.
The encoder networks - Contracting Path
- The repeated application of two 3x3 convolutions (unpadded convolutions)
- The 2 Conv each followed by a rectified linear unit (ReLU)
- A 2x2 max pooling operation with stride
..double the number of feature channels at each downsampling step.
# encoder
class ContractingPath(nn.Module):
def __init__(self, args=None) -> None:
super(ContractingPath, self).__init__()
# input dim = 1 : gray scale image
# 572x572 input size
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0), # the paper model used zero padding in the contracting Path
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2), # 2x2 max pooling
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.MaxPool2d(2,2),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.conv5 = nn.Sequential(
nn.MaxPool2d(2,2),
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
# 28x28 output size
def forward(self, x):
h1 = self.conv1(x)
h2 = self.conv2(h1)
h3 = self.conv3(h2)
h4 = self.conv4(h3)
h5 = self.conv5(h4)
layer_outputs = [h1, h2, h3, h4] # U-net uses the outputs of each two consecutive convolution+ReLU, excepts 5th one.
return h5, layer_outputs
layer_outputs
stores the outputs of each convolution module and is passed into Expansive Path (decoder).
The decoder networks - Expansive Path
- An upsampling of the feature map followed by a 2x2 convolution ("up-convolution"; not
torch.nn.Upsample
) that halves the number of feature channels - concatenation with the correspondingly cropped feature map from the contracting path
- two 3x3 convolutions, each followed by a ReLU
- At the final layer a 1x1 convolution
# deconder
class ExpansivePath(nn.Module):
def __init__(self) -> None:
super(ExpansivePath, self).__init__()
# they used "up-convolutions",
# -> Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution ("up-convolution")
# -> up-convolution halves the # of feature channels.
self.upConv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.conv1 = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.upConv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv2 = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.upConv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv3 = nn.Sequential(
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.upConv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU()
)
self.outConv = nn.Conv2d(64, 2, kernel_size=1, stride=1)
def forward(self, enc_out, layer_outputs):
h = self.upConv1(enc_out)
cropped = layer_outputs[-1][..., 4:h.shape[-2], 4:h.shape[-1]]
h = torch.cat((h, cropped), dim=1)
h = self.conv1(h)
h = self.upConv2(h)
cropped = layer_outputs[-2][..., 16:h.shape[-2], 16:h.shape[-1]]
h = torch.cat((h, cropped), dim=1)
h = self.conv2(h)
h = self.upConv3(h)
cropped = layer_outputs[-3][..., 40:h.shape[-2], 40:h.shape[-1]]
h = torch.cat((h, cropped), dim=1)
h = self.conv3(h)
h = self.upConv4(h)
cropped = layer_outputs[-4][..., 88:h.shape[-2], 88:h.shape[-1]]
h = torch.cat((h, cropped), dim=1)
h = self.conv4(h)
output = self.outConv(h)
return output
U-Net architecture
class Unet(nn.Module):
def __init__(self) -> None:
super(Unet, self).__init__()
self.encoder = ContractingPath()
self.decoder = ExpansivePath()
def forward(self, x):
enc_out, layer_outputs = self.encoder(x)
output = self.decoder(enc_out, layer_outputs)
return output