示例#1
0
def train_ddi_vae(dim_z, hidden_layers_px, hidden_layers_qz, save_path, learning_rate):
    #############################
    ''' Experiment Parameters '''
    #############################

    # num_batches = 100       #Number of minibatches in a single epoch, num_examples % self.num_batches == 0
    # dim_z = 50              #Dimensionality of latent variable (z)
    epochs = 1000           #Number of epochs through the full dataset
    # learning_rate = 3e-3    #Learning rate of ADAM
    l2_loss = 1e-6          #L2 Regularisation weight
    seed = 31415            #Seed for RNG

    #Neural Networks parameterising p(x|z), q(z|x)
    # hidden_layers_px = [ 600, 1000, 800, 500 ]
    # hidden_layers_qz = [ 600, 1000, 800, 500 ]

    ####################
    ''' Load Dataset '''
    ####################

    # mnist_path = 'mnist/mnist_28.pkl.gz'
    # #Uses anglpy module from original paper (linked at top) to load the dataset
    # train_x, train_y, valid_x, valid_y, test_x, test_y = mnist.load_numpy(mnist_path, binarize_y=True)
    #
    # x_train, y_train = train_x.T, train_y.T
    # x_valid, y_valid = valid_x.T, valid_y.T
    # x_test, y_test = test_x.T, test_y.T

    x_train, y_train, x_valid, y_valid, x_test, y_test = load_dataset("/data/cdy/ykq/ddi_dataset/train_dataset")
    utils.print_metrics(['x_train', x_train.shape[0], x_train.shape[1]],
                        ['y_train', y_train.shape[0], y_train.shape[1]],
                        ['x_valid', x_valid.shape[0], x_valid.shape[1]],
                        ['y_valid', y_valid.shape[0], y_valid.shape[1]],
                        ['x_test', x_test.shape[0], x_test.shape[1]],
                        ['y_test', y_test.shape[0], y_test.shape[1]],
                        )

    dim_x = x_train.shape[1]
    dim_y = y_train.shape[1]

    ######################################
    ''' Train Variational Auto-Encoder '''
    ######################################

    VAE = Conv2dVariationalAutoencoder(   dim_x = dim_x,
                                    dim_z = dim_z,
                                    hidden_layers_px = hidden_layers_px,
                                    hidden_layers_qz = hidden_layers_qz,
                                    l2_loss = l2_loss )

    #draw_img uses pylab and seaborn to draw images of original vs. reconstruction
    #every n iterations (set to 0 to disable)

    VAE.train(  x = x_train, x_valid = x_valid, epochs = epochs, save_path=save_path,
                learning_rate = learning_rate, seed = seed, stop_iter = 30, print_every = 10, draw_img = 0 )
示例#2
0
    def predict_labels(self, x_test, y_test):

        test_vars = tf.get_collection(bookkeeper.GraphKeys.TEST_VARIABLES)
        tf.initialize_variables(test_vars).run()

        x_test_mu = x_test[:, :self.dim_x]
        x_test_lsgms = x_test[:, self.dim_x:2 * self.dim_x]

        accuracy, cross_entropy, precision, recall = \
            self.session.run([self.eval_accuracy, self.eval_cross_entropy, self.eval_precision, self.eval_recall],
                             feed_dict={self.x_labelled_mu: x_test_mu, self.x_labelled_lsgms: x_test_lsgms,
                                        self.y_lab: y_test})

        utils.print_metrics('X', ['Test', 'accuracy', accuracy],
                            ['Test', 'cross-entropy', cross_entropy],
                            ['Test', 'precision', precision],
                            ['Test', 'recall', recall])
示例#3
0
    def train(
            self,
            x,
            x_valid,
            epochs,  #num_batches,
            save_path=None,
            print_every=1,
            learning_rate=3e-4,
            beta1=0.9,
            beta2=0.999,
            seed=31415,
            stop_iter=100,
            load_path=None,
            draw_img=1):

        self.num_examples = x.shape[0]
        # self.num_batches = num_batches
        assert self.num_examples % self.batch_size == 0, '#Examples % #Batches != 0'

        # self.batch_size = self.num_examples // self.num_batches

        # x_size = x.shape[0]
        # x_valid_size = x.shape[0]
        # x = np.reshape(x, [x_size, self.dim_x, 1, 1])
        # x_valid = np.reshape(x_valid, [x_valid_size, self.dim_x, 1, 1])
        ''' Session and Summary '''
        if save_path is None:
            self.save_path = 'checkpoints/model_CONV2D-VAE_{}.cpkt'.format(
                time.strftime("%m-%d-%H%M%S", time.localtime()))
        else:
            self.save_path = save_path

        np.random.seed(seed)
        tf.set_random_seed(seed)

        with self.G.as_default():

            self.optimiser = tf.train.AdamOptimizer(
                learning_rate=learning_rate, beta1=beta1, beta2=beta2)
            self.train_op = self.optimiser.minimize(self.cost)
            init = tf.global_variables_initializer(
            )  # tf.initialize_all_variables()
            self._test_vars = None

        with self.session as sess:

            sess.run(init)

            # 实际上并没有执行
            if load_path == 'default':
                self.saver.restore(sess, self.save_path)
            elif load_path is not None:
                self.saver.restore(sess, load_path)

            training_cost = 0.
            best_eval_log_lik = -np.inf
            stop_counter = 0

            for epoch in range(epochs):
                ''' Shuffle Data '''
                np.random.shuffle(x)
                ''' Training '''

                for x_batch in utils.feed_numpy(self.batch_size, x):
                    training_result = sess.run([self.train_op, self.cost],
                                               feed_dict={self.x: x_batch})

                    training_cost = training_result[1]
                ''' Evaluation '''

                stop_counter += 1

                # 训练,更新不断跟新参数
                if epoch % print_every == 0:

                    test_vars = tf.get_collection(
                        bookkeeper.GraphKeys.TEST_VARIABLES)
                    if test_vars:
                        if test_vars != self._test_vars:
                            self._test_vars = list(test_vars)
                            self._test_var_init_op = tf.initialize_variables(
                                test_vars)
                        self._test_var_init_op.run()

                    # eval_log_lik, x_recon_eval = \
                    #     sess.run([self.eval_log_lik, self.x_recon_eval],
                    #              feed_dict={self.x: x_valid})
                    eval_log_lik = 0
                    x_recon_eval = 0
                    valid_times = x_valid.shape[0] / self.batch_size
                    for x_valid_batch in utils.feed_numpy(
                            self.batch_size, x_valid):
                        log_lik, recon_eval = sess.run(
                            [self.eval_log_lik, self.x_recon_eval],
                            feed_dict={self.x: x_valid_batch})
                        eval_log_lik += log_lik
                        x_recon_eval += recon_eval
                    eval_log_lik /= valid_times
                    x_recon_eval /= valid_times

                    if eval_log_lik > best_eval_log_lik:
                        best_eval_log_lik = eval_log_lik
                        self.saver.save(sess, self.save_path)
                        stop_counter = 0

                    utils.print_metrics(
                        epoch + 1, ['Training', 'cost', training_cost],
                        ['Validation', 'log-likelihood', eval_log_lik])

                ## 画图
                # if draw_img > 0 and epoch % draw_img == 0:
                #
                # 	import matplotlib
                # 	matplotlib.use('Agg')
                # 	import pylab
                # 	import seaborn as sns
                #
                # 	five_random = np.random.random_integers(x_valid.shape[0], size = 5)
                # 	x_sample = x_valid[five_random]
                # 	x_recon_sample = x_recon_eval[five_random]
                #
                # 	sns.set_style('white')
                # 	f, axes = pylab.subplots(5, 2, figsize=(8,12))
                # 	for i,row in enumerate(axes):
                #
                # 		row[0].imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1)
                # 		im = row[1].imshow(x_recon_sample[i].reshape(28, 28), vmin=0, vmax=1,
                # 			cmap=sns.light_palette((1.0, 0.4980, 0.0549), input="rgb", as_cmap=True))
                #
                # 		pylab.setp([a.get_xticklabels() for a in row], visible=False)
                # 		pylab.setp([a.get_yticklabels() for a in row], visible=False)
                #
                # 	f.subplots_adjust(left=0.0, right=0.9, bottom=0.0, top=1.0)
                # 	cbar_ax = f.add_axes([0.9, 0.1, 0.04, 0.8])
                # 	f.colorbar(im, cax=cbar_ax, use_gridspec=True)
                #
                # 	pylab.tight_layout()
                # 	pylab.savefig('img/recon-'+str(epoch)+'.png', format='png')
                # 	pylab.clf()
                # 	pylab.close('all')

                if stop_counter >= stop_iter:
                    print('Stopping VAE training')
                    print(
                        'No change in validation log-likelihood for {} iterations'
                        .format(stop_iter))
                    print('Best validation log-likelihood: {}'.format(
                        best_eval_log_lik))
                    print('Model saved in {}'.format(self.save_path))
                    break
示例#4
0
    def train(self,
              x_labelled,
              y,
              x_unlabelled,
              epochs,
              x_valid,
              y_valid,
              print_every=1,
              learning_rate=3e-4,
              beta1=0.9,
              beta2=0.999,
              seed=31415,
              stop_iter=100,
              save_path=None,
              load_path=None):
        ''' Session and Summary '''
        if save_path is None:
            self.save_path = 'checkpoints/model_CONV2D_GC_{}-{}-{}_{}.cpkt'.format(
                self.num_lab, learning_rate, self.batch_size, time.time())
        else:
            self.save_path = save_path

        np.random.seed(seed)
        tf.set_random_seed(seed)

        with self.G.as_default():

            self.optimiser = tf.train.AdamOptimizer(
                learning_rate=learning_rate, beta1=beta1, beta2=beta2)
            self.train_op = self.optimiser.minimize(self.cost)
            init = tf.initialize_all_variables()
            self._test_vars = None
        _data_labelled = np.hstack([x_labelled, y])
        _data_unlabelled = x_unlabelled
        x_valid_mu, x_valid_lsgms = x_valid[:, :int(
            self.dim_x)], x_valid[:, int(self.dim_x):int(2 * self.dim_x)]

        with self.session as sess:

            sess.run(init)
            if load_path == 'default':
                self.saver.restore(sess, self.save_path)
            elif load_path is not None:
                self.saver.restore(sess, load_path)

            best_eval_accuracy = 0.
            best_train_accuracy = 0.
            stop_counter = 0

            # print("****lab_clf", self.weights['lab_clf'])
            # print("****lab_recon_clf", self.weights['lab_recon_clf'])
            # print("****ulab_clf", self.weights['ulab_clf'])
            # print("****ulab_recon_clf", self.weights['ulab_recon_clf'])
            # print("****L_loss", self.weights['L_loss'])
            # print("****U_loss", self.weights['U_loss'])

            for epoch in range(epochs):
                ''' Shuffle Data '''
                np.random.shuffle(_data_labelled)
                np.random.shuffle(_data_unlabelled)
                ''' Training '''

                for x_l_mu, x_l_lsgms, y, x_u_mu, x_u_lsgms in utils.feed_numpy_semisupervised(
                        self.num_lab_batch, self.num_ulab_batch,
                        _data_labelled[:, :2 * self.dim_x],
                        _data_labelled[:, 2 * self.dim_x:], _data_unlabelled):
                    training_result = sess.run(
                        [self.train_op, self.cost],
                        feed_dict={
                            self.x_labelled_mu: x_l_mu,
                            self.x_labelled_lsgms: x_l_lsgms,
                            self.y_lab: y,
                            self.x_unlabelled_mu: x_u_mu,
                            self.x_unlabelled_lsgms: x_u_lsgms
                        })

                    training_cost = training_result[1]
                ''' Evaluation '''

                stop_counter += 1

                if epoch % print_every == 0:

                    test_vars = tf.get_collection(
                        bookkeeper.GraphKeys.TEST_VARIABLES)
                    if test_vars:
                        if test_vars != self._test_vars:
                            self._test_vars = list(test_vars)
                            self._test_var_init_op = tf.initialize_variables(
                                test_vars)
                        self._test_var_init_op.run()

                    # eval_accuracy, eval_cross_entropy = \
                    #     sess.run([self.eval_accuracy, self.eval_cross_entropy],
                    #              feed_dict={self.x_labelled_mu: x_valid_mu,
                    #                         self.x_labelled_lsgms: x_valid_lsgms,
                    #                         self.y_lab: y_valid})
                    eval_accuracy = 0
                    eval_cross_entropy = 0
                    assert x_valid_mu.shape[
                        0] % self.num_lab_batch == 0, '#Valid % #Batches != 0'
                    x_valid_batch_count = x_valid_mu.shape[
                        0] // self.num_lab_batch
                    print("valid: self.num_lab_batch: ", self.num_lab_batch)
                    # print("y_")
                    for x_valid_mu_batch, x_valid_lsgms_batch, y_valid_batch in pt.train.feed_numpy(
                            self.num_lab_batch, x_valid_mu, x_valid_lsgms,
                            y_valid):
                        tmp_accuracy, tmp_cross_entropy = sess.run(
                            [self.eval_accuracy, self.eval_cross_entropy],
                            feed_dict={
                                self.x_labelled_mu: x_valid_mu_batch,
                                self.x_labelled_lsgms: x_valid_lsgms_batch,
                                self.y_lab: y_valid_batch
                            })
                        eval_accuracy += tmp_accuracy
                        eval_cross_entropy += tmp_cross_entropy
                    eval_accuracy /= x_valid_batch_count
                    eval_cross_entropy /= x_valid_batch_count
                    if eval_accuracy > best_eval_accuracy:
                        best_eval_accuracy = eval_accuracy
                        self.saver.save(sess, self.save_path)
                        stop_counter = 0
                        print("accuracy update: ", best_eval_accuracy)

                    utils.print_metrics(
                        epoch + 1, ['Training', 'cost', training_cost],
                        ['Validation', 'accuracy', eval_accuracy],
                        ['Validation', 'cross-entropy', eval_cross_entropy])

                if stop_counter >= stop_iter:
                    print('Stopping GC training')
                    print('No change in validation accuracy for {} iterations'.
                          format(stop_iter))
                    print('Best validation accuracy: {}'.format(
                        best_eval_accuracy))
                    print('Model saved in {}'.format(self.save_path))
                    break