コード例 #1
0
def main(args):
    np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3)

    # =====================================
    # Preparation
    # =====================================
    data_file = os.path.join(RAW_DATA_DIR, "ComputerVision", "dSprites",
                             "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")

    # It is already in the range [0, 1]
    with np.load(data_file, encoding="latin1") as f:
        x_train = f['imgs']

    x_train = np.expand_dims(x_train.astype(np.float32), axis=-1)
    num_train = len(x_train)
    print("x_train: {}".format(num_train))

    args.output_dir = os.path.join(args.output_dir, args.enc_dec_model,
                                   args.run)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    else:
        if args.force_rm_dir:
            import shutil
            shutil.rmtree(args.output_dir, ignore_errors=True)
            print("Removed '{}'".format(args.output_dir))
        else:
            raise ValueError("Output directory '{}' existed. 'force_rm_dir' "
                             "must be set to True!".format(args.output_dir))
        os.mkdir(args.output_dir)

    save_args(os.path.join(args.output_dir, 'config.json'), args)
    # pp.pprint(args.__dict__)

    # =====================================
    # Instantiate models
    # =====================================
    if args.enc_dec_model == "1Konny":
        encoder = Encoder_1Konny(args.z_dim, stochastic=True)
        decoder = Decoder_1Konny()
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support enc_dec_model='{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([64, 64, 1],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
        'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_list()
    # SimpleParamPrinter.print_all_params_tf_slim()

    loss = model.get_loss()
    train_params = model.get_train_params()

    opt_Dz = tf.train.AdamOptimizer(learning_rate=args.lr_Dz,
                                    beta1=args.beta1_Dz,
                                    beta2=args.beta2_Dz)
    opt_vae = tf.train.AdamOptimizer(learning_rate=args.lr_vae,
                                     beta1=args.beta1_vae,
                                     beta2=args.beta2_vae)

    with tf.control_dependencies(model.get_all_update_ops()):
        train_op_Dz = opt_Dz.minimize(loss=loss['Dz_loss'],
                                      var_list=train_params['Dz_loss'])
        train_op_D = train_op_Dz

        train_op_vae = opt_vae.minimize(loss=loss['vae_loss'],
                                        var_list=train_params['vae_loss'])

    # =====================================
    # TF Graph Handler
    asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset"))
    img_gen_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_gen"))
    img_rec_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_rec"))
    img_itpl_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_itpl"))

    log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log"))
    train_log_file = os.path.join(log_dir, "train.log")

    summary_dir = make_dir_if_not_exist(
        os.path.join(args.output_dir, "summary_tf"))
    model_dir = make_dir_if_not_exist(os.path.join(args.output_dir,
                                                   "model_tf"))

    train_helper = SimpleTrainHelper(
        log_dir=summary_dir,
        save_dir=model_dir,
        max_to_keep=3,
        max_to_keep_best=1,
    )
    # =====================================

    # =====================================
    # Training Loop
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)
    train_helper.initialize(sess,
                            init_variables=True,
                            create_summary_writer=True)

    Dz_fetch_keys = [
        'Dz_loss', 'Dz_tc_loss', 'Dz_loss_normal', 'Dz_loss_factor',
        'Dz_avg_prob_normal', 'Dz_avg_prob_factor', 'gp0_z_tc'
    ]
    D_fetch_keys = Dz_fetch_keys

    vae_fetch_keys = ['vae_loss', 'rec_x', 'kld_loss', 'tc_loss']

    global_step = 0
    for epoch in range(args.epochs):
        for batch_ids in iterate_data(num_train, args.batch_size,
                                      shuffle=True):
            global_step += 1

            x = x_train[batch_ids]
            z = np.random.normal(size=[len(x), args.z_dim])

            batch_ids_2 = np.random.choice(num_train,
                                           size=len(batch_ids)).tolist()
            xa = x_train[batch_ids_2]

            for i in range(args.D_steps):
                _, Dm = sess.run(
                    [train_op_D,
                     model.get_output(D_fetch_keys, as_dict=True)],
                    feed_dict={
                        model.is_train: True,
                        model.x_ph: x,
                        model.z_ph: z,
                        model.xa_ph: xa
                    })

            for i in range(args.vae_steps):
                _, VAEm = sess.run(
                    [
                        train_op_vae,
                        model.get_output(vae_fetch_keys, as_dict=True)
                    ],
                    feed_dict={
                        model.is_train: True,
                        model.x_ph: x,
                        model.z_ph: z,
                        model.xa_ph: xa
                    })

            if global_step % args.save_freq == 0:
                train_helper.save(sess, global_step)

            if global_step % args.log_freq == 0:
                log_str = "\n[FactorVAE/{}/{} (dSprites)], Epoch[{}/{}], Step {}".format(
                                args.enc_dec_model, args.run, epoch, args.epochs, global_step) + \
                          "\nvae_loss: {:.4f}, Dz_loss: {:.4f}, Dz_tc_loss: {:.4f}".format(
                              VAEm['vae_loss'], Dm['Dz_loss'], Dm['Dz_tc_loss']) + \
                          "\nrec_x: {:.4f}, kld_loss: {:.4f}, tc_loss: {:.4f}".format(
                              VAEm['rec_x'], VAEm['kld_loss'], VAEm['tc_loss']) + \
                          "\nDz_loss_normal: {:.4f}, Dz_loss_factor: {:.4f}".format(
                              Dm['Dz_loss_normal'], Dm['Dz_loss_factor']) + \
                          "\nDz_avg_prob_normal: {:.4f}, Dz_avg_prob_factor: {:.4f}".format(
                              Dm['Dz_avg_prob_normal'], Dm['Dz_avg_prob_factor']) + \
                          "\ngp0_z_tc_coeff: {:.4f}, gp0_z_tc: {:.4f}".format(args.gp0_z_tc_coeff, Dm['gp0_z_tc'])

                print(log_str)
                with open(train_log_file, "a") as f:
                    f.write(log_str)
                    f.write("\n")
                f.close()

                train_helper.add_summary(
                    custom_tf_scalar_summary('vae_loss',
                                             VAEm['vae_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('rec_x',
                                             VAEm['rec_x'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('kld_loss',
                                             VAEm['kld_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('tc_loss',
                                             VAEm['tc_loss'],
                                             prefix='train'), global_step)

                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_tc_loss',
                                             Dm['Dz_tc_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_loss_normal',
                                             Dm['Dz_loss_normal'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_loss_factor',
                                             Dm['Dz_loss_factor'],
                                             prefix='train'), global_step)

                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_prob_normal',
                                             Dm['Dz_avg_prob_normal'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_prob_factor',
                                             Dm['Dz_avg_prob_factor'],
                                             prefix='train'), global_step)

            if global_step % args.viz_gen_freq == 0:
                # Generate images
                # ------------------------- #
                z = np.random.normal(size=[64, args.z_dim])
                img_file = os.path.join(img_gen_dir,
                                        'step[%d]_gen_test.png' % global_step)

                model.generate_images(
                    img_file,
                    sess,
                    z,
                    block_shape=[8, 8],
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

            if global_step % args.viz_rec_freq == 0:
                # Reconstruct images
                # ------------------------- #
                x = x_train[np.random.choice(num_train, size=64,
                                             replace=False)]
                img_file = os.path.join(img_rec_dir,
                                        'step[%d]_rec_test.png' % global_step)

                model.reconstruct_images(
                    img_file,
                    sess,
                    x,
                    block_shape=[8, 8],
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

            if global_step % args.viz_itpl_freq == 0:
                # Interpolate images
                # ------------------------- #
                x1 = x_train[np.random.choice(num_train,
                                              size=12,
                                              replace=False)]
                x2 = x_train[np.random.choice(num_train,
                                              size=12,
                                              replace=False)]

                img_file = os.path.join(img_itpl_dir,
                                        'step[%d]_itpl_test.png' % global_step)

                model.interpolate_images(
                    img_file,
                    sess,
                    x1,
                    x2,
                    num_itpl_points=12,
                    batch_on_row=True,
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

    # Last save
    train_helper.save(sess, global_step)
コード例 #2
0
def main(args):
    # =====================================
    # Preparation
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)
    num_train = celebA_loader.num_train_data
    num_test = celebA_loader.num_test_data

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))

    args.output_dir = os.path.join(args.output_dir, args.enc_dec_model,
                                   args.run)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    else:
        if args.force_rm_dir:
            import shutil
            shutil.rmtree(args.output_dir, ignore_errors=True)
            print("Removed '{}'".format(args.output_dir))
        else:
            raise ValueError("Output directory '{}' existed. 'force_rm_dir' "
                             "must be set to True!".format(args.output_dir))
        os.mkdir(args.output_dir)

    save_args(os.path.join(args.output_dir, 'config.json'), args)
    # pp.pprint(args.__dict__)

    # =====================================
    # Instantiate models
    # =====================================
    # Only use activation for encoder and decoder
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'Dz_tc_loss': args.Dz_tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_list()

    loss = model.get_loss()
    train_params = model.get_train_params()

    opt_Dz = tf.train.AdamOptimizer(learning_rate=args.lr_Dz,
                                    beta1=args.beta1_Dz,
                                    beta2=args.beta2_Dz)
    opt_vae = tf.train.AdamOptimizer(learning_rate=args.lr_vae,
                                     beta1=args.beta1_vae,
                                     beta2=args.beta2_vae)

    with tf.control_dependencies(model.get_all_update_ops()):
        train_op_Dz = opt_Dz.minimize(loss=loss['Dz_loss'],
                                      var_list=train_params['Dz_loss'])
        train_op_D = train_op_Dz

        train_op_vae = opt_vae.minimize(loss=loss['vae_loss'],
                                        var_list=train_params['vae_loss'])

    # =====================================
    # TF Graph Handler
    asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset"))
    img_gen_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_gen"))
    img_rec_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_rec"))
    img_itpl_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_itpl"))

    log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log"))
    train_log_file = os.path.join(log_dir, "train.log")

    summary_dir = make_dir_if_not_exist(
        os.path.join(args.output_dir, "summary_tf"))
    model_dir = make_dir_if_not_exist(os.path.join(args.output_dir,
                                                   "model_tf"))

    train_helper = SimpleTrainHelper(
        log_dir=summary_dir,
        save_dir=model_dir,
        max_to_keep=3,
        max_to_keep_best=1,
    )
    # =====================================

    # =====================================
    # Training Loop
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)
    train_helper.initialize(sess,
                            init_variables=True,
                            create_summary_writer=True)

    Dz_fetch_keys = [
        'Dz_loss', 'Dz_tc_loss', 'Dz_loss_normal', 'Dz_loss_factor',
        'Dz_avg_prob_normal', 'Dz_avg_prob_factor', 'gp0_z_tc'
    ]
    D_fetch_keys = Dz_fetch_keys
    vae_fetch_keys = ['vae_loss', 'rec_x', 'kld_loss', 'tc_loss']

    train_sampler = ContinuousIndexSampler(num_train,
                                           args.batch_size,
                                           shuffle=True)

    import math
    num_batch_per_epochs = int(math.ceil(num_train / args.batch_size))

    global_step = 0
    for epoch in range(args.epochs):
        for _ in range(num_batch_per_epochs):
            global_step += 1

            batch_ids = train_sampler.sample_ids()
            x = celebA_loader.sample_images_from_dataset(
                sess, 'train', batch_ids)

            z = np.random.randn(len(x), args.z_dim)

            batch_ids_2 = np.random.choice(num_train, size=len(batch_ids))
            xa = celebA_loader.sample_images_from_dataset(
                sess, 'train', batch_ids_2)

            for i in range(args.D_steps):
                _, Dm = sess.run(
                    [train_op_D,
                     model.get_output(D_fetch_keys, as_dict=True)],
                    feed_dict={
                        model.is_train: True,
                        model.x_ph: x,
                        model.z_ph: z,
                        model.xa_ph: xa
                    })

            for i in range(args.vae_steps):
                _, VAEm = sess.run(
                    [
                        train_op_vae,
                        model.get_output(vae_fetch_keys, as_dict=True)
                    ],
                    feed_dict={
                        model.is_train: True,
                        model.x_ph: x,
                        model.z_ph: z,
                        model.xa_ph: xa
                    })

            if global_step % args.save_freq == 0:
                train_helper.save(sess, global_step)

            if global_step % args.log_freq == 0:
                log_str = "\n[FactorVAE (celebA)/{}, {}]".format(args.enc_dec_model, args.run) + \
                          "\nEpoch {}/{}, Step {}, vae_loss: {:.4f}, Dz_loss: {:.4f}, Dz_tc_loss: {:.4f}".format(
                              epoch, args.epochs, global_step, VAEm['vae_loss'], Dm['Dz_loss'], Dm['Dz_tc_loss']) + \
                          "\nrec_x: {:.4f}, kld_loss: {:.4f}, tc_loss: {:.4f}".format(
                              VAEm['rec_x'], VAEm['kld_loss'], VAEm['tc_loss']) + \
                          "\nDz_loss_normal: {:.4f}, Dz_loss_factor: {:.4f}".format(
                              Dm['Dz_loss_normal'], Dm['Dz_loss_factor']) + \
                          "\nDz_avg_prob_normal: {:.4f}, Dz_avg_prob_factor: {:.4f}".format(
                              Dm['Dz_avg_prob_normal'], Dm['Dz_avg_prob_factor']) + \
                          "\ngp0_z_tc_coeff: {:.4f}, gp0_z_tc: {:.4f}".format(args.gp0_z_tc_coeff, Dm['gp0_z_tc'])

                print(log_str)
                with open(train_log_file, "a") as f:
                    f.write(log_str)
                    f.write("\n")
                f.close()

                train_helper.add_summary(
                    custom_tf_scalar_summary('vae_loss',
                                             VAEm['vae_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('rec_x',
                                             VAEm['rec_x'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('kld_loss',
                                             VAEm['kld_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('tc_loss',
                                             VAEm['tc_loss'],
                                             prefix='train'), global_step)

                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_tc_loss',
                                             Dm['Dz_tc_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_loss_normal',
                                             Dm['Dz_loss_normal'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_loss_factor',
                                             Dm['Dz_loss_factor'],
                                             prefix='train'), global_step)

                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_prob_normal',
                                             Dm['Dz_avg_prob_normal'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_prob_factor',
                                             Dm['Dz_avg_prob_factor'],
                                             prefix='train'), global_step)

            if global_step % args.viz_gen_freq == 0:
                # Generate images
                # ------------------------- #
                z = np.random.randn(64, args.z_dim)
                img_file = os.path.join(img_gen_dir,
                                        'step[%d]_gen_test.png' % global_step)

                model.generate_images(
                    img_file,
                    sess,
                    z,
                    block_shape=[8, 8],
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

            if global_step % args.viz_rec_freq == 0:
                # Reconstruct images
                # ------------------------- #
                x = celebA_loader.sample_images_from_dataset(
                    sess, 'test',
                    np.random.choice(num_test, size=64, replace=False))

                img_file = os.path.join(img_rec_dir,
                                        'step[%d]_rec_test.png' % global_step)

                model.reconstruct_images(
                    img_file,
                    sess,
                    x,
                    block_shape=[8, 8],
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

            if global_step % args.viz_itpl_freq == 0:
                # Interpolate images
                # ------------------------- #
                x1 = celebA_loader.sample_images_from_dataset(
                    sess, 'test',
                    np.random.choice(num_test, size=12, replace=False))
                x2 = celebA_loader.sample_images_from_dataset(
                    sess, 'test',
                    np.random.choice(num_test, size=12, replace=False))

                img_file = os.path.join(img_itpl_dir,
                                        'step[%d]_itpl_test.png' % global_step)

                model.interpolate_images(
                    img_file,
                    sess,
                    x1,
                    x2,
                    num_itpl_points=12,
                    batch_on_row=True,
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

        if epoch % 100 == 0:
            train_helper.save_separately(
                sess,
                model_name="model_epoch{}".format(epoch),
                global_step=global_step)

    # Last save
    train_helper.save(sess, global_step)
コード例 #3
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(os.path.join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Preparation
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)
    num_train = celebA_loader.num_train_data
    num_test = celebA_loader.num_test_data

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))

    # =====================================
    # Instantiate models
    # =====================================
    # Only use activation for encoder and decoder
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    elif args.enc_dec_model == "my":
        # assert args.z_dim == 150, "For my, z_dim must be 150. Found {}!".format(args.z_dim)

        encoder = Encoder_My(args.z_dim,
                             stochastic=True,
                             activation=activation)
        decoder = Decoder_My([img_height, img_width, 3],
                             activation=activation,
                             output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_My(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
        # 'Dz_tc_loss': args.Dz_tc_loss,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_list()

    # =====================================
    # TF Graph Handler
    asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset"))
    img_eval = remove_dir_if_exist(os.path.join(asset_dir, "img_eval"),
                                   ask_4_permission=False)
    img_eval = make_dir_if_not_exist(img_eval)

    img_x_gen = make_dir_if_not_exist(os.path.join(img_eval, "x_gen"))
    img_x_rec = make_dir_if_not_exist(os.path.join(img_eval, "x_rec"))
    img_z_rand_2_traversal = make_dir_if_not_exist(
        os.path.join(img_eval, "z_rand_2_traversal"))
    img_z_cond_all_traversal = make_dir_if_not_exist(
        os.path.join(img_eval, "z_cond_all_traversal"))
    img_z_cond_1_traversal = make_dir_if_not_exist(
        os.path.join(img_eval, "z_cond_1_traversal"))
    img_z_corr = make_dir_if_not_exist(os.path.join(img_eval, "z_corr"))
    img_z_dist = make_dir_if_not_exist(os.path.join(img_eval, "z_dist"))
    img_z_stat_dist = make_dir_if_not_exist(
        os.path.join(img_eval, "z_stat_dist"))
    # img_rec_error_dist = make_dir_if_not_exist(os.path.join(img_eval, "rec_error_dist"))

    model_dir = make_dir_if_not_exist(os.path.join(args.output_dir,
                                                   "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)
    # =====================================

    # =====================================
    # Training Loop
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # '''
    # Generation
    # ======================================= #
    z = np.random.randn(64, args.z_dim)

    img_file = os.path.join(img_x_gen, 'x_gen_test.png')
    model.generate_images(img_file,
                          sess,
                          z,
                          block_shape=[8, 8],
                          batch_size=args.batch_size,
                          dec_output_2_img_func=binary_float_to_uint8)
    # ======================================= #
    # '''

    # '''
    # Reconstruction
    # ======================================= #
    seed = 389
    x = celebA_loader.sample_images_from_dataset(sess, 'test',
                                                 list(range(seed, seed + 64)))

    img_file = os.path.join(img_x_rec, 'x_rec_test.png')
    model.reconstruct_images(img_file,
                             sess,
                             x,
                             block_shape=[8, 8],
                             batch_size=-1,
                             dec_output_2_img_func=binary_float_to_uint8)
    # ======================================= #
    # '''

    # '''
    # z random traversal
    # ======================================= #
    if args.z_dim <= 5:
        print("z_dim = {}, perform random traversal!".format(args.z_dim))

        # Plot z cont with z cont
        z_zero = np.zeros([args.z_dim], dtype=np.float32)
        z_rand = np.random.randn(args.z_dim)
        z_start, z_stop = -4, 4
        num_points = 8

        for i in range(args.z_dim):
            for j in range(i + 1, args.z_dim):
                print(
                    "Plot random 2 comps z traverse with {} and {} components!"
                    .format(i, j))

                img_file = os.path.join(img_z_rand_2_traversal,
                                        'z[{},{},zero].png'.format(i, j))
                model.rand_2_latents_traverse(
                    img_file,
                    sess,
                    default_z=z_zero,
                    z_comp1=i,
                    start1=z_start,
                    stop1=z_stop,
                    num_points1=num_points,
                    z_comp2=j,
                    start2=z_start,
                    stop2=z_stop,
                    num_points2=num_points,
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)

                img_file = os.path.join(img_z_rand_2_traversal,
                                        'z[{},{},rand].png'.format(i, j))
                model.rand_2_latents_traverse(
                    img_file,
                    sess,
                    default_z=z_rand,
                    z_comp1=i,
                    start1=z_start,
                    stop1=z_stop,
                    num_points1=num_points,
                    z_comp2=j,
                    start2=z_stop,
                    stop2=z_stop,
                    num_points2=num_points,
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
    # ======================================= #
    # '''

    # z conditional traversal (all features + one feature)
    # ======================================= #
    seed = 389
    num_samples = 30
    data = celebA_loader.sample_images_from_dataset(
        sess, 'train', list(range(seed, seed + num_samples)))

    z_start, z_stop = -4, 4
    num_itpl_points = 8
    for n in range(num_samples):
        print("Plot conditional all comps z traverse with test sample {}!".
              format(n))
        img_file = os.path.join(img_z_cond_all_traversal,
                                'x_train{}.png'.format(n))
        model.cond_all_latents_traverse(
            img_file,
            sess,
            data[n],
            start=z_start,
            stop=z_stop,
            num_itpl_points=num_itpl_points,
            batch_size=args.batch_size,
            dec_output_2_img_func=binary_float_to_uint8)

    z_start, z_stop = -4, 4
    num_itpl_points = 8
    for i in range(args.z_dim):
        print("Plot conditional z traverse with comp {}!".format(i))
        img_file = os.path.join(
            img_z_cond_1_traversal,
            'x_train[{},{}]_z{}.png'.format(seed, seed + num_samples, i))
        model.cond_1_latent_traverse(
            img_file,
            sess,
            data,
            z_comp=i,
            start=z_start,
            stop=z_stop,
            num_itpl_points=num_itpl_points,
            batch_size=args.batch_size,
            dec_output_2_img_func=binary_float_to_uint8)
    # ======================================= #
    # '''

    # '''
    # z correlation matrix
    # ======================================= #
    all_z = []
    for batch_ids in iterate_data(num_train, args.batch_size, shuffle=False):
        x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids)

        z = model.encode(sess, x)
        assert len(
            z.shape) == 2 and z.shape[1] == args.z_dim, "z.shape: {}".format(
                z.shape)

        all_z.append(z)

    all_z = np.concatenate(all_z, axis=0)

    print("Start plotting!")
    plot_corrmat_with_histogram(os.path.join(img_z_corr, "corr_mat.png"),
                                all_z)
    plot_comp_dist(os.path.join(img_z_dist, 'z_{}'), all_z, x_lim=(-5, 5))
    print("Done!")
    # ======================================= #
    # '''

    # '''
    # z gaussian stddev
    # ======================================= #
    print("\nPlot z mean and stddev!")
    all_z_mean = []
    all_z_stddev = []

    for batch_ids in iterate_data(num_train, args.batch_size, shuffle=False):
        x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids)

        z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']),
                                    feed_dict={
                                        model.is_train: False,
                                        model.x_ph: x
                                    })

        all_z_mean.append(z_mean)
        all_z_stddev.append(z_stddev)

    all_z_mean = np.concatenate(all_z_mean, axis=0)
    all_z_stddev = np.concatenate(all_z_stddev, axis=0)

    plot_comp_dist(os.path.join(img_z_stat_dist, 'z_mean_{}.png'),
                   all_z_mean,
                   x_lim=(-5, 5))
    plot_comp_dist(os.path.join(img_z_stat_dist, 'z_stddev_{}.png'),
                   all_z_stddev,
                   x_lim=(-0.5, 3))