コード例 #1
0
def main(args):
    # Load config
    # =========================================== #
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)
    # =========================================== #

    # Load dataset
    # =========================================== #
    data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites",
                     "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")

    # It is already in the range [0, 1]
    with np.load(data_file, encoding="latin1") as f:
        x_train = f['imgs']

    x_train = np.reshape(x_train, [3, 6, 40, 32, 32, 64, 64, 1])
    # =========================================== #

    # Build model
    # =========================================== #
    if args.enc_dec_model == "1Konny":
        encoder = Encoder_1Konny(args.z_dim, stochastic=True)
        decoder = Decoder_1Konny()
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support enc_dec_model='{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([64, 64, 1],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
        'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()
    # =========================================== #

    # Initialize session
    # =========================================== #
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)
    train_helper.load(sess, load_step=args.load_step)

    save_dir = make_dir_if_not_exist(
        join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run)))
    # =========================================== #

    # Load result file
    # =========================================== #
    result_file = join(args.SEPIN_dir, "{}_{}".format(args.enc_dec_model,
                                                      args.run),
                       "results[num_samples={}].npz".format(args.num_samples))

    results = np.load(result_file, "r")
    print("results.keys: {}".format(list(results.keys())))

    np.set_printoptions(threshold=np.nan,
                        linewidth=1000,
                        precision=3,
                        suppress=True)
    # =========================================== #

    # Plotting
    # =========================================== #
    data = [
        x_train[0, 3, 20, 16, 16], x_train[1, 3, 20, 16, 16], x_train[2, 3, 20,
                                                                      16, 16]
    ]

    gt_factors = ['Shape', 'Scale', 'Rotation', 'Pos_x', 'Pos_y']

    # (num_latents,)
    MI_zi_x = results['MI_zi_x']
    SEP_zi = results['SEP_zi']
    ids_sorted = np.argsort(SEP_zi, axis=0)[::-1]

    print("")
    print("MI_zi_x: {}".format(MI_zi_x))
    print("SEP_zi: {}".format(SEP_zi))
    print("ids_sorted: {}".format(ids_sorted))

    span = 3
    points_one_side = 5

    for n in range(len(data)):
        img_file = join(
            save_dir, "sep_x{}_num_samples={}.png".format(n, args.num_samples))
        model.cond_all_latents_traverse_v2(
            img_file,
            sess,
            data[n],
            z_comps=ids_sorted,
            z_comp_labels=[
                "z[{}] (SEP={:.4f}, INFO={:.4f})".format(
                    idx, SEP_zi[idx], MI_zi_x[idx]) for idx in ids_sorted
            ],
            span=span,
            points_1_side=points_one_side,
            hl_x=True,
            font_size=9,
            title_font_scale=1.5,
            subplot_adjust={
                'left': 0.55,
                'right': 0.99,
                'bottom': 0.01,
                'top': 0.99
            },
            size_inches=(4.0, 1.7),
            batch_size=args.batch_size,
            dec_output_2_img_func=binary_float_to_uint8)

    f.close()
コード例 #2
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size))
    num_train = celebA_loader.num_train_data

    # =====================================
    # Instantiate model
    # =====================================
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3], args.z_dim,
                      encoder=encoder, decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    # save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=True)
    # save_dir = make_dir_if_not_exist(save_dir)

    save_dir = make_dir_if_not_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True)

    num_bins = args.num_bins
    data_proportion = args.data_proportion
    bin_limits = tuple([float(s) for s in args.bin_limits.split(";")])
    top_k = args.top_k

    f = open(join(save_dir, 'log[bins={},bin_limits={},data={}].txt'.
                  format(num_bins, bin_limits, data_proportion)), mode='w')
    print_ = functools.partial(print_both, file=f)

    result_file = join(args.informativeness_metrics_dir, "FactorVAE_{}".format(args.run),
                       'results[bins={},bin_limits={},data={}].npz'.
                       format(num_bins, bin_limits, data_proportion))

    results = np.load(result_file, "r")

    print_("")
    print_("num_bins: {}".format(num_bins))
    print_("bin_limits: {}".format(bin_limits))
    print_("data_proportion: {}".format(data_proportion))
    print_("top_k: {}".format(top_k))

    # Plotting
    # =========================================== #
    # seed = 389
    # num_samples = 30
    seed = 398
    num_samples = 1

    ids = list(range(seed, seed + num_samples))
    print_("\nids: {}".format(ids))

    data = celebA_loader.sample_images_from_dataset(sess, 'train', ids)

    span = 3
    points_one_side = 5

    print_("sorted_MI: {}".format(results["sorted_MI_z_x"]))
    print_("sorted_z_ids: {}".format(results["sorted_z_comps"]))
    print_("sorted_norm_MI: {}".format(results["sorted_norm_MI_z_x"]))
    print_("sorted_norm_z_ids: {}".format(results["sorted_norm_z_comps"]))

    top_MI = results["sorted_MI_z_x"][:top_k]
    top_z_ids = results["sorted_z_comps"][:top_k]
    top_norm_MI = results["sorted_norm_MI_z_x"][:top_k]
    top_norm_z_ids = results["sorted_norm_z_comps"][:top_k]

    print("Matplotlib font size: {}".format(matplotlib.rcParams['font.size'],))
    for n in range(len(ids)):
        if top_k == 10:
            print("Plot conditional all comps z traverse with train sample {}!".format(ids[n]))

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={}].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_z_ids, top_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=matplotlib.rcParams['font.size'],
                                               subplot_adjust={'left': 0.15, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(6.3, 4.9),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={},norm].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_norm_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_norm_z_ids, top_norm_MI)],
                                               span=span, points_1_side=points_one_side,

                                               hl_x=True,
                                               font_size=matplotlib.rcParams['font.size'],
                                               subplot_adjust={'left': 0.15, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(6.3, 4.9),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
        else:
            print("Plot conditional all comps z traverse with train sample {}!".format(ids[n]))

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={}].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_z_ids, top_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=5,
                                               subplot_adjust={'left': 0.19, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(2.98, 9.85),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)

            img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={},norm].png".
                            format(ids[n], num_bins, bin_limits, data_proportion))
            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=top_norm_z_ids,
                                               z_comp_labels=["z[{}] ({:.2f})".format(comp, mi)
                                                              for comp, mi in zip(top_norm_z_ids, top_norm_MI)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=5,
                                               subplot_adjust={'left': 0.19, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.99},
                                               size_inches=(2.98, 9.85),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
    # =========================================== #

    f.close()
コード例 #3
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites",
                     "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")

    # It is already in the range [0, 1]
    with np.load(data_file, encoding="latin1") as f:
        x_train = f['imgs']

    x_train = np.reshape(x_train, [3, 6, 40, 32, 32, 64, 64, 1])

    # =====================================
    # Instantiate model
    # =====================================
    if args.enc_dec_model == "1Konny":
        encoder = Encoder_1Konny(args.z_dim, stochastic=True)
        decoder = Decoder_1Konny()
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support enc_dec_model='{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([64, 64, 1],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
        'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    save_dir = make_dir_if_not_exist(
        join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan,
                        linewidth=1000,
                        precision=3,
                        suppress=True)

    result_file = join(args.JEMMIG_sampling_dir,
                       "{}_{}".format(args.enc_dec_model, args.run),
                       "results[num_samples={}].npz".format(args.num_samples))

    results = np.load(result_file, "r")
    print("results.keys: {}".format(list(results.keys())))

    # Plotting
    # =========================================== #
    data = [
        x_train[0, 3, 20, 16, 16], x_train[1, 3, 20, 16, 16], x_train[2, 3, 20,
                                                                      16, 16]
    ]

    gt_factors = ['Shape', 'Scale', 'Rotation', 'Pos_x', 'Pos_y']
    ids_sorted = results['id_sorted']

    MI_zi_yk_sorted = results['MI_zi_yk_sorted']
    H_zi_yk_sorted = results['H_zi_yk_sorted']

    H_yk = results['H_yk']
    RMIG_yk = results['RMIG_yk']
    RMIG_norm_yk = results['RMIG_norm_yk']
    JEMMIG_yk = results['JEMMIG_yk']

    print("MI_zi_yk_sorted:\n{}".format(MI_zi_yk_sorted))

    print("\nShow RMIG!")
    for k in range(len(gt_factors)):
        print(
            "{}, RMIG: {:.4f}, RMIG (norm): {:.4f}, H: {:.4f}, I1: {:.4f}, I2: {:.4f}"
            .format(gt_factors[k], RMIG_yk[k], RMIG_norm_yk[k], H_yk[k],
                    MI_zi_yk_sorted[0, k], MI_zi_yk_sorted[1, k]))

    print("\nShow JEMMIG!")
    for k in range(len(gt_factors)):
        print(
            "{}, JEMMIG: {:.4f}, H1: {:.4f}, H1-I1: {:.4f}, I2: {:.4f}, top2 ids: z{}, z{}"
            .format(gt_factors[k], JEMMIG_yk[k], H_zi_yk_sorted[0, k],
                    H_zi_yk_sorted[0, k] - MI_zi_yk_sorted[0, k],
                    MI_zi_yk_sorted[1, k], ids_sorted[0, k], ids_sorted[1, k]))

    span = 3
    points_one_side = 5

    for n in range(len(data)):
        for k in range(len(gt_factors)):
            print("x={}, y={}!".format(n, gt_factors[k]))

            img_file = join(
                save_dir,
                "{}-x{}_num_samples={}].png".format(gt_factors[k], n,
                                                    args.num_samples))

            ids_top3 = ids_sorted[:3, k]
            MI_top3 = MI_zi_yk_sorted[:3, k]
            model.cond_all_latents_traverse_v2(
                img_file,
                sess,
                data[n],
                z_comps=ids_top3,
                z_comp_labels=[
                    "z[{}] ({:.4f})".format(comp, mi)
                    for comp, mi in zip(ids_top3, MI_top3)
                ],
                span=span,
                points_1_side=points_one_side,
                hl_x=True,
                font_size=9,
                title="{} (RMIG={:.4f}, JEMMIG={:.4f}, H={:.4f})".format(
                    gt_factors[k], RMIG_yk[k], JEMMIG_yk[k], H_yk[k]),
                title_font_scale=1.5,
                subplot_adjust={
                    'left': 0.16,
                    'right': 0.99,
                    'bottom': 0.01,
                    'top': 0.88
                },
                size_inches=(6.2, 1.7),
                batch_size=args.batch_size,
                dec_output_2_img_func=binary_float_to_uint8)

    f.close()
コード例 #4
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    celebA_loader = TFCelebAWithAttrLoader(root_dir=args.celebA_root_dir)

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))
    num_train = celebA_loader.num_train_data

    # =====================================
    # Instantiate model
    # =====================================
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    # save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=True)
    # save_dir = make_dir_if_not_exist(save_dir)

    save_dir = make_dir_if_not_exist(
        join(args.save_dir, "FactorVAE_{}".format(args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan,
                        linewidth=1000,
                        precision=3,
                        suppress=True)

    num_bins = args.num_bins
    bin_limits = tuple([float(s) for s in args.bin_limits.split(";")])
    data_proportion = args.data_proportion

    f = open(join(
        save_dir, 'log[bins={},bin_limits={},data={}].txt'.format(
            num_bins, bin_limits, data_proportion)),
             mode='w')
    print_ = functools.partial(print_both, file=f)

    result_file = join(
        args.interpretability_metrics_dir, "FactorVAE_{}".format(args.run),
        "results[bins={},bin_limits={},data={}].npz".format(
            num_bins, bin_limits, data_proportion))

    results = np.load(result_file, "r")

    print_("")
    print_("num_bins: {}".format(num_bins))
    print_("bin_limits: {}".format(bin_limits))
    print_("data_proportion: {}".format(data_proportion))

    # Plotting
    # =========================================== #
    # seed = 389
    # num_samples = 30
    seed = 398
    num_samples = 1

    ids = list(range(seed, seed + num_samples))
    print_("\nids: {}".format(ids))

    data = celebA_loader.sample_images_from_dataset(sess, 'train', ids)

    span = 3
    points_one_side = 5

    attr_names = celebA_loader.attributes
    print_("#attrs: {}".format(len(attr_names)))
    print_("attr_names: {}".format(attr_names))
    print_("results.keys: {}".format(list(results.keys())))

    # (z_dim, num_attrs)
    MI_ids_sorted = results['MI_ids_sorted']
    MI_sorted = results['MI_sorted']

    MI_gap_y = results['MI_gap_y']
    H_y = results['H_y_4_diff_z'][:, 0]
    assert MI_ids_sorted.shape[1] == len(attr_names) == len(MI_gap_y) == len(H_y), \
        "MI_ids_sorted.shape: {}, len(attr_names): {}, len(MI_gap_y): {}, len(H_y): {}".format(
            MI_ids_sorted.shape, len(attr_names), len(MI_gap_y), len(H_y))

    print_("\nShow RMIG!")
    for i in range(len(attr_names)):
        print("{}: RMIG: {:.4f}, RMIG (unnorm): {:.4f}, H: {:.4f}".format(
            attr_names[i], MI_gap_y[i], MI_gap_y[i] * H_y[i], H_y[i]))

    print_("\nShow JEMMIG!")
    H_z_y = results['H_z_y']
    MI_z_y = results['MI_z_y']

    ids_sorted_by_MI = np.argsort(MI_z_y, axis=0)[::-1]
    MI_z_y_sorted = np.take_along_axis(MI_z_y, ids_sorted_by_MI, axis=0)
    H_z_y_sorted = np.take_along_axis(H_z_y, ids_sorted_by_MI, axis=0)

    H_diff = H_z_y_sorted[0, :] - MI_z_y_sorted[0, :]
    JEMMIG_unnorm = H_diff + MI_z_y_sorted[1, :]
    JEMMIG_norm = JEMMIG_unnorm / (np.log(num_bins) + H_y)

    for i in range(len(attr_names)):
        print(
            "{}: JEMMIG: {:.4f}, JEMMIG (unnorm): {:.4f}, H_diff: {:.4f}, I2: {:.4f}, top 2 latents: z{}, z{}"
            .format(attr_names[i], JEMMIG_norm[i], JEMMIG_unnorm[i], H_diff[i],
                    MI_z_y_sorted[1, i], ids_sorted_by_MI[0, i],
                    ids_sorted_by_MI[1, i]))

    # Plot JEMMIG/RMIG of all attributes
    # ========================================= #
    import matplotlib.pyplot as plt
    import matplotlib
    from matplotlib.backends.backend_pdf import PdfPages

    font = {'family': 'normal', 'size': 12}
    matplotlib.rc('font', **font)

    width = 0.3
    plt.bar(np.arange(len(attr_names)),
            JEMMIG_norm,
            width=width,
            align='center',
            edgecolor='black',
            label='JEMMIG')
    plt.bar(np.arange(len(attr_names)) + width,
            MI_gap_y,
            width=width,
            align='center',
            edgecolor='black',
            label='RMIG')

    plt.xticks(np.arange(len(attr_names)) + 0.5 * width,
               attr_names,
               rotation=90)
    plt.xlabel("model")
    plt.ylabel("JEMMIG/RMIG")

    # subplot_adjust = {'left': 0.08, 'right': 0.99, 'bottom': 0.17, 'top': 0.98}
    # plt.subplots_adjust(**subplot_adjust)
    plt.gcf().set_size_inches(12, 3)

    save_dir = make_dir_if_not_exist(save_dir)
    save_file = join(save_dir, "JEMMIG_tc_beta.pdf")

    with PdfPages(save_file) as pdf_file:
        plt.savefig(pdf_file, format='pdf')

    plt.show()
    plt.close()
    # ========================================= #

    # Uncomment if you want
    '''
    for n in range(len(ids)):
        for k in range(len(attr_names)):
            MI_ids_top10 = MI_ids_sorted[:10, k]
            MI_top10 = MI_sorted[:10, k]
            print("Plot top 10 latents for factor '{}'!".format(attr_names[k]))

            img_file = join(save_dir, "x_train[{}][attr={}][bins={},bin_limits={},data={}].png".
                            format(ids[n], attr_names[k], num_bins, bin_limits, data_proportion))

            model.cond_all_latents_traverse_v2(img_file, sess, data[n],
                                               z_comps=MI_ids_top10,
                                               z_comp_labels=["z[{}] ({:.4f})".format(comp, mi)
                                                              for comp, mi in zip(MI_ids_top10, MI_top10)],
                                               span=span, points_1_side=points_one_side,
                                               hl_x=True,
                                               font_size=9,
                                               title="{} (MI gap={:.4f}, H={:.4f})".format(
                                                   attr_names[k], MI_gap_y[k], H_y[k]),
                                               title_font_scale=1.5,
                                               subplot_adjust={'left': 0.16, 'right': 0.99,
                                                               'bottom': 0.01, 'top': 0.95},
                                               size_inches=(6.5, 5.2),
                                               batch_size=args.batch_size,
                                               dec_output_2_img_func=binary_float_to_uint8)
    #'''

    # Top 5 only
    for n in range(len(ids)):
        for k in range(len(attr_names)):
            MI_ids_top10 = MI_ids_sorted[:5, k]
            MI_top10 = MI_sorted[:5, k]
            print("Plot top 5 latents for factor '{}'!".format(attr_names[k]))

            img_file = join(
                save_dir, "train{}_attr={}_bins={}_data={}.png".format(
                    ids[n], attr_names[k], num_bins, data_proportion))

            model.cond_all_latents_traverse_v2(
                img_file,
                sess,
                data[n],
                z_comps=MI_ids_top10,
                z_comp_labels=[
                    "z[{}] ({:.4f})".format(comp, mi)
                    for comp, mi in zip(MI_ids_top10, MI_top10)
                ],
                span=span,
                points_1_side=points_one_side,
                hl_x=True,
                font_size=9,
                title="{} (MI gap={:.4f}, H={:.4f})".format(
                    attr_names[k], MI_gap_y[k], H_y[k]),
                title_font_scale=1.5,
                subplot_adjust={
                    'left': 0.16,
                    'right': 0.99,
                    'bottom': 0.005,
                    'top': 0.93
                },
                size_inches=(6.5, 2.8),
                batch_size=args.batch_size,
                dec_output_2_img_func=binary_float_to_uint8)
    '''
コード例 #5
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))
    num_train = celebA_loader.num_train_data

    # =====================================
    # Instantiate model
    # =====================================
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    save_dir = remove_dir_if_exist(join(args.save_dir,
                                        "FactorVAE_{}".format(args.run)),
                                   ask_4_permission=False)
    save_dir = make_dir_if_not_exist(save_dir)

    # save_dir = make_dir_if_not_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)))
    # =====================================

    np.set_printoptions(threshold=np.nan,
                        linewidth=1000,
                        precision=3,
                        suppress=True)
    f = open(join(save_dir, 'log.txt'), mode='w')
    print_ = functools.partial(print_both, file=f)

    # z gaussian stddev
    # ======================================= #
    all_z_mean = []
    all_z_stddev = []

    count = 0
    for batch_ids in iterate_data(int(0.05 * num_train),
                                  10 * args.batch_size,
                                  shuffle=False):
        x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids)

        z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']),
                                    feed_dict={
                                        model.is_train: False,
                                        model.x_ph: x
                                    })

        all_z_mean.append(z_mean)
        all_z_stddev.append(z_stddev)

        count += len(batch_ids)
        print("\rProcessed {} samples!".format(count), end="")
    print()

    all_z_mean = np.concatenate(all_z_mean, axis=0)
    all_z_stddev = np.concatenate(all_z_stddev, axis=0)
    # ======================================= #

    z_std_error = np.std(all_z_mean, axis=0, ddof=0)
    z_sorted_comps = np.argsort(z_std_error)[::-1]
    top10_z_comps = z_sorted_comps[:10]

    print_("")
    print_("z_std_error: {}".format(z_std_error))
    print_("z_sorted_std_error: {}".format(z_std_error[z_sorted_comps]))
    print_("z_sorted_comps: {}".format(z_sorted_comps))
    print_("top10_z_comps: {}".format(top10_z_comps))

    z_stddev_mean = np.mean(all_z_stddev, axis=0)
    info_z_comps = [
        idx for idx in range(len(z_stddev_mean)) if z_stddev_mean[idx] < 0.4
    ]
    print_("info_z_comps: {}".format(info_z_comps))
    print_("len(info_z_comps): {}".format(len(info_z_comps)))

    # Plotting
    # =========================================== #
    seed = 389
    num_samples = 30
    ids = list(range(seed, seed + num_samples))
    print("\nids: {}".format(ids))

    data = celebA_loader.sample_images_from_dataset(sess, 'train', ids)

    span = 3
    points_one_side = 5

    # span = 8
    # points_one_side = 12

    for n in range(len(ids)):
        print("Plot conditional all comps z traverse with train sample {}!".
              format(ids[n]))
        img_file = join(save_dir,
                        "x_train[{}]_[span={}]_hl.png".format(ids[n], span))
        # model.cond_all_latents_traverse_v2(img_file, sess, data[n],
        #                                 z_comps=top10_z_comps,
        #                                 z_comp_labels=None,
        #                                 span=span, points_1_side=points_one_side,
        #                                 hl_x=True,
        #                                 batch_size=args.batch_size,
        #                                 dec_output_2_img_func=binary_float_to_uint8)

        img_file = join(
            save_dir,
            "x_train[{}]_[span={}]_hl_labeled.png".format(ids[n], span))
        model.cond_all_latents_traverse_v2(
            img_file,
            sess,
            data[n],
            z_comps=top10_z_comps,
            z_comp_labels=["z[{}]".format(comp) for comp in top10_z_comps],
            span=span,
            points_1_side=points_one_side,
            hl_x=True,
            subplot_adjust={
                'left': 0.09,
                'right': 0.98,
                'bottom': 0.02,
                'top': 0.98
            },
            size_inches=(6, 5),
            batch_size=args.batch_size,
            dec_output_2_img_func=binary_float_to_uint8)

        # img_file = join(save_dir, "x_train[{}]_[span={}].png".format(ids[n], span))
        # model.cond_all_latents_traverse_v2(img_file, sess, data[n],
        #                                    z_comps=top10_z_comps,
        #                                    z_comp_labels=None,
        #                                    span=span, points_1_side=points_one_side,
        #                                    hl_x=False,
        #                                    batch_size=args.batch_size,
        #                                    dec_output_2_img_func=binary_float_to_uint8)
        #
        # img_file = join(save_dir, "x_train[{}]_[span={}]_labeled.png".format(ids[n], span))
        # model.cond_all_latents_traverse_v2(img_file, sess, data[n],
        #                                    z_comps=top10_z_comps,
        #                                    z_comp_labels=["z[{}]".format(comp) for comp in top10_z_comps],
        #                                    span=span, points_1_side=points_one_side,
        #                                    hl_x=False,
        #                                    subplot_adjust={'left': 0.09, 'right': 0.98, 'bottom': 0.02, 'top': 0.98},
        #                                    size_inches=(6, 5),
        #                                    batch_size=args.batch_size,
        #                                    dec_output_2_img_func=binary_float_to_uint8)

        img_file = join(
            save_dir, "x_train[{}]_[span={}]_info_hl.png".format(ids[n], span))
        model.cond_all_latents_traverse_v2(
            img_file,
            sess,
            data[n],
            z_comps=info_z_comps,
            z_comp_labels=None,
            span=span,
            points_1_side=points_one_side,
            hl_x=True,
            batch_size=args.batch_size,
            dec_output_2_img_func=binary_float_to_uint8)
    # =========================================== #

    f.close()