def visualize():
    FLAGS = get_args()
    plot_size = 20

    valid_data = MNISTData('test',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)

    with tf.variable_scope('VAE') as scope:
        model = VAE(n_code=FLAGS.ncode, wd=0)
        model.create_train_model()

    with tf.variable_scope('VAE') as scope:
        scope.reuse_variables()
        valid_model = VAE(n_code=FLAGS.ncode, wd=0)
        valid_model.create_generate_model(b_size=400)

    visualizer = Visualizer(model, save_path=SAVE_PATH)
    generator = Generator(generate_model=valid_model, save_path=SAVE_PATH)

    z = distribution.interpolate(plot_size=plot_size)
    z = np.reshape(z, (plot_size * plot_size, 2))

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, '{}vae-epoch-{}'.format(SAVE_PATH, FLAGS.load))
        visualizer.viz_2Dlatent_variable(sess, valid_data)
        generator.generate_samples(sess, plot_size=plot_size, z=z)
def train():
    FLAGS = get_args()
    train_data = MNISTData('train',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    train_data.setup(epoch_val=0, batch_size=FLAGS.bsize)
    valid_data = MNISTData('test',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)

    with tf.variable_scope('VAE') as scope:
        model = VAE(n_code=FLAGS.ncode, wd=0)
        model.create_train_model()

    with tf.variable_scope('VAE') as scope:
        scope.reuse_variables()
        valid_model = VAE(n_code=FLAGS.ncode, wd=0)
        valid_model.create_generate_model(b_size=400)

    trainer = Trainer(model,
                      valid_model,
                      train_data,
                      init_lr=FLAGS.lr,
                      save_path=SAVE_PATH)
    if FLAGS.ncode == 2:
        z = distribution.interpolate(plot_size=20)
        z = np.reshape(z, (400, 2))
        visualizer = Visualizer(model, save_path=SAVE_PATH)
    else:
        z = None
    generator = Generator(generate_model=valid_model, save_path=SAVE_PATH)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)

        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_epoch(sess, summary_writer=writer)
            trainer.valid_epoch(sess, summary_writer=writer)
            if epoch_id % 10 == 0:
                saver.save(sess, '{}vae-epoch-{}'.format(SAVE_PATH, epoch_id))
                if FLAGS.ncode == 2:
                    generator.generate_samples(sess,
                                               plot_size=20,
                                               z=z,
                                               file_id=epoch_id)
                    visualizer.viz_2Dlatent_variable(sess,
                                                     valid_data,
                                                     file_id=epoch_id)
    def generate_samples(self, sess, plot_size, manifold=False, file_id=None):
        # if z is None:
        #     gen_im = sess.run(self._generate_op)
        # else:
        n_samples = plot_size * plot_size

        label_indices = None
        if self._use_label:
            cur_r = 0
            label_indices = []
            cur_label = -1
            while cur_r < plot_size:
                cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0
                row_label = np.ones(plot_size) * cur_label
                label_indices.extend(row_label)
                cur_r += 1

        if manifold:
            if self._dist == 'gaussian':
                random_code = distribution.interpolate(
                    plot_size=plot_size, interpolate_range=[-3, 3, -3, 3])
                self.viz_samples(sess, random_code, plot_size, file_id=file_id)
            else:
                for mode_id in range(self._n_labels):
                    random_code = distribution.interpolate_gm(
                        plot_size=plot_size,
                        interpolate_range=[-1., 1., -0.2, 0.2],
                        mode_id=mode_id,
                        n_mode=self._n_labels)
                    self.viz_samples(sess,
                                     random_code,
                                     plot_size,
                                     file_id='{}_{}'.format(file_id, mode_id))
        else:
            if self._dist == 'gaussian':
                random_code = distribution.diagonal_gaussian(
                    n_samples, self._g_model.n_code, mean=0, var=1.0)
            else:
                random_code = distribution.gaussian_mixture(
                    n_samples,
                    n_dim=self._g_model.n_code,
                    n_labels=self._n_labels,
                    x_var=0.5,
                    y_var=0.1,
                    label_indices=label_indices)

            self.viz_samples(sess, random_code, plot_size, file_id=file_id)