def eval_net_full(net, loader, device, ratio): # Evaluation of full image tot = 0 loss = 0 dataset = loader.dataset.dataset num_images = int(ratio * dataset.get_real_length()) image_idx = np.random.choice(np.arange(0, dataset.get_real_length()), num_images) with tqdm(total=num_images, desc="Full Validation round", unit="img", leave=False) as pbar: for i in image_idx: img, true_mask = dataset.get_raw_image(i), dataset.get_raw_mask(i) prediction = predict_full_image(net, img, device) prediction = torch.from_numpy(prediction).to(device=device).float() true_mask = torch.from_numpy(np.expand_dims( true_mask, 0)).to(device=device).float() tot += dice_coeff(((prediction > 0.3) * 1).float(), true_mask).item() loss += dice_loss(prediction, true_mask).item() pbar.update(i) return tot / num_images, loss / num_images, torch.from_numpy( img.transpose((2, 0, 1))), true_mask, prediction
def eval_net(net, loader, device, n_val): """Evaluation with the dice coefficient and Dice loss""" tot = 0 loss = 0 with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: for batch in loader: imgs = batch['image'] true_masks = batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) mask_pred = net(imgs) for true_mask, pred in zip(true_masks, mask_pred): pred = ((torch.sigmoid(pred) > 0.3) * 1).float() tot += dice_coeff(pred, true_mask).item() loss += dice_loss(pred, true_mask).item() pbar.update(imgs.shape[0]) return tot / n_val, loss / n_val
def train_model(epochs, criterion, optimizer, lr_scheduler, net, train_loader, val_loader, dir_checkpoint, logger, n_train, n_val, batch_size, writer, val_ratio, balance_classes): # torch.multiprocessing.set_start_method('spawn') # Register device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'Using device {device}') # Create the Network net.to(device=device) dataset_length = n_val + n_train global_step = 0 for epoch in range(epochs): net.train() # Sets module in training mode epoch_loss = [] with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: imgs = batch['image'] true_masks = batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) if balance_classes: # Neg / pos to rectify class imbalance pos_weight = torch.sum( torch.abs(true_masks - 1)) / torch.sum(true_masks) criterion.pos_weight = torch.tensor([pos_weight]).to( device=device, dtype=torch.float32) # Optimization step optimizer.zero_grad() masks_pred = net(imgs) # Make predictions loss = criterion(masks_pred, true_masks) # Evaluate loss batch_loss = loss.item() loss.backward() optimizer.step() # Add data to tensorboard epoch_loss.append(batch_loss) # Add loss to epoch writer.add_scalar('Train/BCE_loss', batch_loss, global_step) d_loss = dice_loss(torch.sigmoid(masks_pred), true_masks) writer.add_scalar('Train/Dice_loss', d_loss, global_step) pbar.set_postfix(**{'loss (batch)': batch_loss}) pbar.update(imgs.shape[0]) global_step += 1 # Validation every 10 batches if global_step % (dataset_length // (10 * batch_size)) == 0 and n_val > 0: net.eval() val_score, val_loss = eval_net(net, val_loader, device, n_val) net.train() # Reset in training mode logger.info('Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Validation/Dice_coef', val_score, global_step) writer.add_scalar('Validation/Dice_loss', val_loss, global_step) writer.add_images('images', imgs, global_step) writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred), global_step) # if lr_scheduler is not None: # lr_scheduler.step(int(val_loss * 1000)) # writer.add_scalar("LR", get_lr(optimizer), global_step) if global_step % 300 == 0 and n_val > 0: net.eval() val_full_score, val_full_loss, img, true_mask, mask_pred = eval_net_full( net, val_loader, device, val_ratio) net.train() logger.info('Full Validation Dice Coeff: {}'.format( val_full_score)) writer.add_scalar('Full_Validation/Dice_coef', val_full_score, global_step) writer.add_scalar('Full_Validation/Dice_loss', val_full_loss, global_step) writer.add_images('full_images', img[None, :, :, :], global_step) writer.add_images('full_masks/true', true_mask[None, :, :, :], global_step) writer.add_images('full_masks/pred', mask_pred[None, :, :, :], global_step) if (global_step + 1) % SAVE_EVERY == 0: torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logger.info(f'Checkpoint {epoch + 1} saved !') if lr_scheduler is not None: ep_loss = int(np.mean(epoch_loss) * 1000) lr_scheduler.step(ep_loss) writer.add_scalar("LR/epoch_loss", epoch_loss) writer.add_scalar("LR", get_lr(optimizer), global_step) writer.close() torch.save(net.state_dict(), os.path.join(dir_checkpoint, "final.pth"))