def inference(loader, model, args): for i, gt_sample in tqdm(enumerate(loader, 1)): (reference, moving) = (gt_sample['reference_irm'], gt_sample['moving_irm']) reference_patients, moving_patients = (gt_sample['reference_patient'], gt_sample['moving_patient']) reference = utils.to_var(args, reference.float()) moving = utils.to_var(args, moving.float()) # compute output pred_sample = model(moving, reference) deformable_grid, integrated_grid, deformed_img = pred_sample[0] moving_mask = utils.to_var(args, gt_sample['moving_mask'].float()) deformed_moving_mask = FrontiersNet.diffeomorphic3D( moving_mask, integrated_grid) deformed_moving_mask = (deformed_moving_mask > 0.5).float() n = reference.shape[0] for batch in range(n): save_pred(deformed_img[batch, ...], deformable_grid[batch, ...], integrated_grid[batch, ...], deformed_moving_mask[batch, ...], moving_patients[batch], reference_patients[batch], args) if i == (args.nb_batch - 1) and not args.all_dataset: break
def calculate_loss(gt_sample, pred_sample, loss_dict, criterion, identity_grid, args): loss = 0 L1, MSE = criterion['L1'], criterion['MSE'] reference = utils.to_var(args, gt_sample['reference_irm'].float()) moving = utils.to_var(args, gt_sample['moving_irm'].float()) mse_loss, regu_deformable_loss, lcc_loss = 0, 0, 0 ground_truths = [reference, moving, moving, reference] # 2 : Moving -> Reference, 3 : Reference -> Moving // Classic Deformation ground_truths = [reference] if args.symmetric_training: ground_truths += [moving] index = range(0, len(ground_truths)) for gt, i in zip(ground_truths, index): (deformed_img, deformable_grid) = (pred_sample[i][2], pred_sample[i][0]) if args.deep_supervision: mse = [MSE(img, gt) for img in deformed_img] mse_loss += torch.mean(torch.stack(mse)) else: mse_loss += MSE(deformed_img, gt) if args.local_cross_correlation_loss: if args.deep_supervision: lcc = [losses.ncc_loss(img, gt) for img in deformed_img] lcc_loss += torch.mean(torch.stack(lcc)) else: lcc_loss += losses.ncc_loss(deformed_img, gt) regu_deformable_loss += L1(deformable_grid, identity_grid) # Reconstruction loss mse_loss /= len(index) loss_dict.update('MSE_loss', mse_loss.mean().item()) loss += args.mse_loss_weight * mse_loss.mean() if args.local_cross_correlation_loss: lcc_loss /= len(index) loss_dict.update('LCC_loss', lcc_loss.mean().item()) loss += lcc_loss.mean() # Regularisation loss regu_deformable_loss /= len(index) loss_dict.update('Regu_Loss', regu_deformable_loss.mean().item()) loss += args.regu_deformable_loss_weight * regu_deformable_loss.mean() loss_dict.update('Loss', loss.item()) return loss, loss_dict, mse_loss
def train(loader, model, criterion, optimizer, writer, logging, epoch, args): end = time.time() loss_dict = utils.MultiAverageMeter() logging_mode = 'Train' if model.training else 'Val' dice_dataframe = pd.DataFrame( columns=['Dice_head', 'Dice_tail', 'Reference']) for i, gt_sample in enumerate(loader, 1): (reference, moving) = (gt_sample['reference_irm'], gt_sample['moving_irm']) # measure data loading time data_time = time.time() - end reference = utils.to_var(args, reference.float()) moving = utils.to_var(args, moving.float()) # compute output pred_sample = model(moving, reference) # Apply predict deformation on ground truth mask integrated_grid = pred_sample[0][1] moving_mask = utils.to_var(args, gt_sample['moving_mask'].float()) deformed_moving_mask = FrontiersNet.diffeomorphic3D( moving_mask, integrated_grid) deformed_reference_mask = None if args.symmetric_training: integrated_grid = pred_sample[1][1] reference_mask = utils.to_var(args, gt_sample['reference_mask'].float()) deformed_reference_mask = FrontiersNet.diffeomorphic3D( reference_mask, integrated_grid) # compute loss identity_grid = args.identity_grid if model.training else args.identity_val_grid loss, loss_dict, mse_loss = calculate_loss(gt_sample, pred_sample, loss_dict, criterion, identity_grid, args, deformed_moving_mask, deformed_reference_mask) if model.training: # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() deformed_moving_mask = (deformed_moving_mask > 0.5).float() if args.symmetric_training: deformed_reference_mask = (deformed_reference_mask > 0.5).float() # Metrics moving_patients = gt_sample['moving_patient'] reference_patients = gt_sample['reference_patient'] reference_mask = utils.to_numpy(args, gt_sample['reference_mask']) deformed_moving_mask = utils.to_numpy(args, deformed_moving_mask) dataframe = losses.evalAllSample(deformed_moving_mask, reference_mask, moving_patients, reference_patients) dice_dataframe = pd.concat([dice_dataframe, dataframe]) loss_dict.update('Dice_head', dataframe.mean()['Dice_head']) loss_dict.update('Dice_tail', dataframe.mean()['Dice_tail']) # measure elapsed time batch_time = time.time() - end end = time.time() loss_dict.update('Batch_time', batch_time) loss_dict.update('Data_time', data_time) if i % args.print_frequency == 0: utils.print_summary(epoch, i, len(loader), loss_dict, logging, logging_mode) if args.tensorboard: step = epoch * len(loader) + i for key in loss_dict.names: writer.add_scalar(logging_mode + '_' + key, loss_dict.get(key).val, step) if args.save_csv > 0 and epoch % args.save_csv == 0: filename = args.csv_path + '/{}_{:02d}.csv'.format(logging_mode, epoch) dice_dataframe.to_csv(filename) if args.tensorboard: writer.add_scalar(logging_mode + '_Dice_head_avg', dice_dataframe.mean()['Dice_head'], epoch) writer.add_scalar(logging_mode + '_Dice_tail_avg', dice_dataframe.mean()['Dice_tail'], epoch) # Add average value to the tensorboard avg_dict = loss_dict.return_all_avg() for key in ['MSE_loss', 'Regu_Loss', 'Deformed_mask_Loss', 'LCC_loss']: if key in avg_dict: writer.add_scalar(logging_mode + '_' + key + '_avg', avg_dict[key], epoch) if epoch % args.image_tensorboard_frequency == 0: # Add images to tensorboard n = reference.shape[0] for batch in range(n): mse_loss = mse_loss if args.plot_loss else None deformed_moving_mask = deformed_moving_mask if args.plot_mask else None fig_registration = ImageTensorboard.plot_registration_results( gt_sample, pred_sample, batch, args, mse_loss, deformed_moving_mask) writer.add_figure(logging_mode + str(batch) + '_Regis', fig_registration, epoch) return loss_dict.get('Loss').avg
def calculate_loss(gt_sample, pred_sample, loss_dict, criterion, identity_grid, args, deformed_moving_mask=None, deformed_reference_mask=None): loss = 0 L1, MSE, Dice = criterion['L1'], criterion['MSE'], criterion['seg'] bbox = utils.to_var(args, gt_sample['bbox']) non_zeros = bbox[:, [1], ...].float().detach() reference = utils.to_var(args, gt_sample['reference_irm'].float()) moving = utils.to_var(args, gt_sample['moving_irm'].float()) (mse_loss, regu_deformable_loss, regu_deformable_MSE_loss, lcc_loss) = 0, 0, 0, 0 # 0 : Moving -> Reference, 1 : Reference -> Moving ground_truths = [reference] if args.symmetric_training: ground_truths += [moving] index = range(0, len(ground_truths)) for gt, i in zip(ground_truths, index): (deformed_img, deformable_grid) = (pred_sample[i][2], pred_sample[i][0]) if args.deep_supervision: mse = [ MSE(img * non_zeros, gt * non_zeros) for img in deformed_img ] mse_loss += torch.mean(torch.stack(mse)) else: mse_loss += MSE(deformed_img * non_zeros, gt * non_zeros) if args.local_cross_correlation_loss: if args.deep_supervision: lcc = [ losses.ncc_loss(img * non_zeros, gt * non_zeros) for img in deformed_img ] lcc_loss += torch.mean(torch.stack(lcc)) else: lcc_loss += losses.ncc_loss(deformed_img * non_zeros, gt * non_zeros) regu_deformable_loss += L1(deformable_grid, identity_grid) regu_deformable_MSE_loss += MSE(deformable_grid, identity_grid) # Reconstruction loss mse_loss /= len(index) loss_dict.update('MSE_loss', mse_loss.mean().item()) loss += args.mse_loss_weight * mse_loss.mean() if args.local_cross_correlation_loss: lcc_loss /= len(index) loss_dict.update('LCC_loss', lcc_loss.mean().item()) loss += lcc_loss.mean() # Regularisation loss regu_deformable_loss /= len(index) loss_dict.update('Regu_L1_Loss', regu_deformable_loss.mean().item()) loss += args.regu_deformable_loss_weight * regu_deformable_loss.mean() regu_deformable_MSE_loss /= len(index) loss_dict.update('Regu_MSE_Loss', regu_deformable_MSE_loss.mean().item()) loss += args.regu_deformable_MSE_loss_weight * regu_deformable_MSE_loss.mean( ) # Supervised Loss (Segmentation) if args.deformed_mask_loss: reference_mask_gt = utils.to_var(args, gt_sample['reference_mask'].float()) deformed_mask_loss = Dice(deformed_moving_mask, reference_mask_gt) if args.symmetric_training: moving_mask_gt = utils.to_var(args, gt_sample['moving_mask'].float()) deformed_mask_loss += Dice(deformed_reference_mask, moving_mask_gt) deformed_mask_loss /= 2 loss_dict.update('Deformed_mask_Loss', deformed_mask_loss.item()) loss += args.deformed_mask_loss_weight * deformed_mask_loss loss_dict.update('Loss', loss.item()) return loss, loss_dict, mse_loss
def main(args): args.main_path = main_path args.test = False # Init of args args.cuda = torch.cuda.is_available() args.data_parallel = args.data_parallel and args.cuda print('CUDA available : {}'.format(args.cuda)) if isinstance(args.crop_size, int): args.crop_size = (args.crop_size, args.crop_size, args.crop_size) if args.channels is None: args.channels = [4, 8, 16, 32, 64, 128, 256] if args.classic_vnet: args.nb_Convs = [1, 2, 3, 2, 2, 2] elif args.nb_Convs is None: args.nb_Convs = [1, 1, 1, 1, 1, 1, 1] args.gpu = 0 if args.session_name == '': args.session_name = args.arch + '_' + time.strftime('%m.%d %Hh%M') else: args.session_name += '_' + time.strftime('%m.%d %Hh%M') if args.debug: args.session_name += '_debug' args.save_path = main_path + 'save/' args.model_path = args.save_path + 'models/' + args.session_name args.dataset_path = main_path + '/datasets/' tensorboard_folder = args.save_path + 'tensorboard_logs/' log_folder = args.save_path + 'logs/' folders = [ args.save_path, args.model_path, tensorboard_folder, log_folder, args.dataset_path ] if args.save_csv > 0: args.csv_path = args.save_path + 'csv/' + args.session_name folders.append(args.csv_path) for folder in folders: if not os.path.isdir(folder): os.makedirs(folder) if args.tensorboard: log_dir = tensorboard_folder + args.session_name + '/' if not os.path.isdir(log_dir): os.makedirs(log_dir) writer = tensorboard.SummaryWriter(log_dir) else: writer = None print('******************* Start training *******************') print('******** Parameter ********') # Log log_path = log_folder + args.session_name + '.log' logging = log.set_logger(log_path) # logs some path info and arguments logging.info('Original command line: {}'.format(' '.join(sys.argv))) logging.info('Arguments:') for arg, value in vars(args).items(): logging.info("%s: %r", arg, value) # Model logging.info("=> creating model '{}'".format(args.arch)) model_kwargs = {} if args.arch in ['FrontiersNet']: params = [ 'channel_multiplication', 'pool_blocks', 'channels', 'last_activation', 'instance_norm', 'batch_norm', 'nb_Convs', 'deep_supervision', 'freeze_registration', 'zeros_init', 'symmetric_training' ] for param in params: model_kwargs[param] = getattr(args, param) if args.model_abspath is not None: (model, model_epoch) = model_loader.load_model(args, model_kwargs) else: model = model_loader.create_model(args, model_kwargs) if args.data_parallel: model = nn.DataParallel(model).cuda(args.gpu) elif args.cuda: model = model.cuda(args.gpu) logging.info('=> Model ready') logging.info(model) # Loss criterion = { 'L1': nn.L1Loss(), 'MSE': nn.MSELoss(reduction='none'), 'seg': losses.mean_dice_loss, 'BCE': nn.BCELoss(reduction='none') } optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Data if args.random_crop: crop = transformations.RandomCrop(args.crop_size, do_affine=args.affine_transform) else: crop = transformations.CenterCrop(args.crop_size, do_affine=args.affine_transform) val_crop_size = args.val_crop_size if args.val_crop_size is not None else args.crop_size val_crop = transformations.CenterCrop(val_crop_size, do_affine=False) transforms_list = [crop, transformations.Normalize()] if args.data_augmentation: transforms_list.append(transformations.AxialFlip()) transforms_list.append(transformations.RandomRotation90()) if args.affine_transform: data_aug_kwargs = { 'theta': 20, 'max_translation': 10, 'only_one_rotation': True, 'max_zoom': 0.2, 'isotropique_zoom': True } transforms_list.append( transformations.AffineTransform(data_aug_kwargs)) val_transforms_list = [val_crop, transformations.Normalize()] transformation = torchvision.transforms.Compose(transforms_list) val_transformation = torchvision.transforms.Compose(val_transforms_list) (train_Dataset, val_Dataset) = Dataset.init_datasets(transformation, val_transformation, args) #pin_memory = True if args.on_aws else False pin_memory = False train_loader = torch.utils.data.DataLoader(train_Dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=pin_memory, drop_last=True) val_loader = torch.utils.data.DataLoader(val_Dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=pin_memory, drop_last=True) args.identity_grid = utils.to_var( args, 0.5 * torch.ones([args.batch_size, 3, *args.crop_size])) args.identity_val_grid = utils.to_var( args, 0.5 * torch.ones([args.batch_size, 3, *val_crop_size])) # summary(model, input_size=(4, *args.crop_size)) start_training = time.time() best_loss = 1e9 for epoch in range(args.epochs): # loop over the dataset multiple times print('******** Epoch [{}/{}] ********'.format( epoch + 1, args.epochs)) print(args.session_name) start_epoch = time.time() # train for one epoch model.train() _ = train(train_loader, model, criterion, optimizer, writer, logging, epoch, args) # evaluate on validation set with torch.no_grad(): model.eval() avg_loss = train(val_loader, model, criterion, optimizer, writer, logging, epoch, args) # remember best loss and save checkpoint is_best = best_loss > avg_loss best_loss = min(best_loss, avg_loss) utils.save_checkpoint( args, { 'epoch': epoch, 'state_dict': model.state_dict(), 'val_loss': avg_loss, 'optizmizer': optimizer.state_dict(), }, is_best) logging.info('Epoch time : {} s'.format(time.time() - start_epoch)) args.training_time = time.time() - start_training logging.info('Finished Training') logging.info('Training time : {}'.format(args.training_time)) log.clear_logger(logging) if args.tensorboard: writer.close() return avg_loss
def train(loader, model, criterion, optimizer, writer, logging, epoch, args): end = time.time() loss_dict = utils.MultiAverageMeter() logging_mode = 'Train' if model.training else 'Val' for i, gt_sample in enumerate(loader, 1): (reference, moving) = (gt_sample['reference_irm'], gt_sample['moving_irm']) # measure data loading time data_time = time.time() - end reference = utils.to_var(args, reference.float()) moving = utils.to_var(args, moving.float()) # compute output pred_sample = model(moving, reference) # compute loss identity_grid = args.identity_grid if model.training else args.identity_val_grid loss, loss_dict, mse_loss = calculate_loss(gt_sample, pred_sample, loss_dict, criterion, identity_grid, args) if model.training: # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time = time.time() - end end = time.time() loss_dict.update('Batch_time', batch_time) loss_dict.update('Data_time', data_time) if i % args.print_frequency == 0: utils.print_summary(epoch, i, len(loader), loss_dict, logging, logging_mode) if args.tensorboard: step = epoch * len(loader) + i for key in loss_dict.names: writer.add_scalar(logging_mode + '_' + key, loss_dict.get(key).val, step) if args.tensorboard: # Add average value to the tensorboard avg_dict = loss_dict.return_all_avg() for key in ['MSE_loss', 'Regu_Loss', 'LCC_loss']: if key in avg_dict: writer.add_scalar(logging_mode + '_' + key + '_avg', avg_dict[key], epoch) if epoch % args.image_tensorboard_frequency == 0: # Add images to tensorboard for batch in range(args.batch_size): mse_loss = mse_loss if args.plot_loss else None fig_registration = ImageTensorboard.plot_registration_results( gt_sample, pred_sample, batch, args, mse_loss) writer.add_figure(logging_mode + str(batch) + '_Regis', fig_registration, epoch) return loss_dict.get('Loss').avg