num_workers=0, batch_size=4, shuffle=True) val_load = \ torch.utils.data.DataLoader(dataset=val, num_workers=0, batch_size=1, shuffle=False) # Dataloader end # Model model = NestedUNet(in_channels=1, out_channels=5) # model = U_Net() model = torch.nn.DataParallel( model, device_ids=list(range(torch.cuda.device_count()))).cuda() # model.load_state_dict(torch.load("E:\\yuanxiaohan\\Cardic_segmentation\\my project\\LV_seg_all\\Histories\\02\\saved_models\\model_epoch_checkpoint_60.pth")) # Loss function criterion = LovaszLossSoftmax() # Optimizerd # optimizer = torch.optim.RMSprop(model.module.parameters(), lr=1e-4) optimizer = torch.optim.Adam(model.module.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 50], gamma=0.9) # Parameters epoch_start = 0
def forward(self, x, y): tp = (x * y).sum(self.dims) fp = (x * (1 - y)).sum(self.dims) fn = ((1 - x) * y).sum(self.dims) dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) dc = dc.mean() return 1 - dc bce_fn = nn.BCEWithLogitsLoss() # bce_fn = nn.BCELoss() dice_fn = SoftDiceLoss() cross_loss = nn.CrossEntropyLoss() loss_f = LovaszLossSoftmax() def loss_fn(y_pred, y_true, ratio=0.8, hard=False): bce = bce_fn(y_pred, y_true) if hard: dice = dice_fn((y_pred.sigmoid()).float() > 0.5, y_true) else: dice = dice_fn(y_pred.sigmoid(), y_true) return ratio * bce + (1 - ratio) * dice EPOCHES = 70 BATCH_SIZE = 8 NUM_WORKERS = 4
def train_net(net, cfg): dataset = BasicDataset(cfg.images_dir, cfg.masks_dir, cfg.scale) val_percent = cfg.validation / 100 n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=cfg.batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=cfg.batch_size, shuffle=False, num_workers=8, pin_memory=True) writer = SummaryWriter( comment=f'LR_{cfg.lr}_BS_{cfg.batch_size}_SCALE_{cfg.scale}') global_step = 0 logging.info(f'''Starting training: Epochs: {cfg.epochs} Batch size: {cfg.batch_size} Learning rate: {cfg.lr} Optimizer: {cfg.optimizer} Training size: {n_train} Validation size: {n_val} Checkpoints: {cfg.save_cp} Device: {device.type} Images scaling: {cfg.scale} ''') if cfg.optimizer == 'Adam': optimizer = optim.Adam(net.parameters(), lr=cfg.lr) elif cfg.optimizer == 'RMSprop': optimizer = optim.RMSprop(net.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) else: optimizer = optim.SGD(net.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay, nesterov=cfg.nesterov) scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=cfg.lr_decay_milestones, gamma=cfg.lr_decay_gamma) if cfg.n_classes > 1: criterion = LovaszLossSoftmax() # criterion = nn.CrossEntropyLoss() else: criterion = LovaszLossHinge() # criterion = nn.BCEWithLogitsLoss() for epoch in range(cfg.epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{cfg.epochs}', unit='img') as pbar: for batch in train_loader: batch_imgs = batch['image'] batch_masks = batch['mask'] assert batch_imgs.shape[1] == cfg.n_channels, \ f'Network has been defined with {cfg.n_channels} input channels, ' \ f'but loaded images have {batch_imgs.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' batch_imgs = batch_imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if cfg.n_classes == 1 else torch.long batch_masks = batch_masks.to(device=device, dtype=mask_type) inference_masks = net(batch_imgs) if cfg.n_classes == 1: inferences = inference_masks.squeeze(1) masks = batch_masks.squeeze(1) else: inferences = inference_masks masks = batch_masks if cfg.deepsupervision: loss = 0 for inference_mask in inferences: loss += criterion(inference_mask, masks) loss /= len(inferences) else: loss = criterion(inferences, masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) writer.add_scalar('model/lr', optimizer.param_groups[0]['lr'], global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() pbar.update(batch_imgs.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * cfg.batch_size)) == 0: val_score = eval_net(net, val_loader, device, n_val, cfg) if cfg.n_classes > 1: logging.info( 'Validation cross entropy: {}'.format(val_score)) writer.add_scalar('CrossEntropy/test', val_score, global_step) else: logging.info( 'Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', batch_imgs, global_step) if cfg.deepsupervision: inference_masks = inference_masks[-1] if cfg.n_classes == 1: # writer.add_images('masks/true', batch_masks, global_step) inference_mask = torch.sigmoid( inference_masks) > cfg.out_threshold writer.add_images('masks/inference', inference_mask, global_step) else: # writer.add_images('masks/true', batch_masks, global_step) ids = inference_masks.shape[1] # N x C x H x W inference_masks = torch.chunk(inference_masks, ids, dim=1) for idx in range(0, len(inference_masks)): inference_mask = torch.sigmoid( inference_masks[idx]) > cfg.out_threshold writer.add_images('masks/inference_' + str(idx), inference_mask, global_step) if cfg.save_cp: try: os.mkdir(cfg.checkpoints_dir) logging.info('Created checkpoint directory') except OSError: pass ckpt_name = 'epoch_' + str(epoch + 1) + '.pth' torch.save(net.state_dict(), osp.join(cfg.checkpoints_dir, ckpt_name)) logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()