def testing():

    c = get_config()

    c.do_load_checkpoint = True
    #c.checkpoint_dir = c.base_dir + '/20190424-020641_unet_experiment' + '/checkpoint/checkpoint_current' # dice_cost train
    # c.checkpoint_dir = c.base_dir + '/20190424-234657_unet_experiment' + '/checkpoint/checkpoint_last' # SDG
    c.checkpoint_dir = c.base_dir + '/20190906-085449_unet_experiment' + '/checkpoint/checkpoint_current'
    # c.checkpoint_file = "checkpoint_last.pth.tar"
    # c.cross_vali_index = valiIndex


    cross_vali_result_all_dir = os.path.join(c.base_dir, c.dataset_name
                                             + '_' + str(
        c.batch_size) + c.cross_vali_result_all_dir + datetime.datetime.now().strftime("_%Y%m%d-%H%M%S"))
    if not os.path.exists(cross_vali_result_all_dir):
        os.makedirs(cross_vali_result_all_dir)
        print('Created' + cross_vali_result_all_dir + '...')
        c.base_dir = cross_vali_result_all_dir
        c.cross_vali_result_all_dir = os.path.join(cross_vali_result_all_dir, "results")
        os.makedirs(c.cross_vali_result_all_dir)


    exp = UNetExperiment(config=c, name='unet_test', n_epochs=c.n_epochs,
                               seed=42, globs=globals())
    exp.run_test(setup=True)
def training():

    c = get_config()

    dataset_name = 'Task04_Hippocampus'
    # dataset_name = 'Task02_Heart'
    # download_dataset(dest_path=c.data_root_dir, dataset=dataset_name, id=c.google_drive_id)
    # c.do_load_checkpoint = True
    # c.checkpoint_dir = c.base_dir + '/20190801-_unet_experiment' + '/checkpoint/checkpoint_current'
    # c.checkpoint_file = "checkpoint_last.pth.tar"

    if not exists(os.path.join(os.path.join(c.data_root_dir, dataset_name), 'preprocessed')):
        print('Preprocessing data. [STARTED]')
        preprocess_data(root_dir=os.path.join(c.data_root_dir, dataset_name))
        create_splits(output_dir=c.split_dir, image_dir=c.data_dir)
        print('Preprocessing data. [DONE]')
    else:
        print('The data has already been preprocessed. It will not be preprocessed again. Delete the folder to enforce it.')


    exp = UNetExperiment(config=c, name='unet_experiment', n_epochs=c.n_epochs,
                        seed=42, append_rnd_to_name=c.append_rnd_string)   # visdomlogger_kwargs={"auto_start": c.start_visdom}

    exp.run()
    exp.run_test(setup=False)
def preprocess_data(root_dir):  #y_shape=64, z_shape=64):
    c = get_config()
    image_dir = os.path.join(root_dir, 'imagesTr')
    label_dir = os.path.join(root_dir, 'labelsTr')
    output_dir = os.path.join(root_dir, 'preprocessed')
    classes = c.num_classes

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print('Created' + output_dir + '...')

    class_stats = defaultdict(int)
    total = 0

    nii_files = subfiles(image_dir, suffix=".nii.gz", join=False)

    for i in range(0, len(nii_files)):
        if nii_files[i].startswith("._"):
            nii_files[i] = nii_files[i][2:]

    for f in nii_files:
        image, _ = load(os.path.join(image_dir, f))
        label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))

        print(f)

        for i in range(classes):
            class_stats[i] += np.sum(label == i)
            total += np.sum(label == i)

        # normalize images
        image = (image - image.min()) / (image.max() - image.min())

        image = np.swapaxes(image, 0, 2)
        label = np.swapaxes(label, 0, 2)

        result = reshape(np.stack([image, label], axis=0),
                         crop_size=c.patch_size)

        # image = reshape(image, append_value=0, new_shape=(image.shape[0], y_shape, z_shape))
        # label = reshape(label, append_value=0, new_shape=(label.shape[0], y_shape, z_shape))

        np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result)
        print(f)

    print(total)
    for i in range(classes):
        print(class_stats[i], class_stats[i] / total)
class LabelTensorToColor(object):
   def __call__(self, label):
       label = label.squeeze()
       colored_label = torch.zeros(3, label.size(0), label.size(1)).byte()
       for i, color in enumerate(class_color):
           mask = label.eq(i)
           for j in range(3):
               colored_label[j].masked_fill_(mask, color[j])

       return colored_label

color_class_converter = LabelTensorToColor()

# Load data
c = get_config()
exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals())
exp.setup()
exp.test_data_loader.do_reshuffle = True

# Load checkpoint
checkpoint = torch.load('./output_experiment/20190604-154315_fine_tune_spleen_for_heart/checkpoint/checkpoint_last.pth.tar')
exp.model.load_state_dict(checkpoint['model'])

exp.model.eval()

batch_counter = 0
with torch.no_grad():
    for data_batch in exp.test_data_loader:
        # Get data_batches