コード例 #1
0
def test(args):
    config = set_up_model_and_preprocessing('TESTING', args)

    dataset_test = ImageSegRegDataset(args.test,
                                      args.test_seg,
                                      args.test_msk,
                                      normalizer_img=config.normalizer_img,
                                      normalizer_seg=config.normalizer_seg,
                                      resampler_img=config.resampler_img,
                                      resampler_seg=config.resampler_seg)
    dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                  batch_size=1,
                                                  shuffle=False)
    loss_names = [
        'loss_itn', 'loss_stn_u', 'loss_stn_s', 'loss_stn_i', 'loss_stn_r',
        'loss', 'metric_dice', 'metric_hd', 'metric_asd', 'metric_precision',
        'metric_recall'
    ]
    test_logger = mira_metrics.Logger('TEST', loss_names)

    # Create output directory
    out_dir = os.path.join(args.out, 'test')
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    config.itn.load_state_dict(torch.load(args.model + '/itn.pt'))
    config.itn.eval()

    config.stn.load_state_dict(torch.load(args.model + '/stn.pt'))
    config.stn.eval()

    with torch.no_grad():
        for index, batch_samples in enumerate(dataloader_test):
            loss, images_dict, values_dict = process_batch(
                config, config.itn, config.stn, batch_samples)
            test_logger.update_epoch_logger(values_dict)

            # source_transformed = sitk.GetImageFromArray(images_dict['source_prime'].cpu().squeeze().numpy().astype(np.uint8))
            # print(images_dict['source_prime'].cpu().squeeze().numpy().astype(np.uint8).shape)
            # print(images_dict['source_prime'].cpu().squeeze().numpy().astype(np.uint8).dtype)
            # print(images_dict['source_prime'].cpu().squeeze().numpy().max())
            # source_transformed.CopyInformation(dataset_test.get_sample(index)['source'])
            # sitk.WriteImage(source_transformed,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_source_prime.nii.gz'))
            # sitk.WriteImage(source_transformed,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_source_prime.png'))
            # target_transformed = sitk.GetImageFromArray(images_dict['target_prime'].cpu().squeeze().numpy())
            # target_transformed.CopyInformation(dataset_test.get_sample(index)['target'])
            # sitk.WriteImage(target_transformed,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_target_prime.nii.gz'))
            # sitk.WriteImage(target_transformed,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_target_prime.png'))
            warped_source = sitk.GetImageFromArray(
                images_dict['warped_source'].cpu().squeeze().numpy().astype(
                    np.uint8))
            warped_source.CopyInformation(
                dataset_test.get_sample(index)['target'])
            # sitk.WriteImage(warped_source,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_warped_source.nii.gz'))
            sitk.WriteImage(
                warped_source,
                os.path.join(out_dir,
                             'sample_' + str(index) + '_warped_source.png'))
            # warped_source_seg = sitk.GetImageFromArray(images_dict['warped_source_seg'].cpu().squeeze().numpy())
            # warped_source_seg.CopyInformation(dataset_test.get_sample(index)['target'])
            # sitk.WriteImage(warped_source_seg,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_warped_source_seg.nii.gz'))
            # sitk.WriteImage(warped_source_seg,
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_warped_source_seg.png'))
            # sitk.WriteImage(dataset_test.get_sample(index)['source'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_source.nii.gz'))
            # sitk.WriteImage(dataset_test.get_sample(index)['source'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_source.png'))
            # sitk.WriteImage(dataset_test.get_sample(index)['target'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_target.nii.gz'))
            # sitk.WriteImage(dataset_test.get_sample(index)['target'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_target.png'))
            # sitk.WriteImage(dataset_test.get_sample(index)['source_seg'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_source_seg.nii.gz'))
            # sitk.WriteImage(dataset_test.get_sample(index)['source_seg'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_source_seg.png'))
            # sitk.WriteImage(dataset_test.get_sample(index)['target_seg'],
            #                 os.path.join(out_dir, 'sample_' + str(index) + '_target_seg.nii.gz'))
            # sitk.WriteImage(dataset_test.get_sams'sample_' + str(index) + '_target_seg.png'))
    #     with open(os.path.join(out_dir,'test_results.yml'), 'w') as outfile:
    #         yaml.dump(test_logger.get_epoch_logger(), outfile)
    # test_logger.update_epoch_summary(0)

    if args.no_refine == False:
        refine_config = set_up_model_and_preprocessing('REFINEMENT', args)
        config.itn.eval()

        for index, batch_samples in enumerate(dataloader_test):

            print('Processing image ' + str(index + 1) + ' of ' +
                  str(len(dataset_test)))

            # Set up fine tuning network to have grads but not the stn
            refine_config.stn.load_state_dict(
                torch.load(args.model + '/stn.pt'))
            refine_config.stn.train()

            optimizer = torch.optim.Adam(
                refine_config.stn.parameters(),
                lr=refine_config.config['learning_rate'])

            # Fine tune STN
            for epoch in range(1, config.config['refine'] + 1):
                optimizer.zero_grad()
                _loss, images_dict, values_dict = process_batch(
                    config, config.itn, refine_config.stn, batch_samples)
                loss = values_dict['loss_stn_r']
                loss.backward()
                optimizer.step()

            with torch.no_grad():
                loss, images_dict, values_dict = process_batch(
                    config, config.itn, refine_config.stn, batch_samples)
                test_logger.update_epoch_logger(values_dict)

                warped_source = sitk.GetImageFromArray(
                    images_dict['warped_source'].cpu().squeeze().numpy(
                    ).astype(np.uint8))
                warped_source.CopyInformation(
                    dataset_test.get_sample(index)['target'])
                # sitk.WriteImage(warped_source,
                #                 os.path.join(out_dir, 'sample_' + str(index) + '_warped_source_refined.nii.gz'))
                sitk.WriteImage(
                    warped_source,
                    os.path.join(
                        out_dir,
                        'sample_' + str(index) + '_warped_source_refined.png'))
コード例 #2
0
def train(args):
    config = set_up_model_and_preprocessing('TRAINING', args)

    writer = SummaryWriter('{}/tensorboard'.format(args.out))
    global_step = 0

    print(separator)
    print('TRAINING data...')
    print(separator)

    dataset_train = ImageSegRegDataset(args.train,
                                       args.train_seg,
                                       args.train_msk,
                                       normalizer_img=config.normalizer_img,
                                       normalizer_seg=config.normalizer_seg,
                                       resampler_img=config.resampler_img,
                                       resampler_seg=config.resampler_seg)
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=config.config['batch_size'], shuffle=True)

    if args.val is not None:
        print(separator)
        print('VALIDATION data...')
        print(separator)
        dataset_val = ImageSegRegDataset(args.val,
                                         args.val_seg,
                                         args.val_msk,
                                         normalizer_img=config.normalizer_img,
                                         normalizer_seg=config.normalizer_seg,
                                         resampler_img=config.resampler_img,
                                         resampler_seg=config.resampler_seg)
        dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                     batch_size=1,
                                                     shuffle=False)

    # Create output directory
    out_dir = os.path.join(args.out, 'train')
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if args.save_temp:
        temp_dir = os.path.join(out_dir, 'temp')
        if not os.path.exists(temp_dir):
            os.makedirs(temp_dir)
        for idx in range(0, len(dataset_train)):
            sample = dataset_train.get_sample(idx)
            sitk.WriteImage(
                sample['source'],
                os.path.join(temp_dir,
                             'sample_' + str(idx) + '_source.nii.gz'))
            sitk.WriteImage(
                sample['target'],
                os.path.join(temp_dir,
                             'sample_' + str(idx) + '_target.nii.gz'))
            sitk.WriteImage(
                sample['source_seg'],
                os.path.join(temp_dir,
                             'sample_' + str(idx) + '_source_seg.nii.gz'))
            sitk.WriteImage(
                sample['target_seg'],
                os.path.join(temp_dir,
                             'sample_' + str(idx) + '_target_seg.nii.gz'))

    print(separator)

    # Note: Must match those used in process_batch()
    loss_names = [
        'loss_itn', 'loss_stn_u', 'loss_stn_s', 'loss_stn_i', 'loss_stn_r',
        'loss', 'metric_dice', 'metric_hd', 'metric_asd', 'metric_precision',
        'metric_recall'
    ]
    train_logger = mira_metrics.Logger('TRAIN', loss_names)
    validation_logger = mira_metrics.Logger('VALID', loss_names)

    model_dir = os.path.join(out_dir, 'model')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    for epoch in range(1, config.config['epochs'] + 1):
        config.stn.train()
        config.itn.train()

        # Training
        for batch_idx, batch_samples in enumerate(
                tqdm(dataloader_train, desc='Epoch {}'.format(epoch))):
            global_step += 1
            config.optimizer.zero_grad()
            loss, images_dict, values_dict = process_batch(
                config, config.itn, config.stn, batch_samples)
            loss.backward()
            config.optimizer.step()
            train_logger.update_epoch_logger(values_dict)

        train_logger.update_epoch_summary(epoch)
        write_values(writer,
                     'train',
                     value_dict=train_logger.get_latest_dict(),
                     n_iter=global_step)
        write_images(writer,
                     'train',
                     image_dict=images_dict,
                     n_iter=global_step,
                     mode3d=args.mode3d)

        # Validation
        if args.val is not None and (epoch == 1 or epoch %
                                     config.config['val_interval'] == 0):
            config.stn.eval()
            config.itn.eval()

            with torch.no_grad():
                for batch_idx, batch_samples in enumerate(dataloader_val):
                    loss, images_dict, values_dict = process_batch(
                        config, config.itn, config.stn, batch_samples)
                    validation_logger.update_epoch_logger(values_dict)

            validation_logger.update_epoch_summary(epoch)
            write_values(writer,
                         phase='val',
                         value_dict=validation_logger.get_latest_dict(),
                         n_iter=global_step)
            write_images(writer,
                         phase='val',
                         image_dict=images_dict,
                         n_iter=global_step,
                         mode3d=args.mode3d)

            print(separator)
            train_logger.print_latest()
            validation_logger.print_latest()
            print(separator)

            torch.save(config.itn.state_dict(),
                       model_dir + '/itn_' + str(epoch) + '.pt')
            torch.save(config.stn.state_dict(),
                       model_dir + '/stn_' + str(epoch) + '.pt')

    torch.save(config.itn.state_dict(), model_dir + '/itn.pt')
    torch.save(config.stn.state_dict(), model_dir + '/stn.pt')

    print(separator)
    print('Finished TRAINING... Plotting Graphs\n\n')
    for loss_name, colour in zip(['loss'], ['b']):
        plt.plot(train_logger.epoch_number_logger,
                 train_logger.epoch_summary[loss_name],
                 c=colour,
                 label='train {}'.format(loss_name))
        plt.plot(validation_logger.epoch_number_logger,
                 validation_logger.epoch_summary[loss_name],
                 c=colour,
                 linestyle=':',
                 label='val {}'.format(loss_name))

    plt.legend(loc='upper right')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()