Coding, Filming, and Nothing
article thumbnail

info

**Attention U-Net: Learning Where to Look for the Pancreas**    
*Ozan Oktay, et al.*  
MIDL 2018  

-

 

Attention U-Net 파이토치 구현입니다. 논문을 보고 작성하는데 Gating Signal $g$ 에서 막혀서 시간이 좀 걸렸네요.

Attention U-Net Encoder, Decoder, 그리고 AttentionGate (softmax, addtive attention)을 구현했습니다.

*Deep supervision은 본문 연구에서 encoder output을 포함하여 stage_outputs들을 모두 segementation map size로 upsampling을 수행 한 후 channel-wise concat을 하여 적용을 했는데 구현 하지 않았습니다.

 

가졌던 의문이 gating signal $g$를 얻는데 사용되는 Tensor는 $ F_{l+1} \times H_{l+1} \times W_{l+1} \times D_{l+1}$ 의 사이즈를 가지고 있고, Attention gate에 input으로 입력되는 각 encoder의 stage output의 size는 $ F_{l} \times H_{l} \times W_{l} \times D_{l}$ 이여서 volume 3d 영상의 크기가 달라 ( $2*F_l == F_{l+1}$) 바로 연산을 수행할 수 없는..데.. 본문에 별 다른 언급 없이 기술되어 있었네요. 

찾아보니, 더 큰 사이즈의 3d 영상을 연산 전에 2x2x2 kernel with 2 strides 로 크기를 맞추어서 수행을 했네요. 그 외에는 다른 2d Attention과 동일합니다.

 

Attention U-Net 장점 요약

  1. 향상된 정확도: Attention U-Net의 어텐션 메커니즘을 통해 네트워크는 이미지에서 가장 관련성이 높은 feature에 집중할 수 있으므로 Segmentation 정확도가 향상됩니다.
  2. False positive/negative 감소: Attention U-Net은 이미지의 가장 중요한(유용한 정보가 포함된) 부분에 초점을 맞춤으로써 Segmentatino 출력에서 False positive와 False Negative의 수를 줄일 수 있습니다.
  3. 이미지 품질의 변화에 영향을 적게 받음: Attention U-Net은 이미지의 노이즈 또는 아티팩트(촬영장비)에 종속적인 요인과 같은 이미지 품질의 변화에 강인한 것으로 나타났습니다.
  4. Segmentatino 결과의 더 나은 시각화: Attention U-Net의 어텐션 메커니즘은 Segmentatino 중에 네트워크가 이미지의 어떤 부분에 집중하고 있는지 더 잘 이해할 수 있도록 시각화가 가능한 Attention 맵을 생성할 수 있습니다.
  5. 다른 작업으로의 이전 가능성: Attention U-Net은 하나의 데이터 세트에서 교육을 받고 최소한의 수정으로 다른 데이터 세트에 적용할 수 있으므로 의료 이미지 분할 작업을 위해 pre-trainined model로 사용할 수 있습니다.

 

5 advantages of the Attention U-Net

Attention U-Net is a modified version of the U-Net architecture that includes attention mechanisms for medical image segmentation.

  1. Improved segmentation accuracy: The attention mechanism in the Attention U-Net allows the network to focus on the most relevant features in the image, resulting in improved segmentation accuracy.
  2. Reduced false positives and false negatives: By focusing on the most informative parts of the image, the Attention U-Net can reduce the number of false positives and false negatives in the segmentation output.
  3. Robust to variations in image quality: The Attention U-Net has been shown to be robust to variations in image quality, such as noise or artifacts in the image.
  4. Better visualization of segmentation results: The attention mechanism in the Attention U-Net produces attention maps that can be visualized to better understand which parts of the image the network is focusing on during segmentation.
  5. Transferability to other tasks: The Attention U-Net can be trained on one dataset and applied to another with minimal modification, making it a highly transferable architecture for medical image segmentation tasks.

Attention U-Nex arch. summary

 

 

Basic operation Module

3D Convolution

  • 3x3x3 kernel size convolution
  • .. batch-normalisation, ...
  • ReLU activation
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=None) -> None:
        super(ConvLayer, self).__init__()

        if not hidden_dim:
            hidden_dim = out_dim

        self.conv = nn.Sequential(
                nn.Conv3d(in_dim, hidden_dim, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(),
                nn.Conv3d(hidden_dim, out_dim, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm3d(out_dim),
                nn.ReLU()
            )

    def forward(self, x):
        return self.conv(x)

UpsamplingLayer

class UpsamplingLayer(nn.Module):
    def __init__(self, in_dim, out_dim, is_deconv=True) -> None:
        super(UpsamplingLayer, self).__init__()

        if is_deconv:
            self.upsampler = nn.ConvTranspose3d(in_dim, out_dim, kernel_size=2, stride=2, padding=0)
        else:
            self.upsampler = nn.Upsample(size=out_dim, scale_factor=2, mode='trilinear')

    def forward(self, x):
        return self.upsampler(x)

 

 

 

 

Encoder 

U-net 흐름과 동일하고 V-net 구현과 pooling을 제외하면 차이가 없습니다.

(Attention U-Net: 2x2x2 max pooling // V-net: pooling을 convolution kernel로 구현)

  • Input image is progressively filtered and downsampled by factor of 2 at each scale in the encoding part of the network
  • 2x2x2 max pooling
# Encoder
class AttnUNet_Encoder(nn.Module):
    def __init__(self) -> None:
        super(AttnUNet_Encoder, self).__init__()

        self.hidden_dim = 32

        self.conv1 = ConvLayer(1, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv2 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv3 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.hidden_dim *= 2
        self.conv4 = ConvLayer(self.hidden_dim//2, self.hidden_dim)

        self.pool = nn.MaxPool3d(2, 2)

    def forward(self, x):

        h1  = self.conv1(x)
        p1  = self.pool(h1)

        h2  = self.conv2(p1)
        p2  = self.pool(h2)

        h3  = self.conv3(p2)
        p3  = self.pool(h3)

        h4  = self.conv4(p3)
        
        stage_outputs = [h1, h2, h3]

        return h4, stage_outputs

 

 

Attention Gate

본문 attention gate 설명

 

class AttentionGate(nn.Module):
    def __init__(self, in_dim, coarser_dim, hidden_dim) -> None:
        super(AttentionGate, self).__init__()

        self.GridGateSignal_generator = nn.Sequential(
                nn.Conv3d(coarser_dim, coarser_dim, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU()
            )
        
        # input feature x // the gating signal from a coarser scale
        # gating dim == in_dim*2
        self.w_x = nn.Conv3d(in_dim, hidden_dim, kernel_size=(2,2,2), stride=(2,2,2), padding=0, bias=False)
        self.w_g = nn.Conv3d(coarser_dim, hidden_dim, kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = nn.Conv3d(hidden_dim, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, inputs, coarser):
        query = self.GridGateSignal_generator(coarser)

        proj_x = self.w_x(inputs)
        proj_g = self.w_g(query)

        addtive = F.relu(proj_x + proj_g)
        attn_coef = self.psi(addtive)

        attn_coef = F.upsample(attn_coef, inputs.size()[2:], mode='trilinear')

        return attn_coef

 

 

Decoder

# Decoder
class AttnUNet_Decoder(nn.Module):
    def __init__(self) -> None:
        super(AttnUNet_Decoder, self).__init__()

        self.hidden_dim = 32
        self.out_dim    = 2

        self.attn1  = AttentionGate(self.hidden_dim, self.hidden_dim*2, self.hidden_dim)
        self.up1    = UpsamplingLayer(self.hidden_dim*2, self.hidden_dim)
        self.conv1  = ConvLayer(self.hidden_dim*2, self.out_dim, self.hidden_dim)

        self.hidden_dim *= 2 # 32
        self.attn2  = AttentionGate(self.hidden_dim, self.hidden_dim*2, self.hidden_dim)
        self.up2    = UpsamplingLayer(self.hidden_dim*2, self.hidden_dim)
        self.conv2  = ConvLayer(self.hidden_dim*2, self.hidden_dim, self.hidden_dim)

        self.hidden_dim *= 2 # 64
        self.attn3  = AttentionGate(self.hidden_dim, self.hidden_dim*2, self.hidden_dim)
        self.up3    = UpsamplingLayer(self.hidden_dim*2, self.hidden_dim)
        self.conv3  = ConvLayer(self.hidden_dim*2, self.hidden_dim, self.hidden_dim)

    
    def forward(self, enc_out, stage_outputs):

        attn_g3 = self.attn3(stage_outputs[-1], enc_out) * stage_outputs[-1]
        h3      = self.up3(enc_out)
        h3      = torch.concat([attn_g3, h3], dim=1)
        h3      = self.conv3(h3)

        attn_g2 = self.attn2(stage_outputs[-2], h3) * stage_outputs[-2]
        h2      = self.up2(h3)
        h2      = torch.concat([attn_g2, h2], dim=1)
        h2      = self.conv2(h2)

        attn_g1 = self.attn1(stage_outputs[-3], h2) * stage_outputs[-3]
        h1      = self.up1(h2)
        h1      = torch.concat([attn_g1, h1], dim=1)
        h1      = self.conv1(h1)

        return h1

 

 

Attention U-Net

class Attn_UNet(nn.modules):
    def __init__(self) -> None:
        super(Attn_UNet, self).__init__()

        self.encoder = AttnUNet_Encoder()
        self.decoder = AttnUNet_Decoder()

    def forward(self, inputs):

        enc_out, stage_outputs = self.encoder(inputs)
        out = self.decoder(enc_out, stage_outputs)

        return out

 

profile

Coding, Filming, and Nothing

@_안쑤

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!