def test(args): config = set_up_model_and_preprocessing('TESTING', args) dataset_test = ImageSegRegDataset(args.test, args.test_seg, args.test_msk, normalizer_img=config.normalizer_img, normalizer_seg=config.normalizer_seg, resampler_img=config.resampler_img, resampler_seg=config.resampler_seg) dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False) loss_names = [ 'loss_itn', 'loss_stn_u', 'loss_stn_s', 'loss_stn_i', 'loss_stn_r', 'loss', 'metric_dice', 'metric_hd', 'metric_asd', 'metric_precision', 'metric_recall' ] test_logger = mira_metrics.Logger('TEST', loss_names) # Create output directory out_dir = os.path.join(args.out, 'test') if not os.path.exists(out_dir): os.makedirs(out_dir) config.itn.load_state_dict(torch.load(args.model + '/itn.pt')) config.itn.eval() config.stn.load_state_dict(torch.load(args.model + '/stn.pt')) config.stn.eval() with torch.no_grad(): for index, batch_samples in enumerate(dataloader_test): loss, images_dict, values_dict = process_batch( config, config.itn, config.stn, batch_samples) test_logger.update_epoch_logger(values_dict) # source_transformed = sitk.GetImageFromArray(images_dict['source_prime'].cpu().squeeze().numpy().astype(np.uint8)) # print(images_dict['source_prime'].cpu().squeeze().numpy().astype(np.uint8).shape) # print(images_dict['source_prime'].cpu().squeeze().numpy().astype(np.uint8).dtype) # print(images_dict['source_prime'].cpu().squeeze().numpy().max()) # source_transformed.CopyInformation(dataset_test.get_sample(index)['source']) # sitk.WriteImage(source_transformed, # os.path.join(out_dir, 'sample_' + str(index) + '_source_prime.nii.gz')) # sitk.WriteImage(source_transformed, # os.path.join(out_dir, 'sample_' + str(index) + '_source_prime.png')) # target_transformed = sitk.GetImageFromArray(images_dict['target_prime'].cpu().squeeze().numpy()) # target_transformed.CopyInformation(dataset_test.get_sample(index)['target']) # sitk.WriteImage(target_transformed, # os.path.join(out_dir, 'sample_' + str(index) + '_target_prime.nii.gz')) # sitk.WriteImage(target_transformed, # os.path.join(out_dir, 'sample_' + str(index) + '_target_prime.png')) warped_source = sitk.GetImageFromArray( images_dict['warped_source'].cpu().squeeze().numpy().astype( np.uint8)) warped_source.CopyInformation( dataset_test.get_sample(index)['target']) # sitk.WriteImage(warped_source, # os.path.join(out_dir, 'sample_' + str(index) + '_warped_source.nii.gz')) sitk.WriteImage( warped_source, os.path.join(out_dir, 'sample_' + str(index) + '_warped_source.png')) # warped_source_seg = sitk.GetImageFromArray(images_dict['warped_source_seg'].cpu().squeeze().numpy()) # warped_source_seg.CopyInformation(dataset_test.get_sample(index)['target']) # sitk.WriteImage(warped_source_seg, # os.path.join(out_dir, 'sample_' + str(index) + '_warped_source_seg.nii.gz')) # sitk.WriteImage(warped_source_seg, # os.path.join(out_dir, 'sample_' + str(index) + '_warped_source_seg.png')) # sitk.WriteImage(dataset_test.get_sample(index)['source'], # os.path.join(out_dir, 'sample_' + str(index) + '_source.nii.gz')) # sitk.WriteImage(dataset_test.get_sample(index)['source'], # os.path.join(out_dir, 'sample_' + str(index) + '_source.png')) # sitk.WriteImage(dataset_test.get_sample(index)['target'], # os.path.join(out_dir, 'sample_' + str(index) + '_target.nii.gz')) # sitk.WriteImage(dataset_test.get_sample(index)['target'], # os.path.join(out_dir, 'sample_' + str(index) + '_target.png')) # sitk.WriteImage(dataset_test.get_sample(index)['source_seg'], # os.path.join(out_dir, 'sample_' + str(index) + '_source_seg.nii.gz')) # sitk.WriteImage(dataset_test.get_sample(index)['source_seg'], # os.path.join(out_dir, 'sample_' + str(index) + '_source_seg.png')) # sitk.WriteImage(dataset_test.get_sample(index)['target_seg'], # os.path.join(out_dir, 'sample_' + str(index) + '_target_seg.nii.gz')) # sitk.WriteImage(dataset_test.get_sams'sample_' + str(index) + '_target_seg.png')) # with open(os.path.join(out_dir,'test_results.yml'), 'w') as outfile: # yaml.dump(test_logger.get_epoch_logger(), outfile) # test_logger.update_epoch_summary(0) if args.no_refine == False: refine_config = set_up_model_and_preprocessing('REFINEMENT', args) config.itn.eval() for index, batch_samples in enumerate(dataloader_test): print('Processing image ' + str(index + 1) + ' of ' + str(len(dataset_test))) # Set up fine tuning network to have grads but not the stn refine_config.stn.load_state_dict( torch.load(args.model + '/stn.pt')) refine_config.stn.train() optimizer = torch.optim.Adam( refine_config.stn.parameters(), lr=refine_config.config['learning_rate']) # Fine tune STN for epoch in range(1, config.config['refine'] + 1): optimizer.zero_grad() _loss, images_dict, values_dict = process_batch( config, config.itn, refine_config.stn, batch_samples) loss = values_dict['loss_stn_r'] loss.backward() optimizer.step() with torch.no_grad(): loss, images_dict, values_dict = process_batch( config, config.itn, refine_config.stn, batch_samples) test_logger.update_epoch_logger(values_dict) warped_source = sitk.GetImageFromArray( images_dict['warped_source'].cpu().squeeze().numpy( ).astype(np.uint8)) warped_source.CopyInformation( dataset_test.get_sample(index)['target']) # sitk.WriteImage(warped_source, # os.path.join(out_dir, 'sample_' + str(index) + '_warped_source_refined.nii.gz')) sitk.WriteImage( warped_source, os.path.join( out_dir, 'sample_' + str(index) + '_warped_source_refined.png'))
def train(args): config = set_up_model_and_preprocessing('TRAINING', args) writer = SummaryWriter('{}/tensorboard'.format(args.out)) global_step = 0 print(separator) print('TRAINING data...') print(separator) dataset_train = ImageSegRegDataset(args.train, args.train_seg, args.train_msk, normalizer_img=config.normalizer_img, normalizer_seg=config.normalizer_seg, resampler_img=config.resampler_img, resampler_seg=config.resampler_seg) dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=config.config['batch_size'], shuffle=True) if args.val is not None: print(separator) print('VALIDATION data...') print(separator) dataset_val = ImageSegRegDataset(args.val, args.val_seg, args.val_msk, normalizer_img=config.normalizer_img, normalizer_seg=config.normalizer_seg, resampler_img=config.resampler_img, resampler_seg=config.resampler_seg) dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False) # Create output directory out_dir = os.path.join(args.out, 'train') if not os.path.exists(out_dir): os.makedirs(out_dir) if args.save_temp: temp_dir = os.path.join(out_dir, 'temp') if not os.path.exists(temp_dir): os.makedirs(temp_dir) for idx in range(0, len(dataset_train)): sample = dataset_train.get_sample(idx) sitk.WriteImage( sample['source'], os.path.join(temp_dir, 'sample_' + str(idx) + '_source.nii.gz')) sitk.WriteImage( sample['target'], os.path.join(temp_dir, 'sample_' + str(idx) + '_target.nii.gz')) sitk.WriteImage( sample['source_seg'], os.path.join(temp_dir, 'sample_' + str(idx) + '_source_seg.nii.gz')) sitk.WriteImage( sample['target_seg'], os.path.join(temp_dir, 'sample_' + str(idx) + '_target_seg.nii.gz')) print(separator) # Note: Must match those used in process_batch() loss_names = [ 'loss_itn', 'loss_stn_u', 'loss_stn_s', 'loss_stn_i', 'loss_stn_r', 'loss', 'metric_dice', 'metric_hd', 'metric_asd', 'metric_precision', 'metric_recall' ] train_logger = mira_metrics.Logger('TRAIN', loss_names) validation_logger = mira_metrics.Logger('VALID', loss_names) model_dir = os.path.join(out_dir, 'model') if not os.path.exists(model_dir): os.makedirs(model_dir) for epoch in range(1, config.config['epochs'] + 1): config.stn.train() config.itn.train() # Training for batch_idx, batch_samples in enumerate( tqdm(dataloader_train, desc='Epoch {}'.format(epoch))): global_step += 1 config.optimizer.zero_grad() loss, images_dict, values_dict = process_batch( config, config.itn, config.stn, batch_samples) loss.backward() config.optimizer.step() train_logger.update_epoch_logger(values_dict) train_logger.update_epoch_summary(epoch) write_values(writer, 'train', value_dict=train_logger.get_latest_dict(), n_iter=global_step) write_images(writer, 'train', image_dict=images_dict, n_iter=global_step, mode3d=args.mode3d) # Validation if args.val is not None and (epoch == 1 or epoch % config.config['val_interval'] == 0): config.stn.eval() config.itn.eval() with torch.no_grad(): for batch_idx, batch_samples in enumerate(dataloader_val): loss, images_dict, values_dict = process_batch( config, config.itn, config.stn, batch_samples) validation_logger.update_epoch_logger(values_dict) validation_logger.update_epoch_summary(epoch) write_values(writer, phase='val', value_dict=validation_logger.get_latest_dict(), n_iter=global_step) write_images(writer, phase='val', image_dict=images_dict, n_iter=global_step, mode3d=args.mode3d) print(separator) train_logger.print_latest() validation_logger.print_latest() print(separator) torch.save(config.itn.state_dict(), model_dir + '/itn_' + str(epoch) + '.pt') torch.save(config.stn.state_dict(), model_dir + '/stn_' + str(epoch) + '.pt') torch.save(config.itn.state_dict(), model_dir + '/itn.pt') torch.save(config.stn.state_dict(), model_dir + '/stn.pt') print(separator) print('Finished TRAINING... Plotting Graphs\n\n') for loss_name, colour in zip(['loss'], ['b']): plt.plot(train_logger.epoch_number_logger, train_logger.epoch_summary[loss_name], c=colour, label='train {}'.format(loss_name)) plt.plot(validation_logger.epoch_number_logger, validation_logger.epoch_summary[loss_name], c=colour, linestyle=':', label='val {}'.format(loss_name)) plt.legend(loc='upper right') plt.xlabel('epoch') plt.ylabel('loss') plt.show()