Coding, Filming, and Nothing
article thumbnail

info.

**V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation** 
*Fausto Milletari, Nassir Navab, Seyed-Ahmad Ahmadi*  
3DV 2016

 

 

V-Net 논문 읽고 구현을 해본 파이토치 버전 코드입니다.Encoder (compression path), Decoder (decompression path), 그리고 dice loss를 구현했습니다. 논문 본문에 있는 모델 설명을 참조했습니다.

 

V-Net 장점 요약

  1. 3D 의료 이미지의 효과적인 Segmentation: V-Net 네트워크 아키텍쳐는 3D 의료 이미지 Segmentation을 위해 특별히 설계되어 CT 또는 MRI 스캔과 같은 3D volumne 이미지를 분할하는 데 매우 효과적입니다.
  2. 매개변수의 효율적인 사용: U-Net과 유사하게 V-Net은 skip connection을 활용하여 segmentation에 필요한 매개변수 수를 줄여 보다 효율적인(가벼운) 네트워크를 만듭니다.
  3. 작은 데이터 세트에서 일관된 성능: V-Net은 작은 데이터 세트에서도 일관된 결과를 생성할 수 있습니다. 이는 지도학습(supervised learning)의 데이터 양이 적기 때문에 겪게 되는 일반적인 문제인데, 이를 완화합니다.
  4. 과적합(Overfitting) 위험 감소: V-Net 네트워크는 드롭아웃 및 배치 정규화를 통합하여 과적합을 방지하고 네트워크의 일반화 성능을 향상시킵니다.

 

Summary with 4 advantages of V-Net

V-Net is a fully convolutional neural network designed for volumetric medical image segmentation. 

  1. Effective segmentation of 3D medical images: The V-Net architecture is designed specifically for 3D medical image segmentation tasks, making it highly effective at segmenting volumetric images such as CT or MRI scans.
  2. Efficient use of parameters: Similar to U-Net, V-Net utilizes skip connections to reduce the number of parameters required for segmentation, resulting in a more efficient network.
  3. Consistent performance on small datasets: V-Net is capable of producing consistent results on small datasets, which is a common issue in medical imaging due to the limited availability of annotated data.
  4. Reduced risk of overfitting: The V-Net architecture incorporates dropout and batch normalization layers to help prevent overfitting and improve the generalization performance of the network.

V-Net arch. summary

 

 

Encoder - 'compression Path'

  • Residual connection
  • 5x5x5 3d convolution kernels
  • 2x2x2 3d convolution kernels with 2 strides for replacing the pooling layer
  • PReLU activation functions 
  • Appropriate padding (c.f., U-Net did not use the padding) 
# Encoder; 'compression path'
# ... The left part of the network consists of a compress, ...
class CompressionPath(nn.Module):
    def __init__(self) -> None:
        super(CompressionPath, self).__init__()

        self.conv1 = nn.Sequential(
                nn.Conv3d(1, 16, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down1 = nn.Sequential(
                nn.Conv3d(16, 32, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )

        self.conv2 = nn.Sequential(
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down2 = nn.Sequential(
                nn.Conv3d(32, 64, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )
        
        self.conv3 = nn.Sequential(
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down3 = nn.Sequential(
                nn.Conv3d(64, 128, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )

        self.conv4 = nn.Sequential(
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down4 = nn.Sequential(
                nn.Conv3d(128, 256, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )

        self.conv5 = nn.Sequential(
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        

    def forward(self, x) :
        
        id1 = x
        h1  = self.conv1(x)
        h1  = h1 + id1
        d1  = self.down1(h1)

        id2 = d1
        h2  = self.conv2(d1)
        h2  = h2 + id2
        d2  = self.down2(h2)

        id3 = d2 
        h3  = self.conv3(d2)
        h3  = h3 + id3 
        d3  = self.down3(h3)

        id4 = d3 
        h4  = self.conv4(d3)
        h4  = h4 + id4
        d4  = self.down4(h4)

        id5 = d4
        h5  = self.conv5(d4)
        h5  = h5 + id5

        stage_outputs = [h1, h2, h3, h4] # forward the features extracted from early stages of the left part of the CNN to the right part

        return h5, stage_outputs

 

 

 

Decode - 'decompression path' 

  • Residual connection
  • Concatenation of the output of the corresponding stage ( 'Fine-grained features forwarding' )
  • 5x5x5 3d convolution kernels
  • 2x2x2 3d convolution kernels with 2 strides for replacing the pooling layer
  • PReLU activation functions 
  • Appropriate padding (c.f., U-Net did not use the padding) 
# decoder; 'decompression path'
# while the right part decompresses the signal until its original size is reached.
class DecompressionPath(nn.Module):
    def __init__(self) -> None:
        super(DecompressionPath, self).__init__()

        self.up1 = nn.Sequential(
                nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv1 = nn.Sequential(
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        self.up2 = nn.Sequential(
                nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv2 = nn.Sequential(
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        
        self.up3 = nn.Sequential(
                nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv3 = nn.Sequential(
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        self.up4 = nn.Sequential(
                nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv4 = nn.Sequential(
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        self.conv5 = nn.Sequential(
                nn.Conv3d(32, 32, kernel_size=1, stride=1, padding=0),
                nn.PReLU()
            )

    def forward(self, enc_out, stage_outputs):
        
            u1  = self.up1(enc_out)
            id1 = u1
            h1  = torch.cat([u1, stage_outputs[-1]], dim=1)
            h1  = self.conv1(h1)
            h1  = h1 + id1

            u2  = self.up2(h1)
            id2 = u2
            h2  = torch.cat([u2, stage_outputs[-2]], dim=1)
            h2  = self.conv2(h2)
            h2  = h2 + id2

            u3  = self.up3(h2)
            id3 = u3
            h3  = torch.cat([u3, stage_outputs[-3]], dim=1)
            h3  = self.conv3(h3)
            h3  = h3 + id3

            u4  = self.up4(h3)
            id4 = u4
            h4  = torch.cat([u4, stage_outputs[-4]], dim=1)
            h4  = self.conv4(h4)
            h4  = h4 + id4

            out = self.conv5(h4)

            return out

 

 

Implementation of V-Net

An Encoder-Decoder network architecture

class VNet(nn.Module):
    def __init__(self) -> None:
        super(VNet, self).__init__()

        self.encoder = CompressionPath()
        self.decoder = DecompressionPath()

    
    def forward(self, x, activate=False):

        enc_out, stage_outputs = self.encoder(x)
        dec_out = self.decoder(enc_out, stage_outputs)


        if not activate:
            return dec_out
        else:
            output = F.softmax(dec_out, dim=1)
        return output

 

 


Dice Loss

The dice coefficient D between two binary volumes can be written as:

Dice coefficient D Explanation

 

# dice loss
def DiceLoss(inputs, targets, smooth=1):

    inputs  = F.sigmoid(inputs)
    
    inputs  = inputs.view(-1)
    targets = targets.view(-1)

    dice_coef = ((2.0 * inputs * targets).sum() + smooth )/ ((inputs**2).sum() + (targets**2).sum() + smooth)

    return 1 - dice_coef
profile

Coding, Filming, and Nothing

@_안쑤

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!