info.
**V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation**
*Fausto Milletari, Nassir Navab, Seyed-Ahmad Ahmadi*
3DV 2016
V-Net 논문 읽고 구현을 해본 파이토치 버전 코드입니다.Encoder (compression path), Decoder (decompression path), 그리고 dice loss를 구현했습니다. 논문 본문에 있는 모델 설명을 참조했습니다.
V-Net 장점 요약
- 3D 의료 이미지의 효과적인 Segmentation: V-Net 네트워크 아키텍쳐는 3D 의료 이미지 Segmentation을 위해 특별히 설계되어 CT 또는 MRI 스캔과 같은 3D volumne 이미지를 분할하는 데 매우 효과적입니다.
- 매개변수의 효율적인 사용: U-Net과 유사하게 V-Net은 skip connection을 활용하여 segmentation에 필요한 매개변수 수를 줄여 보다 효율적인(가벼운) 네트워크를 만듭니다.
- 작은 데이터 세트에서 일관된 성능: V-Net은 작은 데이터 세트에서도 일관된 결과를 생성할 수 있습니다. 이는 지도학습(supervised learning)의 데이터 양이 적기 때문에 겪게 되는 일반적인 문제인데, 이를 완화합니다.
- 과적합(Overfitting) 위험 감소: V-Net 네트워크는 드롭아웃 및 배치 정규화를 통합하여 과적합을 방지하고 네트워크의 일반화 성능을 향상시킵니다.
Summary with 4 advantages of V-Net
V-Net is a fully convolutional neural network designed for volumetric medical image segmentation.
- Effective segmentation of 3D medical images: The V-Net architecture is designed specifically for 3D medical image segmentation tasks, making it highly effective at segmenting volumetric images such as CT or MRI scans.
- Efficient use of parameters: Similar to U-Net, V-Net utilizes skip connections to reduce the number of parameters required for segmentation, resulting in a more efficient network.
- Consistent performance on small datasets: V-Net is capable of producing consistent results on small datasets, which is a common issue in medical imaging due to the limited availability of annotated data.
- Reduced risk of overfitting: The V-Net architecture incorporates dropout and batch normalization layers to help prevent overfitting and improve the generalization performance of the network.
Encoder - 'compression Path'
- Residual connection
- 5x5x5 3d convolution kernels
- 2x2x2 3d convolution kernels with 2 strides for replacing the pooling layer
- PReLU activation functions
- Appropriate padding (c.f., U-Net did not use the padding)
# Encoder; 'compression path'
# ... The left part of the network consists of a compress, ...
class CompressionPath(nn.Module):
def __init__(self) -> None:
super(CompressionPath, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.down1 = nn.Sequential(
nn.Conv3d(16, 32, kernel_size=2, stride=2, padding=0),
nn.PReLU
)
self.conv2 = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.down2 = nn.Sequential(
nn.Conv3d(32, 64, kernel_size=2, stride=2, padding=0),
nn.PReLU
)
self.conv3 = nn.Sequential(
nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.down3 = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=2, stride=2, padding=0),
nn.PReLU
)
self.conv4 = nn.Sequential(
nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.down4 = nn.Sequential(
nn.Conv3d(128, 256, kernel_size=2, stride=2, padding=0),
nn.PReLU
)
self.conv5 = nn.Sequential(
nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
def forward(self, x) :
id1 = x
h1 = self.conv1(x)
h1 = h1 + id1
d1 = self.down1(h1)
id2 = d1
h2 = self.conv2(d1)
h2 = h2 + id2
d2 = self.down2(h2)
id3 = d2
h3 = self.conv3(d2)
h3 = h3 + id3
d3 = self.down3(h3)
id4 = d3
h4 = self.conv4(d3)
h4 = h4 + id4
d4 = self.down4(h4)
id5 = d4
h5 = self.conv5(d4)
h5 = h5 + id5
stage_outputs = [h1, h2, h3, h4] # forward the features extracted from early stages of the left part of the CNN to the right part
return h5, stage_outputs
Decode - 'decompression path'
- Residual connection
- Concatenation of the output of the corresponding stage ( 'Fine-grained features forwarding' )
- 5x5x5 3d convolution kernels
- 2x2x2 3d convolution kernels with 2 strides for replacing the pooling layer
- PReLU activation functions
- Appropriate padding (c.f., U-Net did not use the padding)
# decoder; 'decompression path'
# while the right part decompresses the signal until its original size is reached.
class DecompressionPath(nn.Module):
def __init__(self) -> None:
super(DecompressionPath, self).__init__()
self.up1 = nn.Sequential(
nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2, padding=0),
nn.PReLU()
)
self.conv1 = nn.Sequential(
nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.up2 = nn.Sequential(
nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, padding=0),
nn.PReLU()
)
self.conv2 = nn.Sequential(
nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.up3 = nn.Sequential(
nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, padding=0),
nn.PReLU()
)
self.conv3 = nn.Sequential(
nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.up4 = nn.Sequential(
nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2, padding=0),
nn.PReLU()
)
self.conv4 = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.conv5 = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=1, stride=1, padding=0),
nn.PReLU()
)
def forward(self, enc_out, stage_outputs):
u1 = self.up1(enc_out)
id1 = u1
h1 = torch.cat([u1, stage_outputs[-1]], dim=1)
h1 = self.conv1(h1)
h1 = h1 + id1
u2 = self.up2(h1)
id2 = u2
h2 = torch.cat([u2, stage_outputs[-2]], dim=1)
h2 = self.conv2(h2)
h2 = h2 + id2
u3 = self.up3(h2)
id3 = u3
h3 = torch.cat([u3, stage_outputs[-3]], dim=1)
h3 = self.conv3(h3)
h3 = h3 + id3
u4 = self.up4(h3)
id4 = u4
h4 = torch.cat([u4, stage_outputs[-4]], dim=1)
h4 = self.conv4(h4)
h4 = h4 + id4
out = self.conv5(h4)
return out
Implementation of V-Net
An Encoder-Decoder network architecture
class VNet(nn.Module):
def __init__(self) -> None:
super(VNet, self).__init__()
self.encoder = CompressionPath()
self.decoder = DecompressionPath()
def forward(self, x, activate=False):
enc_out, stage_outputs = self.encoder(x)
dec_out = self.decoder(enc_out, stage_outputs)
if not activate:
return dec_out
else:
output = F.softmax(dec_out, dim=1)
return output
Dice Loss
The dice coefficient D between two binary volumes can be written as:
# dice loss
def DiceLoss(inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
dice_coef = ((2.0 * inputs * targets).sum() + smooth )/ ((inputs**2).sum() + (targets**2).sum() + smooth)
return 1 - dice_coef