示例#1
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size))
    num_train = celebA_loader.num_train_data

    # =====================================
    # Instantiate model
    # =====================================
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny()
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(args.enc_dec_model))

    model = AAE([img_height, img_width, 3], args.z_dim,
                encoder=encoder, decoder=decoder,
                discriminator_z=disc_z,
                rec_x_mode=args.rec_x_mode,
                stochastic_z=args.stochastic_z,
                use_gp0_z=True, gp0_z_mode=args.gp0_z_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'G_loss_z1_gen': args.G_loss_z1_gen_coeff,
        'D_loss_z1_gen': args.D_loss_z1_gen_coeff,
        'gp0_z': args.gp0_z_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    # save_dir = remove_dir_if_exist(join(args.save_dir, "AAE_{}".format(args.run)), ask_4_permission=True)
    # save_dir = make_dir_if_not_exist(save_dir)

    save_dir = make_dir_if_not_exist(join(args.save_dir, "AAE_{}".format(args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True)

    num_bins = args.num_bins
    data_proportion = args.data_proportion
    bin_limits = tuple([float(s) for s in args.bin_limits.split(";")])
    top_k = args.top_k

    f = open(join(save_dir, 'log[bins={},bin_limits={},data={}].txt'.
                  format(num_bins, bin_limits, data_proportion)), mode='w')
    print_ = functools.partial(print_both, file=f)

    result_file = join(args.informativeness_metrics_dir, "AAE_{}".format(args.run),
                       'results[bins={},bin_limits={},data={}].npz'.
                       format(num_bins, bin_limits, data_proportion))

    results = np.load(result_file, "r")

    print_("")
    print_("num_bins: {}".format(num_bins))
    print_("bin_limits: {}".format(bin_limits))
    print_("data_proportion: {}".format(data_proportion))
    print_("top_k: {}".format(top_k))

    # Plotting
    # =========================================== #
    # seed = 389
    # num_samples = 30
    seed = 398
    num_samples = 1

    ids = list(range(seed, seed + num_samples))
    print_("\nids: {}".format(ids))

    data = celebA_loader.sample_images_from_dataset(sess, 'train', ids)

    span = 3
    points_one_side = 5

    print_("sorted_MI: {}".format(results["sorted_MI_z_x"]))
    print_("sorted_z_ids: {}".format(results["sorted_z_comps"]))
    print_("sorted_norm_MI: {}".format(results["sorted_norm_MI_z_x"]))
    print_("sorted_norm_z_ids: {}".format(results["sorted_norm_z_comps"]))

    top_MI = results["sorted_MI_z_x"][:top_k]
    top_z_ids = results["sorted_z_comps"][:top_k]
    top_norm_MI = results["sorted_norm_MI_z_x"][:top_k]
    top_norm_z_ids = results["sorted_norm_z_comps"][:top_k]

    for n in range(len(ids)):
        if top_k == 10:
            print("Plot conditional all comps z traverse with train sample {}!".format(ids[n]))

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={}].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_z_ids, top_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=9,
                                               subplot_adjust={'left': 0.15, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(6.3, 4.9),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={},norm].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_norm_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_norm_z_ids, top_norm_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=9,
                                               subplot_adjust={'left': 0.15, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(6.3, 4.9),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
        elif top_k == 45:
            print("Plot conditional all comps z traverse with train sample {}!".format(ids[n]))

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={}].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_z_ids, top_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=5,
                                               subplot_adjust={'left': 0.19, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(2.98, 9.85),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={},norm].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_norm_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_norm_z_ids, top_norm_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=5,
                                               subplot_adjust={'left': 0.19, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(2.98, 9.85),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
    # =========================================== #

    f.close()
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    celebA_loader = TFCelebAWithAttrLoader(root_dir=args.celebA_root_dir)

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))
    num_train = celebA_loader.num_train_data

    # =====================================
    # Instantiate model
    # =====================================
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny()
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = AAE([img_height, img_width, 3],
                args.z_dim,
                encoder=encoder,
                decoder=decoder,
                discriminator_z=disc_z,
                rec_x_mode=args.rec_x_mode,
                stochastic_z=args.stochastic_z,
                use_gp0_z=True,
                gp0_z_mode=args.gp0_z_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'G_loss_z1_gen': args.G_loss_z1_gen_coeff,
        'D_loss_z1_gen': args.D_loss_z1_gen_coeff,
        'gp0_z': args.gp0_z_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    # save_dir = remove_dir_if_exist(join(args.save_dir, "AAE_{}".format(args.run)), ask_4_permission=True)
    # save_dir = make_dir_if_not_exist(save_dir)

    save_dir = make_dir_if_not_exist(
        join(args.save_dir, "AAE_{}".format(args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan,
                        linewidth=1000,
                        precision=3,
                        suppress=True)

    num_bins = args.num_bins
    bin_limits = tuple([float(s) for s in args.bin_limits.split(";")])
    data_proportion = args.data_proportion

    f = open(join(
        save_dir, 'log[bins={},bin_limits={},data={}].txt'.format(
            num_bins, bin_limits, data_proportion)),
             mode='w')
    print_ = functools.partial(print_both, file=f)

    result_file = join(
        args.interpretability_metrics_dir, "AAE_{}".format(args.run),
        "results[bins={},bin_limits={},data={}].npz".format(
            num_bins, bin_limits, data_proportion))

    results = np.load(result_file, "r")

    print_("")
    print_("num_bins: {}".format(num_bins))
    print_("bin_limits: {}".format(bin_limits))
    print_("data_proportion: {}".format(data_proportion))

    # Plotting
    # =========================================== #
    # seed = 389
    # num_samples = 30
    seed = 398
    num_samples = 1

    ids = list(range(seed, seed + num_samples))
    print_("\nids: {}".format(ids))

    data = celebA_loader.sample_images_from_dataset(sess, 'train', ids)

    span = 3
    points_one_side = 5

    attr_names = celebA_loader.attributes
    print_("attr_names: {}".format(attr_names))
    print_("results.keys: {}".format(list(results.keys())))

    # (z_dim, num_attrs)
    MI_ids_sorted = results['MI_ids_sorted']
    MI_sorted = results['MI_sorted']

    MI_gap_y = results['MI_gap_y']
    H_y = results['H_y_4_diff_z'][:, 0]
    assert MI_ids_sorted.shape[1] == len(attr_names) == len(MI_gap_y) == len(H_y), \
        "MI_ids_sorted.shape: {}, len(attr_names): {}, len(MI_gap_y): {}, len(H_y): {}".format(
            MI_ids_sorted.shape, len(attr_names), len(MI_gap_y), len(H_y))

    print_("\nShow RMIG!")
    for i in range(len(attr_names)):
        print("{}: RMIG: {:.4f}, RMIG (unnorm): {:.4f}, H: {:.4f}".format(
            attr_names[i], MI_gap_y[i], MI_gap_y[i] * H_y[i], H_y[i]))

    print_("\nShow JEMMI!")
    H_z_y = results['H_z_y']
    MI_z_y = results['MI_z_y']

    ids_sorted_by_MI = np.argsort(MI_z_y, axis=0)[::-1]
    MI_z_y_sorted = np.take_along_axis(MI_z_y, ids_sorted_by_MI, axis=0)
    H_z_y_sorted = np.take_along_axis(H_z_y, ids_sorted_by_MI, axis=0)

    H_diff = H_z_y_sorted[0, :] - MI_z_y_sorted[0, :]
    JEMMI_unnorm = H_diff + MI_z_y_sorted[1, :]
    JEMMI_norm = JEMMI_unnorm / (np.log(num_bins) + H_y)

    for i in range(len(attr_names)):
        print(
            "{}: JEMMI: {:.4f}, JEMMI (unnorm): {:.4f}, H_diff: {:.4f}, I2: {:.4f}, top 2 latents: z{}, z{}"
            .format(attr_names[i], JEMMI_norm[i], JEMMI_unnorm[i], H_diff[i],
                    MI_z_y_sorted[1, i], ids_sorted_by_MI[0, i],
                    ids_sorted_by_MI[1, i]))

    # Uncomment if you want
    '''
    for n in range(len(ids)):
        for k in range(len(attr_names)):
            MI_ids_top10 = MI_ids_sorted[:10, k]
            MI_top10 = MI_sorted[:10, k]
            print("Plot top 10 latents for factor '{}'!".format(attr_names[k]))

            img_file = join(save_dir, "x_train[{}][attr={}][bins={},bin_limits={},data={}].png".
                            format(ids[n], attr_names[k], num_bins, bin_limits, data_proportion))

            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=MI_ids_top10,
                                               z_comp_labels=["z[{}] ({:.4f})".format(comp, mi)
                                                              for comp, mi in zip(MI_ids_top10, MI_top10)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=9,
                                               title="{} (MI gap={:.4f}, H={:.4f})".format(
                                                   attr_names[k], MI_gap_y[k], H_y[k]),
                                               title_font_scale=1.5,
                                               subplot_adjust={'left': 0.16, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.95},
                                               size_inches=(6.5, 5.2),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
    '''

    # Top 5 only
    for n in range(len(ids)):
        for k in range(len(attr_names)):
            MI_ids_top10 = MI_ids_sorted[:5, k]
            MI_top10 = MI_sorted[:5, k]
            print("Plot top 5 latents for factor '{}'!".format(attr_names[k]))

            img_file = join(
                save_dir, "train{}_attr={}_bins={}_data={}.png".format(
                    ids[n], attr_names[k], num_bins, data_proportion))

            model.cond_all_latents_traverse_v2(
                img_file,
                sess,
                data[n],
                z_comps=MI_ids_top10,
                z_comp_labels=[
                    "z[{}] ({:.4f})".format(comp, mi)
                    for comp, mi in zip(MI_ids_top10, MI_top10)
                ],
                span=span,
                points_1_side=points_one_side,
                hl_x=True,
                font_size=9,
                title="{} (MI gap={:.4f}, H={:.4f})".format(
                    attr_names[k], MI_gap_y[k], H_y[k]),
                title_font_scale=1.5,
                subplot_adjust={
                    'left': 0.16,
                    'right': 0.99,
                    'bottom': 0.005,
                    'top': 0.93
                },
                size_inches=(6.5, 2.8),
                batch_size=args.batch_size,
                dec_output_2_img_func=binary_float_to_uint8)
    '''
示例#3
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites",
                     "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")

    # It is already in the range [0, 1]
    with np.load(data_file, encoding="latin1") as f:
        x_train = f['imgs']

    x_train = np.reshape(x_train, [3, 6, 40, 32, 32, 64, 64, 1])

    # =====================================
    # Instantiate model
    # =====================================
    if args.enc_dec_model == "1Konny":
        encoder = Encoder_1Konny(args.z_dim, stochastic=True)
        decoder = Decoder_1Konny()
        disc_z = DiscriminatorZ_1Konny()
    else:
        raise ValueError("Do not support enc_dec_model='{}'!".format(
            args.enc_dec_model))

    model = AAE([64, 64, 1],
                args.z_dim,
                encoder=encoder,
                decoder=decoder,
                discriminator_z=disc_z,
                rec_x_mode=args.rec_x_mode,
                stochastic_z=args.stochastic_z,
                use_gp0_z=True,
                gp0_z_mode=args.gp0_z_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'G_loss_z1_gen': args.G_loss_z1_gen_coeff,
        'D_loss_z1_gen': args.D_loss_z1_gen_coeff,
        'gp0_z': args.gp0_z_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    # save_dir = remove_dir_if_exist(join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run)), ask_4_permission=True)
    # save_dir = make_dir_if_not_exist(save_dir)

    save_dir = make_dir_if_not_exist(
        join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan,
                        linewidth=1000,
                        precision=3,
                        suppress=True)

    num_bins = args.num_bins
    bin_limits = tuple([float(s) for s in args.bin_limits.split(";")])
    data_proportion = args.data_proportion

    # Logs
    f = open(join(
        save_dir, 'log[bins={},bin_limits={},data={}].txt'.format(
            num_bins, bin_limits, data_proportion)),
             mode='w')
    print_ = functools.partial(print_both, file=f)

    print_("")
    print_("num_bins: {}".format(num_bins))
    print_("bin_limits: {}".format(bin_limits))
    print_("data_proportion: {}".format(data_proportion))

    # Results
    result_file = join(
        args.interpretability_metrics_dir,
        "{}_{}".format(args.enc_dec_model, args.run),
        "results[bins={},bin_limits={},data={}].npz".format(
            num_bins, bin_limits, data_proportion))

    results = np.load(result_file, "r")
    print_("results.keys: {}".format(list(results.keys())))

    # Plotting
    # =========================================== #
    data = [
        x_train[0, 3, 20, 16, 16], x_train[1, 3, 20, 16, 16], x_train[2, 3, 20,
                                                                      16, 16]
    ]

    gt_factors = ['Shape', 'Scale', 'Rotation', 'Pos_x', 'Pos_y']
    ids_sorted = results['ids_sorted']
    MI_z_y_sorted = results['MI_z_y_sorted']
    H_z_y_sorted = results['H_z_y_sorted']
    H_y = results['H_y']
    RMIG = results['RMIG']
    JEMMI = results['JEMMI']

    print_("MI_z_y_sorted:\n{}".format(MI_z_y_sorted))

    print_("\nShow RMIG!")
    for k in range(len(gt_factors)):
        print_(
            "{}, RMIG: {:.4f}, RMIG (unnorm): {:.4f}, H: {:.4f}, I1: {:.4f}, I2: {:.4f}"
            .format(gt_factors[k], RMIG[k], RMIG[k] * H_y[k], H_y[k],
                    MI_z_y_sorted[0, k], MI_z_y_sorted[1, k]))

    print_("\nShow JEMMI!")
    for k in range(len(gt_factors)):
        print_(
            "{}, JEMMI: {:.4f}, JEMMI (unnorm): {:.4f}, H1: {:.4f}, H1-I1: {:.4f}, I2: {:.4f}, "
            "top2 ids: z{}, z{}".format(
                gt_factors[k], JEMMI[k],
                JEMMI[k] * (H_y[k] + np.log(num_bins)), H_z_y_sorted[0, k],
                H_z_y_sorted[0, k] - MI_z_y_sorted[0, k], MI_z_y_sorted[1, k],
                ids_sorted[0, k], ids_sorted[1, k]))

    span = 3
    points_one_side = 5

    for n in range(len(data)):
        for k in range(len(gt_factors)):
            print("x={}, y={}!".format(n, gt_factors[k]))

            img_file = join(
                save_dir, "{}[x={},bins={},bin_limits={},data={}].png".format(
                    gt_factors[k], n, num_bins, bin_limits, data_proportion))
            '''
            ids_top10 = ids_sorted[:10, k]
            MI_top10 = MI_z_y_sorted[:10, k]
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=ids_top10,
                                               z_comp_labels=["z[{}] ({:.4f})".format(comp, mi)
                                                              for comp, mi in zip(ids_top10, MI_top10)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=9,
                                               title="{} (RMIG={:.4f}, JEMMI={:.4f}, H={:.4f})".format(
                                                   gt_factors[k], RMIG[k], JEMMI[k], H_y[k]),
                                               title_font_scale=1.5,
                                               subplot_adjust={'left': 0.16, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.95},
                                               size_inches=(6.5, 5.2),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
            '''

            ids_top3 = ids_sorted[:3, k]
            MI_top3 = MI_z_y_sorted[:3, k]
            model.cond_all_latents_traverse_v2(
                img_file,
                sess,
                data[n],
                z_comps=ids_top3,
                z_comp_labels=[
                    "z[{}] ({:.4f})".format(comp, mi)
                    for comp, mi in zip(ids_top3, MI_top3)
                ],
                span=span,
                points_1_side=points_one_side,
                hl_x=True,
                font_size=9,
                title="{} (RMIG={:.4f}, JEMMI={:.4f}, H={:.4f})".format(
                    gt_factors[k], RMIG[k], JEMMI[k], H_y[k]),
                title_font_scale=1.5,
                subplot_adjust={
                    'left': 0.16,
                    'right': 0.99,
                    'bottom': 0.01,
                    'top': 0.88
                },
                size_inches=(6.2, 1.7),
                batch_size=args.batch_size,
                dec_output_2_img_func=binary_float_to_uint8)

    f.close()