Dice score를 대상으로 동작하는 코드
BraTS2020 데이터 셋을 가지고 동작을 한다.
코드
# lighting.py
import torch
import torch.nn as nn
import torch.optim
import pytorch_lightning as pl
from monai.losses.dice import DiceLoss, DiceFocalLoss
from monai.metrics import DiceMetric, ConfusionMatrixMetric, HausdorffDistanceMetric
from monai.utils.enums import MetricReduction
from monai.data.utils import decollate_batch
from torchmetrics.functional import dice, f1_score
from monai.transforms import (
Activations,
Compose,
AsDiscrete,
)
from monai.inferers import sliding_window_inference
from models.apollo import Apollo
from models.cosine_anealing_warmup import CosineAnnealingWarmupRestarts
import numpy as np
from functools import partial
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)
class LightningRunner(pl.LightningModule):
def __init__(self, network, args) -> None:
super().__init__()
self.model = network
self.loss = DiceFocalLoss(sigmoid=True)
self.args = args
self.post_trans = Compose([Activations(sigmoid=True), AsDiscrete(argmax=False,threshold=0.5)])
self.dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
self.confusion = ConfusionMatrixMetric(include_background=True,metric_name=['sensitivity', 'specificity','precision'], get_not_nans=True, reduction=MetricReduction.MEAN_BATCH)
self.hausdorff = HausdorffDistanceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True, percentile=95.0)
self.run_acc = AverageMeter()
# self.inferer = partial(
# sliding_window_inference,
# roi_size=[128, 128, 128],
# sw_batch_size=1,
# predictor=self.model,
# overlap=0.6,
# )
def configure_optimizers(self):
optimizer = Apollo(params=self.parameters(), lr=self.args.init_lr, beta=0.9, eps=1e-4, rebound='constant', warmup=10, init_lr=None, weight_decay=0, weight_decay_type=None)
lr_scheduler = CosineAnnealingWarmupRestarts(optimizer=optimizer, first_cycle_steps=100, max_lr=self.args.init_lr, min_lr=1e-7, warmup_steps=20, gamma=0.9)
return [optimizer], [lr_scheduler]
def training_step(self, batch, batch_idx):
x, y = batch['image'], batch['label']
y_hat = self.model(x)
if isinstance(y_hat, list):
y_hat = y_hat[0]
loss = self.loss(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size, sync_dist=True)
return loss
def training_step_end(self, step_output):
ret = torch.mean(step_output)
return torch.mean(step_output)
def validation_step(self, batch, batch_idx):
metrics = self._shared_eval_step(batch, batch_idx)
return metrics
def validation_step_end(self, batch_parts):
avg_dice = np.mean(batch_parts['avg_dice'])
dice_tc = np.mean(batch_parts['dice_tc'])
dice_wt = np.mean(batch_parts['dice_wt'])
dice_et = np.mean(batch_parts['dice_et'])
return {'avg_dice':avg_dice, 'dice_tc':dice_tc, 'dice_wt':dice_wt, 'dice_et':dice_et}
def validation_epoch_end(self, outputs) -> None:
avg_dice = np.mean(np.stack([output['avg_dice'] for output in outputs]))
dice_tc = np.mean(np.stack([output['dice_tc'] for output in outputs]))
dice_wt = np.mean(np.stack([output['dice_wt'] for output in outputs]))
dice_et = np.mean(np.stack([output['dice_et'] for output in outputs]))
self.log_dict({'avg_dice':avg_dice, 'dice_tc':dice_tc, 'dice_wt':dice_wt, 'dice_et':dice_et}, prog_bar=True, sync_dist=True, logger=True)
return
def _shared_eval_step(self, batch, batch_idx):
x, y = batch['image'], batch['label']
y_hat = self.model(x)
# y_hat = self.inferer(x)
if isinstance(y_hat, list): # for UNet++
y_hat = y_hat[0]
labels_list = [one_label for one_label in y]
preds_list = [one_pred for one_pred in y_hat]
preds_converted = [self.post_trans(pred) for pred in preds_list]
self.dice_acc.reset()
self.dice_acc(y_pred=preds_converted, y=labels_list)
acc, not_nans = self.dice_acc.aggregate()
self.run_acc.reset()
self.run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
dice_tc = self.run_acc.avg[0]
dice_wt = self.run_acc.avg[1]
dice_et = self.run_acc.avg[2]
avg_dice= np.average([dice_tc,dice_wt,dice_et])
return {'avg_dice':avg_dice, 'dice_tc':dice_tc, 'dice_wt':dice_wt, 'dice_et':dice_et}
사용 예제
아래와 같이 실행할 코드에서 import를 통해 instance 생성, Trainer와 함께 사용하면 된다.
Multi-GPU 환경, APEX DDP setup, 16bit precision >> AMP
from lighting import LightningRunner
...
model = nets.BasicUNet(
spatial_dims=3,
in_channels=4,
out_channels=3,
)
pl_runner = LightningRunner(model, args)
trainer = Trainer(
max_epochs=args.epochs,
devices=[2,3],
accelerator='gpu',
precision=16,
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=[lr_monitor, checkpoint_callback, lr_finder],
check_val_every_n_epoch=10,
)
trainer.fit(
model= pl_runner,
train_dataloaders=train_loader,
val_dataloaders=val_loader
)
PyTorch Lighning 설명 글
'개발새발 > 개발 셋업' 카테고리의 다른 글
selenium을 활용한 원본 이미지 크롤링 파이썬 코드 (1) | 2023.04.18 |
---|---|
PyTorch 모델 특정 종류의 레이어를 바꾸기 (0) | 2023.03.16 |
윈도우 포맷 후 환경 설정 및 설치 체크리스트 (0) | 2023.02.24 |
PyTorch Lightning 소개 및 설명 (0) | 2023.02.13 |
[Docker, Linux] GPG에러 해결, public key 수동 설치 (apt-get이 에러가 발생하며 동작하지 않을 때) (0) | 2023.01.19 |