Exemplo n.º 1
0
def load_model(
    inlier_name="cifar10",
    checkpoint=-1,
    save_path="saved_models/",
    filters=128,
    batch_size=1000,
    split="100,0",
    sigma_high=1,
    num_L=10,
    class_label="all",
):
    args = get_command_line_args([
        "--checkpoint_dir=" + save_path,
        "--filters=" + str(filters),
        "--dataset=" + inlier_name,
        "--sigma_low=0.01",
        "--sigma_high=" + str(sigma_high),
        "--resume_from=" + str(checkpoint),
        "--batch_size=" + str(batch_size),
        "--split=" + split,
        "--num_L=" + str(num_L),
        "--class_label=" + str(class_label),
    ])
    configs.config_values = args

    sigmas = utils.get_sigma_levels().numpy()
    (
        save_dir,
        complete_model_name,
    ) = (utils.get_savemodel_dir()
         )  # "longleaf_models/baseline64_fashion_mnist_SL0.001", ""
    model, optimizer, step, _, _ = utils.try_load_model(
        save_dir, step_ckpt=configs.config_values.resume_from, verbose=True)
    return model, args
Exemplo n.º 2
0
def main():
    save_dir, complete_model_name = utils.get_savemodel_dir()
    model, optimizer, step = utils.try_load_model(save_dir, step_ckpt=configs.config_values.resume_from, verbose=True)
    start_time = datetime.now().strftime("%y%m%d-%H%M%S")

    sigma_levels = utils.get_sigma_levels()

    samples_directory = './samples/{}_{}_step{}/'.format(start_time, complete_model_name, step)

    if not os.path.exists(samples_directory):
        os.makedirs(samples_directory)

    x0 = utils.get_init_samples()
    sample_and_save(model, sigma_levels, x=x0, n_images=100, T=100, eps=2*1e-5, save_directory=samples_directory)
Exemplo n.º 3
0
def compute_scores(model, xs, masked_input=False):
    scores = []
    sigmas = utils.get_sigma_levels()
    for x in tqdm(xs):
        per_sigma_scores = []
        for idx, sigma_val in enumerate(sigmas):
            sigma = idx * tf.ones([x.shape[0]], dtype=tf.dtypes.int32)
            score = model([x, sigma]) * sigma_val
            # score = score ** 2
            per_sigma_scores.append(score)
        scores.append(tf.stack(per_sigma_scores, axis=1))

    # N x WxH x L Matrix of score norms
    scores = tf.squeeze(tf.concat(scores, axis=0))
    return scores
Exemplo n.º 4
0
def main():
    save_dir, complete_model_name = utils.get_savemodel_dir()
    model, optimizer, step = utils.try_load_model(
        save_dir, step_ckpt=configs.config_values.resume_from, verbose=True)
    start_time = datetime.now().strftime("%y%m%d-%H%M%S")

    sigma_levels = utils.get_sigma_levels()

    k = configs.config_values.k
    samples_directory = './samples/{}_{}_step{}_{}nearest/'.format(
        start_time, complete_model_name, step, k)

    if not os.path.exists(samples_directory):
        os.makedirs(samples_directory)

    n_images = 10  # TODO make this not be hard-coded
    samples = sample_many(model,
                          sigma_levels,
                          batch_size=configs.config_values.batch_size,
                          T=100,
                          n_images=n_images)

    if configs.config_values.dataset == 'celeb_a':
        data = get_celeb_a32()
    else:
        data = get_data_k_nearest(configs.config_values.dataset)
        data = data.batch(int(tf.data.experimental.cardinality(data)))
        # data = tf.data.experimental.get_single_element(data)

    images = []
    data_subsets = []
    for i, sample in enumerate(samples):
        for data_batch in data:
            k_closest_images, _ = utils.find_k_closest(sample, k, data_batch)
            data_subsets.append(k_closest_images)
        # data = tf.convert_to_tensor(data_subsets)
        # k_closest_images, smallest_idx = utils.find_k_closest(sample, k, data)
        # save_image(sample[0, :, :, 0], samples_directory + f'sample_{i}')
        # k_closest_images, smallest_idx = utils.find_k_closest(sample, configs.config_values.k, data_as_array)
        # for j, img in enumerate(k_closest_images):
        # save_image(img[0, :, :, 0], samples_directory + f'sample_{i}_closest_{j}')

        # print(smallest_idx)
        images.append([sample, k_closest_images])

    save_as_grid_closest_k(images,
                           samples_directory + "k_closest_grid.png",
                           spacing=5)
Exemplo n.º 5
0
def compute_scores_ncsnv2(model, x_test):

    # Sigma Idx -> Score
    score_dict = []

    sigmas = utils.get_sigma_levels().numpy()
    final_logits = 0  # tf.zeros(logits_shape)
    progress_bar = tqdm(sigmas)
    for idx, sigma in enumerate(progress_bar):

        progress_bar.set_description("Sigma: {:.4f}".format(sigma))
        _logits = []

        for x_batch in x_test:
            sigma_val = tf.ones(
                (x_batch.shape[0], 1, 1, 1), dtype=tf.float32) * sigma
            score = model([x_batch, sigma_val])
            _logits.append(score)

        _logits = tf.concat(_logits, axis=0)
        score_dict.append(tf.identity(_logits))

    return tf.stack(score_dict, axis=0)
Exemplo n.º 6
0
def compute_batched_score_norms(model, x_test, masked_input=False, seed=None):
    # Sigma Idx -> Score
    score_dict = []
    masks_arr = []
    sigmas = utils.get_sigma_levels()
    input_shape = utils.get_dataset_image_size(configs.config_values.dataset)
    channels = input_shape[-1]
    progress_bar = tqdm(sigmas)
    for idx, sigma in enumerate(progress_bar):

        progress_bar.set_description("Sigma: {:.4f}".format(sigma))
        _logits = []
        if seed:
            tf.random.set_seed(seed)
            np.random.seed(seed)

        for x_batch in x_test:
            idx_sigmas = tf.ones(x_batch.shape[0], dtype=tf.int32) * idx
            score = model([x_batch, idx_sigmas]) * sigma

            if masked_input:
                _, masks = tf.split(x_batch, (channels - 1, 1), axis=-1)
                score = score * masks

            score = reduce_norm(score)
            _logits.append(score)
        score_dict.append(tf.identity(tf.concat(_logits, axis=0)))

    # N x L Matrix of score norms
    scores = tf.squeeze(tf.stack(score_dict, axis=1))

    # if masked_input:
    #     masks_arr = np.concatenate(masks_arr, axis=0)
    #     return dict(scores=scores.numpy(), masks=masks_arr)

    return scores.numpy()
Exemplo n.º 7
0
def main():
    device = utils.get_tensorflow_device()
    tf.random.set_seed(2019)

    # load dataset from tfds (or use downloaded version if exists)
    train_data = get_train_test_data(configs.config_values.dataset)[0]

    # split data into batches
    train_data = train_data.shuffle(buffer_size=10000)
    if configs.config_values.dataset != 'celeb_a':
        train_data = train_data.batch(configs.config_values.batch_size)
    train_data = train_data.repeat()
    train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    # path for saving the model(s)
    save_dir, complete_model_name = utils.get_savemodel_dir()

    start_time = datetime.now().strftime("%y%m%d-%H%M%S")

    # array of sigma levels
    # generate geometric sequence of values between sigma_low (0.01) and sigma_high (1.0)
    sigma_levels = utils.get_sigma_levels()

    model, optimizer, step = utils.try_load_model(
        save_dir, step_ckpt=configs.config_values.resume_from, verbose=True)

    total_steps = configs.config_values.steps
    progress_bar = tqdm(train_data, total=total_steps, initial=step + 1)
    progress_bar.set_description('iteration {}/{} | current loss ?'.format(
        step, total_steps))

    loss_history = []
    with tf.device(device):  # For some reason, this makes everything faster
        avg_loss = 0
        for data_batch in progress_bar:
            step += 1
            idx_sigmas = tf.random.uniform([data_batch.shape[0]],
                                           minval=0,
                                           maxval=configs.config_values.num_L,
                                           dtype=tf.dtypes.int32)
            sigmas = tf.gather(sigma_levels, idx_sigmas)
            sigmas = tf.reshape(sigmas, shape=(data_batch.shape[0], 1, 1, 1))
            data_batch_perturbed = data_batch + tf.random.normal(
                shape=data_batch.shape) * sigmas

            current_loss = train_one_step(model, optimizer,
                                          data_batch_perturbed, data_batch,
                                          idx_sigmas, sigmas)
            loss_history.append([step, current_loss.numpy()])

            progress_bar.set_description(
                'iteration {}/{} | current loss {:.3f}'.format(
                    step, total_steps, current_loss))

            avg_loss += current_loss
            if step % configs.config_values.checkpoint_freq == 0:
                # Save checkpoint
                ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                           optimizer=optimizer,
                                           model=model)
                ckpt.step.assign_add(step)
                ckpt.save(save_dir + "{}_step_{}".format(start_time, step))
                # Append in csv file
                with open(save_dir + 'loss_history.csv', mode='a',
                          newline='') as csv_file:
                    writer = csv.writer(csv_file, delimiter=';')
                    writer.writerows(loss_history)
                print("\nSaved checkpoint. Average loss: {:.3f}".format(
                    avg_loss / configs.config_values.checkpoint_freq))
                loss_history = []
                avg_loss = 0
            if step == total_steps:
                return
Exemplo n.º 8
0
def main():

    policy_name = "float32"

    if configs.config_values.mixed_precision:
        policy_name = "mixed_float16"

    if tf.__version__ < "2.4.0":
        policy = tf.keras.mixed_precision.experimental.Policy(policy_name)
        tf.keras.mixed_precision.experimental.set_policy(policy)
    else:
        policy = mixed_precision.Policy(policy_name)
        mixed_precision.set_global_policy(policy)

    strategy = tf.distribute.MirroredStrategy()

    # device = utils.get_tensorflow_device()
    tf.random.set_seed(2019)
    BATCH_SIZE_PER_REPLICA = configs.config_values.batch_size
    NUM_REPLICAS = strategy.num_replicas_in_sync
    GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * NUM_REPLICAS
    LOG_FREQ = 100
    LOG_FREQ = configs.config_values.log_freq
    configs.config_values.global_batch_size = GLOBAL_BATCH_SIZE

    # array of sigma levels
    # generate geometric sequence of values between sigma_low (0.01) and sigma_high (1.0)
    SIGMA_LEVELS = utils.get_sigma_levels()
    NUM_L = configs.config_values.num_L

    # path for saving the model(s)
    save_dir, complete_model_name = utils.get_savemodel_dir()

    # Swapping to EMA weights for evaluation/checkpoints
    def swap_weights():
        if ema.in_use == False:
            ema.in_use = True
            ema.training_state = [
                tf.identity(x) for x in model.trainable_variables
            ]
            for var in model.trainable_variables:
                var.assign(ema.average(var))
            LOGGER.info("Swapped to EMA...")
            return

        # Else switch back to training state
        for var, var_train_state in zip(model.trainable_variables,
                                        ema.training_state):
            var.assign(var_train_state)
        ema.in_use = False
        LOGGER.info("Swapped back to training state.")
        return

    print("GPUs in use:", NUM_REPLICAS)

    with strategy.scope():
        # Create an ExponentialMovingAverage object
        ema = tf.train.ExponentialMovingAverage(decay=0.999)
        ema.in_use = False

        step = tf.Variable(0)
        model, optimizer, step, ocnn_model, ocnn_optimizer = utils.try_load_model(
            save_dir,
            step_ckpt=configs.config_values.resume_from,
            verbose=True,
            ocnn=configs.config_values.ocnn,
        )

        if configs.config_values.mixed_precision:
            print("Using mixed-prec optimizer...")
            if OLD_TF:
                optimizer = mixed_precision.experimental.LossScaleOptimizer(
                    optimizer, loss_scale="dynamic")
            else:
                optimizer = mixed_precision.LossScaleOptimizer(optimizer)

        # Checkpoint should also be under strategy
        ckpt = tf.train.Checkpoint(step=tf.Variable(step),
                                   optimizer=optimizer,
                                   model=model)

    manager = tf.train.CheckpointManager(
        ckpt,
        directory=save_dir,
        max_to_keep=configs.config_values.max_to_keep)
    step = int(step)

    ####### Training Steps #######
    train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
    test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32)

    train_step, test_step = utils.build_distributed_trainers(
        strategy,
        model,
        optimizer,
        ema,
        SIGMA_LEVELS,
        NUM_REPLICAS,
        (train_loss, test_loss),
    )

    # FIXME: "test_data" needs to be a val_data = 10% of training data

    # load dataset from tfds (or use downloaded version if exists)
    train_data, test_data = get_train_test_data(configs.config_values.dataset)

    # # split data into batches
    train_data = train_data.repeat()
    train_data = train_data.prefetch(buffer_size=AUTOTUNE)

    test_data = test_data.take(32).cache()

    train_data = strategy.experimental_distribute_dataset(train_data)
    test_data = strategy.experimental_distribute_dataset(test_data)

    start_time = datetime.now().strftime("%y%m%d-%H%M%S")
    basename = "logs/{model}/{dataset}/{time}".format(
        model=configs.config_values.model,
        dataset=configs.config_values.dataset,
        time=start_time,
    )
    train_log_dir = basename + "/train"
    test_log_dir = basename + "/test"

    total_steps = configs.config_values.steps
    progress_bar = tqdm(train_data, total=total_steps, initial=step + 1)
    progress_bar.set_description("current loss ?")

    steps_per_epoch = (
        configs.dataconfig[configs.config_values.dataset]["n_samples"] //
        configs.config_values.batch_size)

    epoch = step // steps_per_epoch

    train_summary_writer = None  # tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = None  # tf.summary.create_file_writer(test_log_dir)

    if configs.config_values.profile:
        tf.profiler.experimental.start(basename + "/profile")

    avg_loss = 0

    for data_batch in progress_bar:

        if step % steps_per_epoch == 0:
            epoch += 1

        step += 1

        # train_step = None
        # if (
        #     configs.config_values.y_cond
        #     or configs.config_values.model == "masked_refinenet"
        # ):
        #     train_step, test_step = train_one_masked_step, test_step_masked
        # else:
        #     train_step, test_step = train_one_step, test_one_step

        current_loss = train_step(data_batch)
        train_loss(current_loss)

        progress_bar.set_description(
            "[epoch {:d}] | current loss {:.3f}".format(
                epoch,
                train_loss.result().numpy()))

        if step % LOG_FREQ == 0:

            if train_summary_writer == None:
                train_summary_writer = tf.summary.create_file_writer(
                    train_log_dir)
                test_summary_writer = tf.summary.create_file_writer(
                    test_log_dir)

            with train_summary_writer.as_default():
                tf.summary.scalar("loss", train_loss.result(), step=step)

            # Swap to EMA
            swap_weights()
            for x_test in test_data:
                _loss = test_step(data_batch, NUM_L - 1)
                test_loss(_loss)

            with test_summary_writer.as_default():
                tf.summary.scalar("loss", test_loss.result(), step=step)
            swap_weights()

            # Reset metrics every epoch
            train_loss.reset_states()
            test_loss.reset_states()

        # loss_history.append([step, current_loss.numpy()])
        avg_loss += current_loss

        if step % configs.config_values.checkpoint_freq == 0:
            swap_weights()
            ckpt.step.assign(step)
            manager.save()
            swap_weights()
            # Append in csv file
            # with open(save_dir + 'loss_history.csv', mode='a', newline='') as csv_file:
            #     writer = csv.writer(csv_file, delimiter=';')
            #     writer.writerows(loss_history)

            print("\nSaved checkpoint. Average loss: {:.3f}".format(
                avg_loss / configs.config_values.checkpoint_freq))
            avg_loss = 0

        if step == total_steps:
            if configs.config_values.profile:
                tf.profiler.experimental.stop()
            return
Exemplo n.º 9
0
def main():

    #     policy = mixed_precision.Policy('mixed_float16')
    #     mixed_precision.set_global_policy(policy)

    device = utils.get_tensorflow_device()
    tf.random.set_seed(2019)
    LOG_FREQ = 100
    LOG_FREQ = configs.config_values.log_freq
    SIGMA_LEVELS = utils.get_sigma_levels()
    NUM_L = configs.config_values.num_L

    if configs.config_values.y_cond or configs.config_values.model == "masked_refinenet":
        SPLITS = utils.dict_splits[configs.config_values.dataset]

    # Create an ExponentialMovingAverage object
    ema = tf.train.ExponentialMovingAverage(decay=0.999)
    ema.in_use = False

    # Swapping to EMA weights for evaluation/checkpoints
    def swap_weights():
        if ema.in_use == False:
            ema.in_use = True
            ema.training_state = [
                tf.identity(x) for x in model.trainable_variables
            ]
            for var in model.trainable_variables:
                var.assign(ema.average(var))
            LOGGER.info("Swapped to EMA...")
            return

        # Else switch back to training state
        for var, var_train_state in zip(model.trainable_variables,
                                        ema.training_state):
            var.assign(var_train_state)
        ema.in_use = False
        LOGGER.info("Swapped back to training state.")
        return

    @tf.function
    def test_one_step(model, data_batch):
        idx_sigmas = (NUM_L - 1) * tf.ones([data_batch.shape[0]],
                                           dtype=tf.dtypes.int32)
        sigmas = tf.gather(SIGMA_LEVELS, idx_sigmas)
        sigmas = tf.reshape(sigmas, shape=(data_batch.shape[0], 1, 1, 1))
        data_batch_perturbed = data_batch + tf.random.normal(
            shape=data_batch.shape) * sigmas
        scores = model([data_batch_perturbed, idx_sigmas])
        current_loss = dsm_loss(scores, data_batch_perturbed, data_batch,
                                sigmas)
        return current_loss

    @tf.function
    def train_one_step(model, optimizer, data_batch):
        idx_sigmas = tf.random.uniform([data_batch.shape[0]],
                                       minval=0,
                                       maxval=NUM_L,
                                       dtype=tf.dtypes.int32)
        sigmas = tf.gather(SIGMA_LEVELS, idx_sigmas)
        sigmas = tf.reshape(sigmas, shape=(data_batch.shape[0], 1, 1, 1))
        data_batch_perturbed = data_batch + tf.random.normal(
            shape=data_batch.shape) * sigmas

        with tf.GradientTape() as t:
            scores = model([data_batch_perturbed, idx_sigmas])
            current_loss = dsm_loss(scores, data_batch_perturbed, data_batch,
                                    sigmas)

        gradients = t.gradient(current_loss, model.trainable_variables)
        opt_op = optimizer.apply_gradients(
            zip(gradients, model.trainable_variables))

        with tf.control_dependencies([opt_op]):
            # Creates the shadow variables, and add ops to maintain moving averages
            # Also creates an op that will update the moving
            # averages after each training step
            training_op = ema.apply(model.trainable_variables)

        return current_loss

    start_time = datetime.now().strftime("%y%m%d-%H%M%S")
    basename = "logs/{model}/{dataset}/{time}".format(
        model=configs.config_values.model,
        dataset=configs.config_values.dataset,
        time=start_time)
    train_log_dir = basename + '/train'
    test_log_dir = basename + '/test'

    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)

    # load dataset from tfds (or use downloaded version if exists)
    train_data, test_data = get_train_test_data(configs.config_values.dataset)

    # # split data into batches
    train_data = train_data.shuffle(buffer_size=100)
    if configs.config_values.dataset != 'celeb_a':
        train_data = train_data.batch(configs.config_values.batch_size)
    train_data = train_data.repeat()
    train_data = train_data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    test_data = test_data.batch(configs.config_values.batch_size)
    test_data = test_data.take(2).cache()

    # path for saving the model(s)
    save_dir, complete_model_name = utils.get_savemodel_dir()
    # save_dir += "/multichannel/"

    # array of sigma levels
    # generate geometric sequence of values between sigma_low (0.01) and sigma_high (1.0)
    sigma_levels = utils.get_sigma_levels()

    model, optimizer, step, ocnn_model, ocnn_optimizer = utils.try_load_model(
        save_dir,
        step_ckpt=configs.config_values.resume_from,
        verbose=True,
        ocnn=configs.config_values.ocnn)

    # Save checkpoint
    ckpt = None
    if configs.config_values.ocnn:
        ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                   optimizer=optimizer,
                                   model=model,
                                   ocnn_model=ocnn_model,
                                   ocnn_optmizer=ocnn_optimizer)
    else:
        ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                   optimizer=optimizer,
                                   model=model)

    manager = tf.train.CheckpointManager(
        ckpt,
        directory=save_dir,
        max_to_keep=configs.config_values.max_to_keep)

    total_steps = configs.config_values.steps
    progress_bar = tqdm(train_data, total=total_steps, initial=step + 1)
    progress_bar.set_description('current loss ?')

    steps_per_epoch = configs.dataconfig[configs.config_values.dataset][
        "n_samples"] // configs.config_values.batch_size

    radius = 1.0
    loss_history = []
    epoch = step // steps_per_epoch

    train_summary_writer = None  #tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = None  #tf.summary.create_file_writer(test_log_dir)

    with tf.device(device):  # For some reason, this makes everything faster
        avg_loss = 0
        for data_batch in progress_bar:

            if step % steps_per_epoch == 0:
                epoch += 1

            step += 1

            train_step = None
            if configs.config_values.y_cond or configs.config_values.model == "masked_refinenet":
                train_step, test_step = train_one_masked_step, test_step_masked
            else:
                train_step, test_step = train_one_step, test_one_step

            current_loss = train_step(model, optimizer, data_batch)
            train_loss(current_loss)

            progress_bar.set_description(
                '[epoch {:d}] | current loss {:.3f}'.format(
                    epoch,
                    train_loss.result().numpy()))

            if step % LOG_FREQ == 0:

                if train_summary_writer == None:
                    train_summary_writer = tf.summary.create_file_writer(
                        train_log_dir)
                    test_summary_writer = tf.summary.create_file_writer(
                        test_log_dir)

                with train_summary_writer.as_default():
                    tf.summary.scalar('loss', train_loss.result(), step=step)

                # Swap to EMA
                swap_weights()
                for x_test in test_data:
                    _loss = test_step(model, data_batch)
                    test_loss(_loss)

                with test_summary_writer.as_default():
                    tf.summary.scalar('loss', test_loss.result(), step=step)
                swap_weights()

                # Reset metrics every epoch
                train_loss.reset_states()
                test_loss.reset_states()

            # loss_history.append([step, current_loss.numpy()])
            avg_loss += current_loss

            if step % configs.config_values.checkpoint_freq == 0:
                swap_weights()
                ckpt.step.assign(step)
                manager.save()
                swap_weights()
                # Append in csv file
                # with open(save_dir + 'loss_history.csv', mode='a', newline='') as csv_file:
                #     writer = csv.writer(csv_file, delimiter=';')
                #     writer.writerows(loss_history)

                print("\nSaved checkpoint. Average loss: {:.3f}".format(
                    avg_loss / configs.config_values.checkpoint_freq))
                loss_history = []
                avg_loss = 0

            if step == total_steps:
                return
Exemplo n.º 10
0
def main():
    start_time = datetime.now().strftime("%y%m%d-%H%M%S")

    # load model from checkpoint
    save_dir, complete_model_name = utils.get_savemodel_dir()
    model, optimizer, step = utils.try_load_model(
        save_dir, step_ckpt=configs.config_values.resume_from, verbose=True)

    # construct path and folder
    dataset = configs.config_values.dataset
    # samples_directory = f'./inpainting_results/{dataset}_{start_time}'
    samples_directory = './samples/{}_{}_step{}_inpainting/'.format(
        start_time, complete_model_name, step)

    if not os.path.exists(samples_directory):
        os.makedirs(samples_directory)

    # initialise sigmas
    sigma_levels = utils.get_sigma_levels()

    # TODO add these values to args
    N_to_occlude = 10  # number of images to occlude
    n_reconstructions = 8  # number of samples to generate for each occluded image
    mask_style = 'horizontal_up'  # what kind of occlusion to use
    # mask_style = 'middle'  # what kind of occlusion to use

    # load data for inpainting (currently always N first data points from test data)
    data = get_data_inpainting(configs.config_values.dataset, N_to_occlude)

    images = []

    mask = np.zeros(data.shape[1:])
    if mask_style == 'vertical_split':
        mask[:, :data.shape[2] // 2, :] += 1  # set left side to ones
    if mask_style == 'middle':
        fifth = data.shape[2] // 5
        mask[:, :2 * fifth, :] += 1  # set stripe in the middle to ones
        mask[:, -(2 * fifth):, :] += 1  # set stripe in the middle to ones
    elif mask_style == 'checkerboard':
        mask[::2, ::2, :] += 1  # set every other value to ones
    elif mask_style == 'horizontal_down':
        mask[:data.shape[1] // 2, :, :] += 1
    elif mask_style == 'horizontal_up':
        mask[data.shape[1] // 2:, :, :] += 1
    elif mask_style == 'centered':
        init_x, init_y = data.shape[1] // 4, data.shape[2] // 4
        mask += 1
        mask[init_x:3 * init_x, init_y:3 * init_y, :] -= 1
    else:
        pass  # TODO add options here

    mask = tf.convert_to_tensor(mask, dtype=tf.float32)

    for i, x in enumerate(data):
        occluded_x = x * mask
        save_dir = f'{samples_directory}/image_{i}'
        save_image(x, save_dir + '_original')
        save_image(occluded_x, save_dir + '_occluded')

    reconstructions = [[] for i in range(N_to_occlude)]
    for j in tqdm(range(n_reconstructions)):
        samples_j = inpaint_x(model, sigma_levels, mask, data, T=100)
        samples_j = _preprocess_image_to_save(samples_j)
        for i, reconstruction in enumerate(samples_j):
            reconstructions[i].append(reconstruction)
            save_image(
                reconstruction,
                samples_directory + 'image_{}-{}_inpainted'.format(i, j))

    for i in range(N_to_occlude):
        images.append([data[i] * mask, reconstructions[i], data[i]])

    save_as_grid(images, samples_directory + '/grid.png', spacing=5)