def VAE_encode_dataset(model_path, dataset, factor_dim, dim_z,
                       M1_hidden_layers_px, M1_hidden_layers_qz):
    batch_size = 100
    VAE = Conv2dVariationalAutoencoder(dim_x=factor_dim,
                                       dim_z=dim_z,
                                       batch_size=batch_size,
                                       hidden_layers_px=M1_hidden_layers_px,
                                       hidden_layers_qz=M1_hidden_layers_qz)

    enc_x_mean = None
    enc_x_var = None
    with VAE.session:
        VAE.saver.restore(VAE.session, model_path)
        for x_batch in utils.feed_numpy(batch_size, dataset):
            if enc_x_mean is None:
                enc_x_mean, enc_x_var = VAE.encode(x_batch)
            else:
                tmp_mean, tmp_var = VAE.encode(x_batch)
                enc_x_mean = np.vstack([enc_x_mean, tmp_mean])
                enc_x_var = np.vstack([enc_x_var, tmp_var])
    print("Shape of enc_x_mean is ", np.shape(enc_x_mean))
    print("Shape of enc_x_var is ", np.shape(enc_x_var))

    return enc_x_mean, enc_x_var
示例#2
0
    def train(
            self,
            x,
            x_valid,
            epochs,  #num_batches,
            save_path=None,
            print_every=1,
            learning_rate=3e-4,
            beta1=0.9,
            beta2=0.999,
            seed=31415,
            stop_iter=100,
            load_path=None,
            draw_img=1):

        self.num_examples = x.shape[0]
        # self.num_batches = num_batches
        assert self.num_examples % self.batch_size == 0, '#Examples % #Batches != 0'

        # self.batch_size = self.num_examples // self.num_batches

        # x_size = x.shape[0]
        # x_valid_size = x.shape[0]
        # x = np.reshape(x, [x_size, self.dim_x, 1, 1])
        # x_valid = np.reshape(x_valid, [x_valid_size, self.dim_x, 1, 1])
        ''' Session and Summary '''
        if save_path is None:
            self.save_path = 'checkpoints/model_CONV2D-VAE_{}.cpkt'.format(
                time.strftime("%m-%d-%H%M%S", time.localtime()))
        else:
            self.save_path = save_path

        np.random.seed(seed)
        tf.set_random_seed(seed)

        with self.G.as_default():

            self.optimiser = tf.train.AdamOptimizer(
                learning_rate=learning_rate, beta1=beta1, beta2=beta2)
            self.train_op = self.optimiser.minimize(self.cost)
            init = tf.global_variables_initializer(
            )  # tf.initialize_all_variables()
            self._test_vars = None

        with self.session as sess:

            sess.run(init)

            # 实际上并没有执行
            if load_path == 'default':
                self.saver.restore(sess, self.save_path)
            elif load_path is not None:
                self.saver.restore(sess, load_path)

            training_cost = 0.
            best_eval_log_lik = -np.inf
            stop_counter = 0

            for epoch in range(epochs):
                ''' Shuffle Data '''
                np.random.shuffle(x)
                ''' Training '''

                for x_batch in utils.feed_numpy(self.batch_size, x):
                    training_result = sess.run([self.train_op, self.cost],
                                               feed_dict={self.x: x_batch})

                    training_cost = training_result[1]
                ''' Evaluation '''

                stop_counter += 1

                # 训练,更新不断跟新参数
                if epoch % print_every == 0:

                    test_vars = tf.get_collection(
                        bookkeeper.GraphKeys.TEST_VARIABLES)
                    if test_vars:
                        if test_vars != self._test_vars:
                            self._test_vars = list(test_vars)
                            self._test_var_init_op = tf.initialize_variables(
                                test_vars)
                        self._test_var_init_op.run()

                    # eval_log_lik, x_recon_eval = \
                    #     sess.run([self.eval_log_lik, self.x_recon_eval],
                    #              feed_dict={self.x: x_valid})
                    eval_log_lik = 0
                    x_recon_eval = 0
                    valid_times = x_valid.shape[0] / self.batch_size
                    for x_valid_batch in utils.feed_numpy(
                            self.batch_size, x_valid):
                        log_lik, recon_eval = sess.run(
                            [self.eval_log_lik, self.x_recon_eval],
                            feed_dict={self.x: x_valid_batch})
                        eval_log_lik += log_lik
                        x_recon_eval += recon_eval
                    eval_log_lik /= valid_times
                    x_recon_eval /= valid_times

                    if eval_log_lik > best_eval_log_lik:
                        best_eval_log_lik = eval_log_lik
                        self.saver.save(sess, self.save_path)
                        stop_counter = 0

                    utils.print_metrics(
                        epoch + 1, ['Training', 'cost', training_cost],
                        ['Validation', 'log-likelihood', eval_log_lik])

                ## 画图
                # if draw_img > 0 and epoch % draw_img == 0:
                #
                # 	import matplotlib
                # 	matplotlib.use('Agg')
                # 	import pylab
                # 	import seaborn as sns
                #
                # 	five_random = np.random.random_integers(x_valid.shape[0], size = 5)
                # 	x_sample = x_valid[five_random]
                # 	x_recon_sample = x_recon_eval[five_random]
                #
                # 	sns.set_style('white')
                # 	f, axes = pylab.subplots(5, 2, figsize=(8,12))
                # 	for i,row in enumerate(axes):
                #
                # 		row[0].imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1)
                # 		im = row[1].imshow(x_recon_sample[i].reshape(28, 28), vmin=0, vmax=1,
                # 			cmap=sns.light_palette((1.0, 0.4980, 0.0549), input="rgb", as_cmap=True))
                #
                # 		pylab.setp([a.get_xticklabels() for a in row], visible=False)
                # 		pylab.setp([a.get_yticklabels() for a in row], visible=False)
                #
                # 	f.subplots_adjust(left=0.0, right=0.9, bottom=0.0, top=1.0)
                # 	cbar_ax = f.add_axes([0.9, 0.1, 0.04, 0.8])
                # 	f.colorbar(im, cax=cbar_ax, use_gridspec=True)
                #
                # 	pylab.tight_layout()
                # 	pylab.savefig('img/recon-'+str(epoch)+'.png', format='png')
                # 	pylab.clf()
                # 	pylab.close('all')

                if stop_counter >= stop_iter:
                    print('Stopping VAE training')
                    print(
                        'No change in validation log-likelihood for {} iterations'
                        .format(stop_iter))
                    print('Best validation log-likelihood: {}'.format(
                        best_eval_log_lik))
                    print('Model saved in {}'.format(self.save_path))
                    break