Esempio n. 1
0
def main(args):
    # Create output directory
    # ===================================== #
    args.output_dir = os.path.join(args.output_dir, args.model_name, args.run)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    else:
        if args.force_rm_dir:
            import shutil
            shutil.rmtree(args.output_dir, ignore_errors=True)
            print("Removed '{}'".format(args.output_dir))
        else:
            raise ValueError("Output directory '{}' existed. 'force_rm_dir' "
                             "must be set to True!".format(args.output_dir))
        os.mkdir(args.output_dir)

    save_args(os.path.join(args.output_dir, 'config.json'), args)
    # pp = pprint.PrettyPrinter(indent=4)
    # pp.pprint(args.__dict__)
    # ===================================== #

    # Specify data
    # ===================================== #
    if args.dataset == "mnist":
        x_shape = [28, 28, 1]
    elif args.dataset == "mnist_3" or args.dataset == "mnistm":
        x_shape = [28, 28, 3]
    elif args.dataset == "svhn" or args.dataset == "cifar10" or args.dataset == "cifar100":
        x_shape = [32, 32, 3]
    else:
        raise ValueError("Do not support dataset '{}'!".format(args.dataset))

    if args.dataset == "cifar100":
        num_classes = 100
    else:
        num_classes = 10

    print("x_shape: {}".format(x_shape))
    print("num_classes: {}".format(num_classes))
    # ===================================== #

    # Load data
    # ===================================== #
    print("Loading {}!".format(args.dataset))
    train_loader = SimpleDataset4SSL()
    train_loader.load_npz_data(args.train_file)
    train_loader.create_ssl_data(args.num_labeled,
                                 num_classes=num_classes,
                                 shuffle=True,
                                 seed=args.seed)
    if args.input_norm != "applied":
        train_loader.x = uint8_to_binary_float(train_loader.x)
    else:
        print("Input normalization has been applied on train data!")

    test_loader = SimpleDataset()
    test_loader.load_npz_data(args.test_file)
    if args.input_norm != "applied":
        test_loader.x = uint8_to_binary_float(test_loader.x)
    else:
        print("Input normalization has been applied on test data!")

    print("train_l/train_u/test: {}/{}/{}".format(
        train_loader.num_labeled_data, train_loader.num_unlabeled_data,
        test_loader.num_data))

    # import matplotlib.pyplot as plt
    # print("train_l.y[:10]: {}".format(train_l.y[:10]))
    # print("train_u.y[:10]: {}".format(train_u.y[:10]))
    # print("test.y[:10]: {}".format(test.y[:10]))
    # fig, axes = plt.subplots(3, 5)
    # for i in range(5):
    #     axes[0][i].imshow(train_l.x[i])
    #     axes[1][i].imshow(train_u.x[i])
    #     axes[2][i].imshow(test.x[i])
    # plt.show()

    if args.dataset == "mnist":
        train_loader.x = np.expand_dims(train_loader.x, axis=-1)
        test_loader.x = np.expand_dims(test_loader.x, axis=-1)
    elif args.dataset == "mnist_3":
        train_loader.x = np.stack(
            [train_loader.x, train_loader.x, train_loader.x], axis=-1)
        test_loader.x = np.stack([test_loader.x, test_loader.x, test_loader.x],
                                 axis=-1)

    # Data Preprocessing + Augmentation
    # ------------------------------------- #
    if args.input_norm == 'none' or args.input_norm == 'applied':
        print("Do not apply any normalization!")
    elif args.input_norm == 'zca':
        print("Apply ZCA whitening on data!")
        normalizer = ZCA()
        normalizer.fit(train_loader.x, eps=1e-5)
        train_loader.x = normalizer.transform(train_loader.x)
        test_loader.x = normalizer.transform(test_loader.x)
    elif args.input_norm == 'standard':
        print("Apply Standardization on data!")
        normalizer = Standardization()
        normalizer.fit(train_loader.x)
        train_loader.x = normalizer.transform(train_loader.x)
        test_loader.x = normalizer.transform(test_loader.x)
    else:
        raise ValueError("Do not support 'input_norm'={}".format(
            args.input_norm))
    # ------------------------------------- #
    # ===================================== #

    # Hyperparameters
    # ===================================== #
    hyper_updater = HyperParamUpdater(
        ['lr', 'ema_momentum', 'cent_u_coeff', 'cons_coeff'], [
            args.lr_max, args.ema_momentum_init, args.cent_u_coeff_max,
            args.cons_coeff_max
        ],
        scope='moving_hyperparams')
    # ===================================== #

    # Build model
    # ===================================== #
    # IMPORTANT: Remember to test with No Gaussian Noise
    print("args.gauss_noise: {}".format(args.gauss_noise))

    if args.model_name == "9310gaurav":
        main_classifier = MainClassifier_9310gaurav(
            num_classes=num_classes, use_gauss_noise=args.gauss_noise)
    else:
        raise ValueError("Do not support model_name='{}'!".format(
            args.model_name))

    # Input Perturber
    # ------------------------------------- #
    # Input perturber only performs 'translating_pixels' (Both CIFAR-10 and SVHN) here
    input_perturber = InputPerturber(
        normalizer=None,  # We do not use normalizer here!
        flip_horizontally=args.flip_horizontally,
        flip_vertically=False,  # We do not flip images vertically!
        translating_pixels=args.translating_pixels,
        noise_std=0.0)  # We do not add noise here!
    # ------------------------------------- #

    # Main model
    # ------------------------------------- #
    model = MeanTeacher(x_shape=x_shape,
                        y_shape=num_classes,
                        main_classifier=main_classifier,
                        input_perturber=input_perturber,
                        cons_mode=args.cons_mode,
                        ema_momentum=hyper_updater.variables['ema_momentum'],
                        cons_4_unlabeled_only=args.cons_4_unlabeled_only,
                        weight_decay=args.weight_decay)

    loss_coeff_dict = {
        'cross_ent_l': args.cross_ent_l,
        'cond_ent_u': hyper_updater.variables['cent_u_coeff'],
        'cons': hyper_updater.variables['cons_coeff'],
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_list(trainable_only=False)
    # ------------------------------------- #
    # ===================================== #

    # Build optimizer
    # ===================================== #
    losses = model.get_loss()
    train_params = model.get_train_params()
    opt_AE = tf.train.MomentumOptimizer(
        learning_rate=hyper_updater.variables['lr'],
        momentum=args.lr_momentum,
        use_nesterov=True)

    # Contain both batch norm update and teacher param update
    update_ops = model.get_all_update_ops()
    print("update_ops: {}".format(update_ops))
    with tf.control_dependencies(update_ops):
        train_op_AE = opt_AE.minimize(loss=losses['loss'],
                                      var_list=train_params['loss'])
    # ===================================== #

    # Create directories
    # ===================================== #
    asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset"))
    img_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img"))
    log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log"))
    train_log_file = os.path.join(log_dir, "train.log")

    summary_dir = make_dir_if_not_exist(
        os.path.join(args.output_dir, "summary_tf"))
    model_dir = make_dir_if_not_exist(os.path.join(args.output_dir,
                                                   "model_tf"))
    # ===================================== #

    # Create session
    # ===================================== #
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    train_helper = SimpleTrainHelper(log_dir=summary_dir,
                                     save_dir=model_dir,
                                     max_to_keep=args.num_save,
                                     max_to_keep_best=args.num_save_best)
    train_helper.initialize(sess,
                            init_variables=True,
                            create_summary_writer=True)
    # ===================================== #

    # Start training
    # ===================================== #
    # Summarizer
    # ------------------------------------- #
    fetch_keys_AE_l = ['acc_y_l', 'cross_ent_l']
    fetch_keys_AE_u = ['acc_y_u', 'cond_ent_u', 'cons']
    # To compare between MDL loss and xent+consistency to see whether MDL loss
    # is a better indicator for generalization compared to xent+consistency or not
    fetch_keys_AE = fetch_keys_AE_l + fetch_keys_AE_u
    train_summarizer = ScalarSummarizer([(key, 'mean')
                                         for key in fetch_keys_AE])

    fetch_keys_test = ['acc_y', 'acc_y_stu']
    eval_summarizer = ScalarSummarizer([(key, 'mean')
                                        for key in fetch_keys_test])
    # ------------------------------------- #

    # Data sampler
    # ------------------------------------- #
    # The number of labeled data varies during training
    if args.batch_size_labeled <= 0:
        sampler = ContinuousIndexSampler(train_loader.num_data,
                                         args.batch_size,
                                         shuffle=True)
        sampling_separately = False
        print("batch_size_l, batch_size_u vary but their sum={}!".format(
            args.batch_size))

    elif 0 < args.batch_size_labeled < args.batch_size:
        batch_size_l = args.batch_size_labeled
        batch_size_u = args.batch_size - args.batch_size_labeled
        print("batch_size_l/batch_size_u: {}/{}".format(
            batch_size_l, batch_size_u))

        # IMPORTANT: Here we must use 'train_loader.labeled_ids' and 'train_loader.unlabeled_ids',
        # NOT 'train_loader.num_labeled_data' and 'train_loader.num_unlabeled_data'
        sampler_l = ContinuousIndexSampler(train_loader.labeled_ids,
                                           batch_size_l,
                                           shuffle=True)
        sampler_u = ContinuousIndexSampler(train_loader.unlabeled_ids,
                                           batch_size_u,
                                           shuffle=True)
        sampler = ContinuousIndexSamplerGroup(sampler_l, sampler_u)
        sampling_separately = True

    else:
        raise ValueError(
            "'args.batch_size_labeled' must be in ({}, {})!".format(
                0, args.batch_size))
    # ------------------------------------- #

    # Annealer
    # ------------------------------------- #
    step_rampup_annealer = StepAnnealing(args.rampup_len_step,
                                         value_0=0,
                                         value_1=1)
    sigmoid_rampup_annealer = SigmoidRampup(args.rampup_len_step)
    sigmoid_rampdown_annealer = SigmoidRampdown(args.rampdown_len_step,
                                                args.steps)
    # ------------------------------------- #

    # Results Tracker
    # ------------------------------------- #
    tracker = BestResultsTracker([('acc_y', 'greater')],
                                 num_best=args.num_save_best)
    # ------------------------------------- #

    import math
    batches_per_epoch = int(math.ceil(train_loader.num_data / args.batch_size))
    global_step = 0
    log_time_start = time()

    for epoch in range(args.epochs):
        if global_step >= args.steps:
            break

        for batch in range(batches_per_epoch):
            if global_step >= args.steps:
                break

            global_step += 1

            # Update hyper parameters
            # ---------------------------------- #
            step_rampup = step_rampup_annealer.get_value(global_step)
            sigmoid_rampup = sigmoid_rampup_annealer.get_value(global_step)
            sigmoid_rampdown = sigmoid_rampdown_annealer.get_value(global_step)

            lr = sigmoid_rampup * sigmoid_rampdown * args.lr_max
            ema_momentum = (
                1.0 - step_rampup
            ) * args.ema_momentum_init + step_rampup * args.ema_momentum_final
            cent_u_coeff = sigmoid_rampup * args.cent_u_coeff_max
            cons_coeff = sigmoid_rampup * args.cons_coeff_max

            hyper_updater.update(sess,
                                 feed_dict={
                                     'lr': lr,
                                     'ema_momentum': ema_momentum,
                                     'cent_u_coeff': cent_u_coeff,
                                     'cons_coeff': cons_coeff
                                 })
            hyper_vals = hyper_updater.get_value(sess)
            hyper_vals['sigmoid_rampup'] = sigmoid_rampup
            hyper_vals['sigmoid_rampdown'] = sigmoid_rampdown
            hyper_vals['step_rampup'] = step_rampup
            # ---------------------------------- #

            # Train model
            # ---------------------------------- #
            if sampling_separately:
                # print("Sample separately!")
                batch_ids_l, batch_ids_u = sampler.sample_group_of_ids()

                xl, yl, label_flag_l = train_loader.fetch_batch(batch_ids_l)
                xu, yu, label_flag_u = train_loader.fetch_batch(batch_ids_u)
                assert np.all(label_flag_l), "'label_flag_l: {}'".format(
                    label_flag_l)
                assert not np.any(label_flag_u), "'label_flag_u: {}'".format(
                    label_flag_u)

                x = np.concatenate([xl, xu], axis=0)
                y = np.concatenate([yl, yu], axis=0)
                label_flag = np.concatenate([label_flag_l, label_flag_u],
                                            axis=0)
            else:
                # print("Sample jointly!")
                batch_ids = sampler.sample_ids()
                x, y, label_flag = train_loader.fetch_batch(batch_ids)

            _, AEm = sess.run(
                [train_op_AE,
                 model.get_output(fetch_keys_AE, as_dict=True)],
                feed_dict={
                    model.is_train: True,
                    model.x_ph: x,
                    model.y_ph: y,
                    model.label_flag_ph: label_flag
                })

            batch_results = AEm
            train_summarizer.accumulate(batch_results, args.batch_size)
            # ---------------------------------- #

            if global_step % args.save_freq == 0:
                train_helper.save(sess, global_step)

            if global_step % args.log_freq == 0:
                log_time_end = time()
                log_time_gap = (log_time_end - log_time_start)
                log_time_start = log_time_end

                summaries, results = train_summarizer.get_summaries_and_reset(
                    summary_prefix='train')
                train_helper.add_summaries(summaries, global_step)
                train_helper.add_summaries(
                    custom_tf_scalar_summaries(hyper_vals,
                                               prefix="moving_hyper"),
                    global_step)

                log_str = "\n[MeanTeacher ({})/{}, {}], " \
                          "Epoch {}/{}, Batch {}/{} Step {} ({:.2f}s) (train)".format(
                              args.dataset, args.model_name, args.run, epoch, args.epochs,
                              batch, batches_per_epoch, global_step-1, log_time_gap) + \
                          "\n" + ", ".join(["{}: {:.4f}".format(key, results[key])
                                            for key in fetch_keys_AE_l]) + \
                          "\n" + ", ".join(["{}: {:.4f}".format(key, results[key])
                                            for key in fetch_keys_AE_u]) + \
                          "\n" + ", ".join(["{}: {:.4f}".format(key, hyper_vals[key])
                                           for key in hyper_vals])

                print(log_str)
                with open(train_log_file, "a") as f:
                    f.write(log_str)
                    f.write("\n")
                f.close()

            if global_step % args.eval_freq == 0:
                for batch_ids in iterate_data(test_loader.num_data,
                                              args.batch_size,
                                              shuffle=False,
                                              include_remaining=True):
                    x, y = test_loader.fetch_batch(batch_ids)

                    batch_results = sess.run(model.get_output(fetch_keys_test,
                                                              as_dict=True),
                                             feed_dict={
                                                 model.is_train: False,
                                                 model.x_ph: x,
                                                 model.y_ph: y
                                             })

                    eval_summarizer.accumulate(batch_results, len(batch_ids))

                summaries, results = eval_summarizer.get_summaries_and_reset(
                    summary_prefix='test')
                train_helper.add_summaries(summaries, global_step)

                log_str = "Epoch {}/{}, Batch {}/{} (test), acc_y: {:.4f}, acc_y_stu: {:.4f}".format(
                    epoch, args.epochs, batch, batches_per_epoch,
                    results['acc_y'], results['acc_y_stu'])

                print(log_str)
                with open(train_log_file, "a") as f:
                    f.write(log_str)
                    f.write("\n")
                f.close()

                is_better = tracker.check_and_update(results, global_step)
                if is_better['acc_y']:
                    train_helper.save_best(sess, global_step=global_step)

    # Last save
    train_helper.save(sess, global_step)
Esempio n. 2
0
def main(args):
    # =====================================
    # Preparation
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)
    num_train = celebA_loader.num_train_data
    num_test = celebA_loader.num_test_data

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))

    args.output_dir = os.path.join(args.output_dir, args.enc_dec_model,
                                   args.run)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    else:
        if args.force_rm_dir:
            import shutil
            shutil.rmtree(args.output_dir, ignore_errors=True)
            print("Removed '{}'".format(args.output_dir))
        else:
            raise ValueError("Output directory '{}' existed. 'force_rm_dir' "
                             "must be set to True!".format(args.output_dir))
        os.mkdir(args.output_dir)

    save_args(os.path.join(args.output_dir, 'config.json'), args)
    # pp.pprint(args.__dict__)

    # =====================================
    # Instantiate models
    # =====================================
    # Only use activation for encoder and decoder
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'Dz_tc_loss': args.Dz_tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_list()

    loss = model.get_loss()
    train_params = model.get_train_params()

    opt_Dz = tf.train.AdamOptimizer(learning_rate=args.lr_Dz,
                                    beta1=args.beta1_Dz,
                                    beta2=args.beta2_Dz)
    opt_vae = tf.train.AdamOptimizer(learning_rate=args.lr_vae,
                                     beta1=args.beta1_vae,
                                     beta2=args.beta2_vae)

    with tf.control_dependencies(model.get_all_update_ops()):
        train_op_Dz = opt_Dz.minimize(loss=loss['Dz_loss'],
                                      var_list=train_params['Dz_loss'])
        train_op_D = train_op_Dz

        train_op_vae = opt_vae.minimize(loss=loss['vae_loss'],
                                        var_list=train_params['vae_loss'])

    # =====================================
    # TF Graph Handler
    asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset"))
    img_gen_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_gen"))
    img_rec_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_rec"))
    img_itpl_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_itpl"))

    log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log"))
    train_log_file = os.path.join(log_dir, "train.log")

    summary_dir = make_dir_if_not_exist(
        os.path.join(args.output_dir, "summary_tf"))
    model_dir = make_dir_if_not_exist(os.path.join(args.output_dir,
                                                   "model_tf"))

    train_helper = SimpleTrainHelper(
        log_dir=summary_dir,
        save_dir=model_dir,
        max_to_keep=3,
        max_to_keep_best=1,
    )
    # =====================================

    # =====================================
    # Training Loop
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)
    train_helper.initialize(sess,
                            init_variables=True,
                            create_summary_writer=True)

    Dz_fetch_keys = [
        'Dz_loss', 'Dz_tc_loss', 'Dz_loss_normal', 'Dz_loss_factor',
        'Dz_avg_prob_normal', 'Dz_avg_prob_factor', 'gp0_z_tc'
    ]
    D_fetch_keys = Dz_fetch_keys
    vae_fetch_keys = ['vae_loss', 'rec_x', 'kld_loss', 'tc_loss']

    train_sampler = ContinuousIndexSampler(num_train,
                                           args.batch_size,
                                           shuffle=True)

    import math
    num_batch_per_epochs = int(math.ceil(num_train / args.batch_size))

    global_step = 0
    for epoch in range(args.epochs):
        for _ in range(num_batch_per_epochs):
            global_step += 1

            batch_ids = train_sampler.sample_ids()
            x = celebA_loader.sample_images_from_dataset(
                sess, 'train', batch_ids)

            z = np.random.randn(len(x), args.z_dim)

            batch_ids_2 = np.random.choice(num_train, size=len(batch_ids))
            xa = celebA_loader.sample_images_from_dataset(
                sess, 'train', batch_ids_2)

            for i in range(args.D_steps):
                _, Dm = sess.run(
                    [train_op_D,
                     model.get_output(D_fetch_keys, as_dict=True)],
                    feed_dict={
                        model.is_train: True,
                        model.x_ph: x,
                        model.z_ph: z,
                        model.xa_ph: xa
                    })

            for i in range(args.vae_steps):
                _, VAEm = sess.run(
                    [
                        train_op_vae,
                        model.get_output(vae_fetch_keys, as_dict=True)
                    ],
                    feed_dict={
                        model.is_train: True,
                        model.x_ph: x,
                        model.z_ph: z,
                        model.xa_ph: xa
                    })

            if global_step % args.save_freq == 0:
                train_helper.save(sess, global_step)

            if global_step % args.log_freq == 0:
                log_str = "\n[FactorVAE (celebA)/{}, {}]".format(args.enc_dec_model, args.run) + \
                          "\nEpoch {}/{}, Step {}, vae_loss: {:.4f}, Dz_loss: {:.4f}, Dz_tc_loss: {:.4f}".format(
                              epoch, args.epochs, global_step, VAEm['vae_loss'], Dm['Dz_loss'], Dm['Dz_tc_loss']) + \
                          "\nrec_x: {:.4f}, kld_loss: {:.4f}, tc_loss: {:.4f}".format(
                              VAEm['rec_x'], VAEm['kld_loss'], VAEm['tc_loss']) + \
                          "\nDz_loss_normal: {:.4f}, Dz_loss_factor: {:.4f}".format(
                              Dm['Dz_loss_normal'], Dm['Dz_loss_factor']) + \
                          "\nDz_avg_prob_normal: {:.4f}, Dz_avg_prob_factor: {:.4f}".format(
                              Dm['Dz_avg_prob_normal'], Dm['Dz_avg_prob_factor']) + \
                          "\ngp0_z_tc_coeff: {:.4f}, gp0_z_tc: {:.4f}".format(args.gp0_z_tc_coeff, Dm['gp0_z_tc'])

                print(log_str)
                with open(train_log_file, "a") as f:
                    f.write(log_str)
                    f.write("\n")
                f.close()

                train_helper.add_summary(
                    custom_tf_scalar_summary('vae_loss',
                                             VAEm['vae_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('rec_x',
                                             VAEm['rec_x'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('kld_loss',
                                             VAEm['kld_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('tc_loss',
                                             VAEm['tc_loss'],
                                             prefix='train'), global_step)

                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_tc_loss',
                                             Dm['Dz_tc_loss'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_loss_normal',
                                             Dm['Dz_loss_normal'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_loss_factor',
                                             Dm['Dz_loss_factor'],
                                             prefix='train'), global_step)

                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_prob_normal',
                                             Dm['Dz_avg_prob_normal'],
                                             prefix='train'), global_step)
                train_helper.add_summary(
                    custom_tf_scalar_summary('Dz_prob_factor',
                                             Dm['Dz_avg_prob_factor'],
                                             prefix='train'), global_step)

            if global_step % args.viz_gen_freq == 0:
                # Generate images
                # ------------------------- #
                z = np.random.randn(64, args.z_dim)
                img_file = os.path.join(img_gen_dir,
                                        'step[%d]_gen_test.png' % global_step)

                model.generate_images(
                    img_file,
                    sess,
                    z,
                    block_shape=[8, 8],
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

            if global_step % args.viz_rec_freq == 0:
                # Reconstruct images
                # ------------------------- #
                x = celebA_loader.sample_images_from_dataset(
                    sess, 'test',
                    np.random.choice(num_test, size=64, replace=False))

                img_file = os.path.join(img_rec_dir,
                                        'step[%d]_rec_test.png' % global_step)

                model.reconstruct_images(
                    img_file,
                    sess,
                    x,
                    block_shape=[8, 8],
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

            if global_step % args.viz_itpl_freq == 0:
                # Interpolate images
                # ------------------------- #
                x1 = celebA_loader.sample_images_from_dataset(
                    sess, 'test',
                    np.random.choice(num_test, size=12, replace=False))
                x2 = celebA_loader.sample_images_from_dataset(
                    sess, 'test',
                    np.random.choice(num_test, size=12, replace=False))

                img_file = os.path.join(img_itpl_dir,
                                        'step[%d]_itpl_test.png' % global_step)

                model.interpolate_images(
                    img_file,
                    sess,
                    x1,
                    x2,
                    num_itpl_points=12,
                    batch_on_row=True,
                    batch_size=args.batch_size,
                    dec_output_2_img_func=binary_float_to_uint8)
                # ------------------------- #

        if epoch % 100 == 0:
            train_helper.save_separately(
                sess,
                model_name="model_epoch{}".format(epoch),
                global_step=global_step)

    # Last save
    train_helper.save(sess, global_step)