def evaluate_saved_model(model_config,
                         split='validation',
                         model_path=None,
                         data_path=None,
                         save_directory=None,
                         save_nii=False,
                         save_npz=False):
    # Load options
    json_opts = json_file_to_pyobj(model_config)
    train_opts = json_opts.training
    model_opts = json_opts.model
    data_path_opts = json_opts.data_path

    if model_path is not None:
        model_opts = json_opts.model._replace(
            path_pre_trained_model=model_path)

    model_opts = model_opts._replace(gpu_ids=[])

    # Setup the NN Model
    model = get_model(model_opts)
    if save_directory is None:
        save_directory = os.path.join(os.path.dirname(model_config),
                                      split + '_evaluation')
    mkdir(save_directory)

    # Setup Dataset and Augmentation
    ds_class = get_dataset(train_opts.arch_type)
    if data_path is None:
        data_path = get_dataset_path(train_opts.arch_type, data_path_opts)
    dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                                   opts=json_opts.augmentation)

    # Setup channels
    channels = json_opts.data_opts.channels
    if len(channels) != json_opts.model.input_nc \
            or len(channels) != getattr(json_opts.augmentation, train_opts.arch_type).scale_size[-1]:
        raise Exception(
            'Number of data channels must match number of model channels, and patch and scale size dimensions'
        )

    # Setup Data Loader
    split_opts = json_opts.data_split
    dataset = ds_class(data_path,
                       split=split,
                       transform=dataset_transform['valid'],
                       preload_data=train_opts.preloadData,
                       train_size=split_opts.train_size,
                       test_size=split_opts.test_size,
                       valid_size=split_opts.validation_size,
                       split_seed=split_opts.seed,
                       channels=channels)
    data_loader = DataLoader(dataset=dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=False)

    # Visualisation Parameters
    # visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)

    # Setup stats logger
    stat_logger = StatLogger()

    if save_npz:
        all_predicted = []

    # test
    for iteration, data in tqdm(enumerate(data_loader, 1)):
        model.set_input(data[0], data[1])
        model.test()

        input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32)
        prior_arr = np.squeeze(data[0].cpu().numpy())[5].astype(np.int16)
        prior_arr[prior_arr > 0] = 1
        label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16)
        ids = dataset.get_ids(data[2])
        output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
            np.int16)

        # If there is a label image - compute statistics
        dice_vals = dice_score(label_arr, output_arr, n_class=int(2))
        single_class_dice = single_class_dice_score(label_arr, output_arr)
        md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=1)
        precision, recall = precision_and_recall(label_arr,
                                                 output_arr,
                                                 n_class=int(2))
        sp = specificity(label_arr, output_arr)
        jaccard = jaccard_score(label_arr.flatten(), output_arr.flatten())

        # compute stats for the prior that is used
        prior_dice = single_class_dice_score(label_arr, prior_arr)
        prior_precision, prior_recall = precision_and_recall(label_arr,
                                                             prior_arr,
                                                             n_class=int(2))

        stat_logger.update(split=split,
                           input_dict={
                               'img_name': ids[0],
                               'dice_bg': dice_vals[0],
                               'dice_les': dice_vals[1],
                               'dice2_les': single_class_dice,
                               'prec_les': precision[1],
                               'reca_les': recall[1],
                               'specificity': sp,
                               'md_les': md,
                               'hd_les': hd,
                               'jaccard': jaccard,
                               'dice_prior': prior_dice,
                               'prec_prior': prior_precision[1],
                               'reca_prior': prior_recall[1]
                           })

        if save_nii:
            # Write a nifti image
            import SimpleITK as sitk
            input_img = sitk.GetImageFromArray(
                np.transpose(input_arr[0], (2, 1, 0)))
            input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            cbf_img = sitk.GetImageFromArray(
                np.transpose(input_arr[1], (2, 1, 0)))
            cbf_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            prior_img = sitk.GetImageFromArray(
                np.transpose(input_arr[5], (2, 1, 0)))
            prior_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            label_img = sitk.GetImageFromArray(
                np.transpose(label_arr, (2, 1, 0)))
            label_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            predi_img = sitk.GetImageFromArray(
                np.transpose(output_arr, (2, 1, 0)))
            predi_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])

            sitk.WriteImage(
                input_img,
                os.path.join(save_directory,
                             '{}_img.nii.gz'.format(iteration)))
            sitk.WriteImage(
                cbf_img,
                os.path.join(save_directory,
                             '{}_cbf.nii.gz'.format(iteration)))
            sitk.WriteImage(
                prior_img,
                os.path.join(save_directory,
                             '{}_prior.nii.gz'.format(iteration)))
            sitk.WriteImage(
                label_img,
                os.path.join(save_directory,
                             '{}_lbl.nii.gz'.format(iteration)))
            sitk.WriteImage(
                predi_img,
                os.path.join(save_directory,
                             '{}_pred.nii.gz'.format(iteration)))

        if save_npz:
            all_predicted.append(output_arr)

    stat_logger.statlogger2csv(split=split,
                               out_csv_name=os.path.join(
                                   save_directory, split + '_stats.csv'))
    for key, (mean_val,
              std_val) in stat_logger.get_errors(split=split).items():
        print('-', key, ': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val), '-')

    if save_npz:
        np.savez_compressed(os.path.join(save_directory, 'predictions.npz'),
                            predicted=np.array(all_predicted))
def train(arguments):

    # Parse input arguments
    json_filename = arguments.config
    network_debug = arguments.debug

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    train_opts = json_opts.training

    # Architecture type
    arch_type = train_opts.arch_type

    # Setup Dataset and Augmentation
    ds_class = get_dataset(arch_type)
    ds_path = get_dataset_path(arch_type, json_opts.data_path)
    ds_transform = get_dataset_transformation(
        arch_type,
        opts=json_opts.augmentation,
        max_output_channels=json_opts.model.output_nc,
        verbose=json_opts.training.verbose)

    # Setup channels
    channels = json_opts.data_opts.channels
    if len(channels) != json_opts.model.input_nc \
            or len(channels) != getattr(json_opts.augmentation, arch_type).scale_size[-1]:
        raise Exception(
            'Number of data channels must match number of model channels, and patch and scale size dimensions'
        )

    # Setup the NN Model
    model = get_model(json_opts.model)
    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    split_opts = json_opts.data_split
    train_dataset = ds_class(ds_path,
                             split='train',
                             transform=ds_transform['train'],
                             preload_data=train_opts.preloadData,
                             train_size=split_opts.train_size,
                             test_size=split_opts.test_size,
                             valid_size=split_opts.validation_size,
                             split_seed=split_opts.seed,
                             channels=channels)
    valid_dataset = ds_class(ds_path,
                             split='validation',
                             transform=ds_transform['valid'],
                             preload_data=train_opts.preloadData,
                             train_size=split_opts.train_size,
                             test_size=split_opts.test_size,
                             valid_size=split_opts.validation_size,
                             split_seed=split_opts.seed,
                             channels=channels)
    test_dataset = ds_class(ds_path,
                            split='test',
                            transform=ds_transform['valid'],
                            preload_data=train_opts.preloadData,
                            train_size=split_opts.train_size,
                            test_size=split_opts.test_size,
                            valid_size=split_opts.validation_size,
                            split_seed=split_opts.seed,
                            channels=channels)
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=16,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=16,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=16,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    # Setup Early Stopping
    early_stopper = EarlyStopper(json_opts.training.early_stopping,
                                 verbose=json_opts.training.verbose)
    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
        train_volumes = []
        validation_volumes = []

        # Training Iterations
        for epoch_iter, (images, labels,
                         indices) in tqdm(enumerate(train_loader, 1),
                                          total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            #model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

            ids = train_dataset.get_ids(indices)
            volumes = model.get_current_volumes()
            visualizer.display_current_volumes(volumes, ids, 'train', epoch)
            train_volumes.append(volumes)

        # Validation and Testing Iterations
        for loader, split, dataset in zip([valid_loader, test_loader],
                                          ['validation', 'test'],
                                          [valid_dataset, test_dataset]):
            for epoch_iter, (images, labels,
                             indices) in tqdm(enumerate(loader, 1),
                                              total=len(loader)):
                ids = dataset.get_ids(indices)

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split=split)

                if split == 'validation':  # do not look at testing
                    # Visualise predictions
                    volumes = model.get_current_volumes()
                    visualizer.display_current_volumes(volumes, ids, split,
                                                       epoch)
                    validation_volumes.append(volumes)

                    # Track validation loss values
                    early_stopper.update({**errors, **stats})

        # Update the plots
        for split in ['train', 'validation', 'test']:
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        visualizer.save_plots(epoch, save_frequency=5)
        error_logger.reset()

        # Save the model parameters
        if not early_stopper.is_improving is False:
            model.save(json_opts.model.model_type, epoch)
            save_config(json_opts, json_filename, model, epoch)

        # Update the model learning rate
        model.update_learning_rate(
            metric=early_stopper.get_current_validation_loss())

        if early_stopper.interrogate(epoch):
            break