示例#1
0
    def test(epoch):
        data_loader.eval()
        test_gen_loss_accum, test_dis_loss_accum, test_likelihood_accum, test_kl_accum, batch_size_accum = 0, 0, 0, 0, 0
        start = time.time()

        for batch_idx, curr_batch_size, batch in data_loader: 
            test_generator_cost_np, test_discriminator_cost_np = sess.run([test_outs_dict['generator_cost'], test_outs_dict['discriminator_cost']], 
                feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))            

            test_gen_loss_accum += curr_batch_size*test_generator_cost_np
            test_dis_loss_accum += curr_batch_size*test_discriminator_cost_np
            batch_size_accum += curr_batch_size

        end = time.time();
        print('====> Average Test: Epoch {}\tGenerator Cost: {:.3f}\tDiscriminator Cost: {:.3f}\tTime: {:.3f}'.format(
              epoch, test_gen_loss_accum/batch_size_accum, test_dis_loss_accum/batch_size_accum, (end - start)))

        with open(global_args.exp_dir+"test_traces.txt", "a") as text_file:
            text_file.write(str(test_gen_loss_accum/batch_size_accum) + ', ' + str(test_dis_loss_accum/batch_size_accum) + '\n')

        if not (data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader'):
            distributions.visualizeProductDistribution3(sess, input_dict_func(batch, np.asarray([epoch, float(epoch>2)])), batch, inference_obs_dist, transport_dist, rec_dist, generative_dict['obs_dist'], 
                save_dir = global_args.exp_dir+'Visualization/Test/Random/', postfix='test_'+str(epoch))
            batch['observed']['data']['image'] = fixed_batch_data
            distributions.visualizeProductDistribution3(sess, input_dict_func(batch, np.asarray([epoch, float(epoch>2)])), batch, inference_obs_dist, transport_dist, rec_dist, generative_dict['obs_dist'], 
                save_dir = global_args.exp_dir+'Visualization/Test/Fixed/', postfix='test_fixed_'+str(epoch))
示例#2
0
    def train(epoch):
        if epoch < 10: critic_rate, generator_rate = 1, 5
        else: critic_rate, generator_rate = 1, 5
        # else: critic_rate, generator_rate = 5, 1
        critic_rate, generator_rate = 5, 5

        data_loader.train()
        train_gen_loss_accum, train_dis_loss_accum, train_likelihood_accum, train_kl_accum, batch_size_accum = 0, 0, 0, 0, 0
        start = time.time()
        for batch_idx, curr_batch_size, batch in data_loader:

            trans_train_step_np = sess.run([train_transport_step_tf],
                                           feed_dict=input_dict_func(
                                               batch, np.asarray([
                                                   0,
                                               ])))
            if batch_idx % critic_rate != 0: continue
            disc_train_step_np = sess.run([train_discriminator_step_tf],
                                          feed_dict=input_dict_func(
                                              batch, np.asarray([
                                                  0,
                                              ])))
            if batch_idx % (critic_rate * generator_rate) != 0: continue
            gen_train_step_np, generator_cost_np, discriminator_cost_np, variational_cost_np, mean_transport_cost_np = \
                sess.run([train_generator_step_tf, train_outs_dict['generator_cost'], train_outs_dict['discriminator_cost'],
                          train_outs_dict['variational_cost'], train_outs_dict['mean_transport_cost']],
                          feed_dict = input_dict_func(batch, np.asarray([0,])))

            # trans_train_step_np, disc_train_step_np, gen_train_step_np, generator_cost_np, discriminator_cost_np, variational_cost_np, mean_transport_cost_np = \
            #     sess.run([train_transport_step_tf, train_discriminator_step_tf, train_generator_step_tf, train_outs_dict['generator_cost'], train_outs_dict['discriminator_cost'],
            #               train_outs_dict['variational_cost'], train_outs_dict['mean_transport_cost']],
            #               feed_dict = input_dict_func(batch, np.asarray([0,])))

            max_discriminator_weight = sess.run(max_abs_discriminator_vars)
            train_gen_loss_accum += curr_batch_size * generator_cost_np
            train_dis_loss_accum += curr_batch_size * discriminator_cost_np
            batch_size_accum += curr_batch_size

            if batch_idx % global_args.log_interval == 0:
                end = time.time()
                print(
                    'Train: Epoch {} [{:7d} ()]\tGenerator Cost: {:.6f}\tDiscriminator Cost: {:.6f}\tTime: {:.3f}, variational cost {:.3f}, transport cost {:.3f}'
                    .format(epoch, batch_idx * curr_batch_size,
                            generator_cost_np, discriminator_cost_np,
                            (end - start), variational_cost_np,
                            mean_transport_cost_np))

                with open(global_args.exp_dir + "training_traces.txt",
                          "a") as text_file:
                    text_file.write(
                        str(generator_cost_np) + ', ' +
                        str(discriminator_cost_np) + '\n')
                start = time.time()

        summary_str = sess.run(merged_summaries,
                               feed_dict=input_dict_func(
                                   batch, np.asarray([
                                       0,
                                   ])))
        summary_writer.add_summary(summary_str,
                                   (tf.train.global_step(sess, global_step)))

        checkpoint_time = 1
        if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
            checkpoint_time = 20

        if epoch % checkpoint_time == 0:
            print(
                '====> Average Train: Epoch: {}\tGenerator Cost: {:.6f}\tDiscriminator Cost: {:.6f}'
                .format(epoch, train_gen_loss_accum / batch_size_accum,
                        train_dis_loss_accum / batch_size_accum))

            # helper.draw_bar_plot(rate_similarity_gen_np[:,0,0], y_min_max = [0,1], save_dir=global_args.exp_dir+'Visualization/inversion_weight/', postfix='inversion_weight'+str(epoch))
            # helper.draw_bar_plot(effective_z_cost_np[:,0,0], thres = [np.mean(effective_z_cost_np), np.max(effective_z_cost_np)], save_dir=global_args.exp_dir+'Visualization/inversion_cost/', postfix='inversion_cost'+str(epoch))
            # helper.draw_bar_plot(disc_cost_gen_np[:,0,0], thres = [0, 0], save_dir=global_args.exp_dir+'Visualization/disc_cost/', postfix='disc_cost'+str(epoch))

            if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
                helper.visualize_datasets(sess,
                                          input_dict_func(batch),
                                          data_loader.dataset,
                                          generative_dict['obs_sample_out'],
                                          generative_dict['latent_sample_out'],
                                          save_dir=global_args.exp_dir +
                                          'Visualization/',
                                          postfix=str(epoch))

                xmin, xmax, ymin, ymax, X_dense, Y_dense = -3.5, 3.5, -3.5, 3.5, 250, 250
                xlist = np.linspace(xmin, xmax, X_dense)
                ylist = np.linspace(ymin, ymax, Y_dense)
                X, Y = np.meshgrid(xlist, ylist)
                XY = np.concatenate(
                    [X.reshape(-1, 1), Y.reshape(-1, 1)], axis=1)

                batch['observed']['data']['flat'] = XY[:, np.newaxis, :]
                disc_cost_real_np = sess.run(train_outs_dict['critic_real'],
                                             feed_dict=input_dict_func(
                                                 batch, np.asarray([
                                                     0,
                                                 ])))

                f = np.reshape(disc_cost_real_np[:, 0, 0], [Y_dense, X_dense])
                helper.plot_ffs(X,
                                Y,
                                f,
                                save_dir=global_args.exp_dir +
                                'Visualization/discriminator_function/',
                                postfix='discriminator_function' + str(epoch))
            else:
                distributions.visualizeProductDistribution3(
                    sess,
                    input_dict_func(batch),
                    batch,
                    inference_obs_dist,
                    transport_dist,
                    generative_dict['obs_dist'],
                    save_dir=global_args.exp_dir +
                    'Visualization/Train/Random/',
                    postfix='train_' + str(epoch))
                batch['observed']['data']['image'] = fixed_batch_data
                distributions.visualizeProductDistribution3(
                    sess,
                    input_dict_func(batch),
                    batch,
                    inference_obs_dist,
                    transport_dist,
                    generative_dict['obs_dist'],
                    save_dir=global_args.exp_dir +
                    'Visualization/Train/Fixed/',
                    postfix='train_fixed_' + str(epoch))

            checkpoint_path1 = global_args.exp_dir + 'checkpoint/'
            checkpoint_path2 = global_args.exp_dir + 'checkpoint2/'
            print('====> Saving checkpoint. Epoch: ', epoch)
            start_tmp = time.time()
            helper.save_checkpoint(saver, sess, global_step, checkpoint_path1)
            end_tmp = time.time()
            print(
                'Checkpoint path: ' + checkpoint_path1 + '   ====> It took: ',
                end_tmp - start_tmp)
            if epoch % 60 == 0:
                print('====> Saving checkpoint backup. Epoch: ', epoch)
                start_tmp = time.time()
                helper.save_checkpoint(saver, sess, global_step,
                                       checkpoint_path2)
                end_tmp = time.time()
                print(
                    'Checkpoint path: ' + checkpoint_path2 +
                    '   ====> It took: ', end_tmp - start_tmp)
示例#3
0
    def train(epoch):
        if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
            if epoch < 10: critic_rate, generator_rate = 1, 4
            else: critic_rate, generator_rate = 1, 4
        else:
            if epoch < 10: critic_rate, generator_rate = 2, 2
            else: critic_rate, generator_rate = 2, 2

        trans_steps, disc_steps, gen_steps = 0, 0, 0
        turn = 'gen'
        in_between_vis = 3
        report_count = 0
        data_loader.train()
        train_gen_loss_accum, train_dis_loss_accum, train_likelihood_accum, train_kl_accum, batch_size_accum = 0, 0, 0, 0, 0
        start = time.time();
        for batch_idx, curr_batch_size, batch in data_loader: 


            if turn == 'trans' or turn == 'disc' or turn == 'gen':
                # trans_train_step_np, generator_cost_np, discriminator_cost_np, transporter_cost_np, mean_transport_cost_np = \
                # sess.run([train_transport_step_tf, train_outs_dict['generator_cost'], train_outs_dict['discriminator_cost'], train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost']], 
                # feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))

                trans_train_step_np, transporter_cost_np, mean_transport_cost_np = \
                sess.run([train_transport_step_tf, train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost']], 
                feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))
                trans_steps = trans_steps+1
                
            if turn == 'disc' or turn == 'gen':
                disc_train_step_np, transporter_cost_np, mean_transport_cost_np = \
                sess.run([train_discriminator_step_tf, train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost']],
                feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))
                disc_steps = disc_steps+1

            if turn == 'gen':
                gen_train_step_np, generator_cost_np, discriminator_cost_np, transporter_cost_np, mean_transport_cost_np, convex_mask_np = \
                sess.run([train_generator_step_tf, train_outs_dict['generator_cost'], train_outs_dict['discriminator_cost'], 
                train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost'], train_outs_dict['convex_mask']],
                feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))
                gen_steps = gen_steps+1
           
            if turn == 'trans':
                # if discriminator_cost_np/np.abs(generator_cost_np) < 0.5: turn = 'disc'
                if trans_steps % (critic_rate*generator_rate) == 0: turn = 'disc'
            elif turn == 'disc':
                if disc_steps % generator_rate == 0: turn = 'gen'
            elif turn == 'gen':
                if gen_steps % 1 == 0: turn = 'trans'

            max_discriminator_weight = sess.run(max_abs_discriminator_vars)
            train_gen_loss_accum += curr_batch_size*generator_cost_np
            train_dis_loss_accum += curr_batch_size*discriminator_cost_np
            batch_size_accum += curr_batch_size

            if batch_idx % global_args.log_interval == 0:
                report_count = report_count+1

                end = time.time();
                print('Train: Epoch {} [{:7d} ()]\tGenerator Cost: {:.3f}\tDiscriminator Cost: {:.3f}\tTime: {:.3f}, variational cost {:.3f}, transport cost {:.3f}, t {:2d} d {:2d} g {:2d}, t/d {:.1f}, t/g {:.1f}, mask {:.3f}'.format(
                      epoch, batch_idx * curr_batch_size, generator_cost_np, discriminator_cost_np, (end - start), transporter_cost_np, mean_transport_cost_np, trans_steps, disc_steps, gen_steps, trans_steps/disc_steps, trans_steps/gen_steps, np.mean(convex_mask_np)))

                with open(global_args.exp_dir+"training_traces.txt", "a") as text_file:
                    text_file.write(str(generator_cost_np) + ', ' + str(discriminator_cost_np) + '\n')
                start = time.time()

                if in_between_vis>0 and report_count % in_between_vis: 
                    distributions.visualizeProductDistribution3(sess, input_dict_func(batch, np.asarray([epoch, float(epoch>2)])), batch, inference_obs_dist, transport_dist, rec_dist, generative_dict['obs_dist'], 
                    save_dir=global_args.exp_dir+'Visualization/Train/Random/', postfix='train_'+str(epoch))
                    batch['observed']['data']['image'] = fixed_batch_data
                    distributions.visualizeProductDistribution3(sess, input_dict_func(batch, np.asarray([epoch, float(epoch>2)])), batch, inference_obs_dist, transport_dist, rec_dist, generative_dict['obs_dist'], 
                    save_dir=global_args.exp_dir+'Visualization/Train/Fixed/', postfix='train_fixed_'+str(epoch))

            # prev_mean_transport_cost_np = mean_transport_cost_np
            # prev_transporter_cost_np = transporter_cost_np

        summary_str = sess.run(merged_summaries, feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))
        summary_writer.add_summary(summary_str, (tf.train.global_step(sess, global_step)))
        
        checkpoint_time = 1
        if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
            checkpoint_time = 20

        if epoch % checkpoint_time == 0:
            print('====> Average Train: Epoch: {}\tGenerator Cost: {:.3f}\tDiscriminator Cost: {:.3f}'.format(
                  epoch, train_gen_loss_accum/batch_size_accum, train_dis_loss_accum/batch_size_accum))

            # helper.draw_bar_plot(effective_z_cost_np[:,0,0], thres = [np.mean(effective_z_cost_np), np.max(effective_z_cost_np)], save_dir=global_args.exp_dir+'Visualization/inversion_cost/', postfix='inversion_cost'+str(epoch))
            # helper.draw_bar_plot(disc_cost_gen_np[:,0,0], thres = [0, 0], save_dir=global_args.exp_dir+'Visualization/disc_cost/', postfix='disc_cost'+str(epoch))
            
            if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
                
                helper.visualize_datasets(sess, input_dict_func(batch), data_loader.dataset, generative_dict['obs_sample_out'],
                                          generative_dict['latent_sample_out'], train_outs_dict['transport_sample'], train_outs_dict['input_sample'],
                                          save_dir=global_args.exp_dir+'Visualization/', postfix=str(epoch)) 

                xmin, xmax, ymin, ymax, X_dense, Y_dense = -3.5, 3.5, -3.5, 3.5, 250, 250
                xlist = np.linspace(xmin, xmax, X_dense)
                ylist = np.linspace(ymin, ymax, Y_dense)
                X, Y = np.meshgrid(xlist, ylist)
                XY = np.concatenate([X.reshape(-1,1), Y.reshape(-1,1)], axis=1)

                batch['observed']['data']['flat'] = XY[:, np.newaxis, :]
                disc_cost_real_np = sess.run(train_outs_dict['critic_real'], feed_dict = input_dict_func(batch, np.asarray([epoch, float(epoch>2)])))

                f = np.reshape(disc_cost_real_np[:,0,0], [Y_dense, X_dense])
                helper.plot_ffs(X, Y, f, save_dir=global_args.exp_dir+'Visualization/discriminator_function/', postfix='discriminator_function'+str(epoch))
                
            else:
                # helper.draw_bar_plot(convex_mask_np, y_min_max = [0,1], save_dir=global_args.exp_dir+'Visualization/convex_mask/', postfix='convex_mask'+str(epoch))
                distributions.visualizeProductDistribution3(sess, input_dict_func(batch, np.asarray([epoch, float(epoch>2)])), batch, inference_obs_dist, transport_dist, rec_dist, generative_dict['obs_dist'], 
                save_dir=global_args.exp_dir+'Visualization/Train/Random/', postfix='train_'+str(epoch))
                batch['observed']['data']['image'] = fixed_batch_data
                distributions.visualizeProductDistribution3(sess, input_dict_func(batch, np.asarray([epoch, float(epoch>2)])), batch, inference_obs_dist, transport_dist, rec_dist, generative_dict['obs_dist'], 
                save_dir=global_args.exp_dir+'Visualization/Train/Fixed/', postfix='train_fixed_'+str(epoch))

            checkpoint_path1 = global_args.exp_dir+'checkpoint/'
            checkpoint_path2 = global_args.exp_dir+'checkpoint2/'
            print('====> Saving checkpoint. Epoch: ', epoch); start_tmp = time.time()
            # helper.save_checkpoint(saver, sess, global_step, checkpoint_path1) 
            end_tmp = time.time(); print('Checkpoint path: '+checkpoint_path1+'   ====> It took: ', end_tmp - start_tmp)
            if epoch % 60 == 0: 
                print('====> Saving checkpoint backup. Epoch: ', epoch); start_tmp = time.time()
                # helper.save_checkpoint(saver, sess, global_step, checkpoint_path2) 
                end_tmp = time.time(); print('Checkpoint path: '+checkpoint_path2+'   ====> It took: ', end_tmp - start_tmp)
示例#4
0
    def train(epoch):
        global curr_meanp, curr_stdp, p_real

        trans_steps, disc_steps, gen_steps = 0, 0, 0
        turn = 'gen'
        in_between_vis = 5
        report_count = 0
        data_loader.train()
        train_gen_loss_accum, train_dis_loss_accum, train_likelihood_accum, train_kl_accum, batch_size_accum = 0, 0, 0, 0, 0
        start = time.time()

        hyperparam_dict = {'b_identity': 0.}
        helper.update_dict_from_file(hyperparam_dict, './hyperparam_file.py')

        for batch_idx, curr_batch_size, batch in data_loader:
            gen_bool, disc_bool, trans_bool = scheduler(epoch, batch_idx)
            hyper_param = np.asarray(
                [epoch, hyperparam_dict['b_identity'], curr_meanp, p_real])
            curr_feed_dict = input_dict_func(batch, hyper_param)

            if trans_bool:
                trans_train_step_np, transporter_cost_np, mean_transport_cost_np = \
                sess.run([train_transport_step_tf, train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost']],
                feed_dict = curr_feed_dict)
                trans_steps = trans_steps + 1
                # p_real = np.exp(-mean_transport_cost_np)

            if disc_bool:
                disc_train_step_np, transporter_cost_np, mean_transport_cost_np = \
                sess.run([train_discriminator_step_tf, train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost']],
                feed_dict = curr_feed_dict)
                disc_steps = disc_steps + 1

            if gen_bool:
                gen_train_step_np, generator_cost_np, discriminator_cost_np, transporter_cost_np, mean_transport_cost_np, convex_mask_np, critic_real_np, \
                critic_gen_np, expected_log_pdf_prior_np, expected_log_pdf_agg_post_np, aaa = \
                sess.run([train_generator_step_tf, train_outs_dict['generator_cost'], train_outs_dict['discriminator_cost'],
                train_outs_dict['transporter_cost'], train_outs_dict['mean_transport_cost'], train_outs_dict['convex_mask'], train_outs_dict['critic_real'],
                train_outs_dict['critic_gen'], train_outs_dict['expected_log_pdf_prior'], train_outs_dict['expected_log_pdf_agg_post'], model.AAA],
                feed_dict = curr_feed_dict)
                gen_steps = gen_steps + 1

            max_discriminator_weight = sess.run(max_abs_discriminator_vars)
            train_gen_loss_accum += curr_batch_size * generator_cost_np
            train_dis_loss_accum += curr_batch_size * discriminator_cost_np
            batch_size_accum += curr_batch_size

            curr_meanp_delta, curr_stdp_delta = np.mean(critic_gen_np), np.std(
                critic_gen_np)

            curr_meanp = 0.9 * curr_meanp + 0.1 * curr_meanp_delta
            curr_stdp = 0.9 * curr_stdp + 0.1 * curr_stdp_delta

            if batch_idx % global_args.log_interval == 0:
                report_count = report_count + 1

                end = time.time()
                # print('Train: Epoch {} [{:7d} ()]\tGenerator Cost: {:.3f}\tDiscriminator Cost: {:.3f}\tTime: {:.3f}, variational cost {:.3f}, transport cost {:.3f}, t {:2d} d {:2d} g {:2d}, t/d {:.1f}, t/g {:.1f}, mask {:.3f}, m {:.2f}, std {:.2f}'.format(
                #       epoch, batch_idx * curr_batch_size, generator_cost_np, discriminator_cost_np, (end - start), transporter_cost_np, mean_transport_cost_np, trans_steps, disc_steps, gen_steps, trans_steps/disc_steps, trans_steps/gen_steps, np.mean(convex_mask_np[:,0]),  curr_meanp, curr_stdp))
                print(
                    'Train: Epoch {} [{:7d} ()]\tGenerator Cost: {:.3f}\tDiscriminator Cost: {:.3f}\tTime: {:.3f}, variational cost {:.3f}, transport cost {:.3f}, t {:2d} d {:2d} g {:2d}, t/d {:.1f}, t/g {:.1f}, pri {:.1f}, agg {:.1f}, mask {:.2f}, p_real {:.2f}'
                    .format(epoch, batch_idx * curr_batch_size,
                            generator_cost_np, discriminator_cost_np,
                            (end - start), transporter_cost_np,
                            mean_transport_cost_np, trans_steps, disc_steps,
                            gen_steps, trans_steps / disc_steps,
                            trans_steps / gen_steps, expected_log_pdf_prior_np,
                            expected_log_pdf_agg_post_np,
                            np.mean(convex_mask_np[:, 0]), p_real))

                with open(global_args.exp_dir + "training_traces.txt",
                          "a") as text_file:
                    text_file.write(
                        str(generator_cost_np) + ', ' +
                        str(discriminator_cost_np) + '\n')
                start = time.time()

                if data_loader.__module__ == 'datasetLoaders.CelebA1QueueLoader' and in_between_vis > 0 and report_count % in_between_vis:
                    distributions.visualizeProductDistribution3(
                        sess,
                        curr_feed_dict,
                        batch,
                        inference_obs_dist,
                        transport_dist,
                        rec_dist,
                        generative_dict['obs_dist'],
                        save_dir=global_args.exp_dir +
                        'Visualization/Train/Random/',
                        postfix='train_' + str(epoch))
                    batch['observed']['data']['image'] = fixed_batch_data
                    curr_feed_dict = input_dict_func(batch, hyper_param)
                    distributions.visualizeProductDistribution3(
                        sess,
                        curr_feed_dict,
                        batch,
                        inference_obs_dist,
                        transport_dist,
                        rec_dist,
                        generative_dict['obs_dist'],
                        save_dir=global_args.exp_dir +
                        'Visualization/Train/Fixed/',
                        postfix='train_fixed_' + str(epoch))

        summary_str = sess.run(merged_summaries, feed_dict=curr_feed_dict)
        summary_writer.add_summary(summary_str,
                                   (tf.train.global_step(sess, global_step)))

        checkpoint_time = 1
        if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':
            checkpoint_time = 20

        if epoch % checkpoint_time == 0:
            print(
                '====> Average Train: Epoch: {}\tGenerator Cost: {:.3f}\tDiscriminator Cost: {:.3f}'
                .format(epoch, train_gen_loss_accum / batch_size_accum,
                        train_dis_loss_accum / batch_size_accum))

            # helper.draw_bar_plot(effective_z_cost_np[:,0,0], thres = [np.mean(effective_z_cost_np), np.max(effective_z_cost_np)], save_dir=global_args.exp_dir+'Visualization/inversion_cost/', postfix='inversion_cost'+str(epoch))
            # helper.draw_bar_plot(disc_cost_gen_np[:,0,0], thres = [0, 0], save_dir=global_args.exp_dir+'Visualization/disc_cost/', postfix='disc_cost'+str(epoch))

            if data_loader.__module__ == 'datasetLoaders.RandomManifoldDataLoader' or data_loader.__module__ == 'datasetLoaders.ToyDataLoader':

                helper.visualize_datasets(sess,
                                          input_dict_func(batch),
                                          data_loader.dataset,
                                          generative_dict['obs_sample_out'],
                                          generative_dict['latent_sample_out'],
                                          train_outs_dict['transport_sample'],
                                          train_outs_dict['input_sample'],
                                          save_dir=global_args.exp_dir +
                                          'Visualization/',
                                          postfix=str(epoch))

                xmin, xmax, ymin, ymax, X_dense, Y_dense = -2.5, 2.5, -2.5, 2.5, 250, 250
                xlist = np.linspace(xmin, xmax, X_dense)
                ylist = np.linspace(ymin, ymax, Y_dense)
                X, Y = np.meshgrid(xlist, ylist)
                XY = np.concatenate(
                    [X.reshape(-1, 1), Y.reshape(-1, 1)], axis=1)

                batch['observed']['data']['flat'] = XY[:, np.newaxis, :]
                disc_cost_real_np = sess.run(train_outs_dict['critic_real'],
                                             feed_dict=input_dict_func(
                                                 batch, hyper_param))

                batch['observed']['data'][
                    'flat'] = data_loader.dataset[:, np.newaxis, :]
                disc_cost_real_real_np = sess.run(
                    train_outs_dict['critic_real'],
                    feed_dict=input_dict_func(batch, hyper_param))

                # disc_mean = disc_cost_real_real_np.max()
                # disc_max = disc_cost_real_real_np.max()
                # disc_min = disc_cost_real_real_np.min()

                disc_mean = disc_cost_real_real_np.mean()
                disc_std = disc_cost_real_real_np.std()
                disc_max = disc_mean + 2 * disc_std
                disc_min = disc_mean - 2 * disc_std

                np.clip(disc_cost_real_np,
                        disc_min,
                        disc_max,
                        out=disc_cost_real_np)
                f = np.reshape(disc_cost_real_np[:, 0, 0], [Y_dense, X_dense])
                helper.plot_ffs(X,
                                Y,
                                f,
                                save_dir=global_args.exp_dir +
                                'Visualization/discriminator_function/',
                                postfix='discriminator_function' + str(epoch))

            else:
                # helper.draw_bar_plot(convex_mask_np, y_min_max = [0,1], save_dir=global_args.exp_dir+'Visualization/convex_mask/', postfix='convex_mask'+str(epoch))
                distributions.visualizeProductDistribution3(
                    sess,
                    curr_feed_dict,
                    batch,
                    inference_obs_dist,
                    transport_dist,
                    rec_dist,
                    generative_dict['obs_dist'],
                    save_dir=global_args.exp_dir +
                    'Visualization/Train/Random/',
                    postfix='train_' + str(epoch))
                batch['observed']['data']['image'] = fixed_batch_data
                curr_feed_dict = input_dict_func(batch, hyper_param)
                distributions.visualizeProductDistribution3(
                    sess,
                    curr_feed_dict,
                    batch,
                    inference_obs_dist,
                    transport_dist,
                    rec_dist,
                    generative_dict['obs_dist'],
                    save_dir=global_args.exp_dir +
                    'Visualization/Train/Fixed/',
                    postfix='train_fixed_' + str(epoch))

            checkpoint_path1 = global_args.exp_dir + 'checkpoint/'
            checkpoint_path2 = global_args.exp_dir + 'checkpoint2/'
            print('====> Saving checkpoint. Epoch: ', epoch)
            start_tmp = time.time()
            # helper.save_checkpoint(saver, sess, global_step, checkpoint_path1)
            end_tmp = time.time()
            print(
                'Checkpoint path: ' + checkpoint_path1 + '   ====> It took: ',
                end_tmp - start_tmp)
            if epoch % 60 == 0:
                print('====> Saving checkpoint backup. Epoch: ', epoch)
                start_tmp = time.time()
                # helper.save_checkpoint(saver, sess, global_step, checkpoint_path2)
                end_tmp = time.time()
                print(
                    'Checkpoint path: ' + checkpoint_path2 +
                    '   ====> It took: ', end_tmp - start_tmp)