예제 #1
0
def main():
    # parse arguments
    args = parse_args()

    # parameters
    batch_size = 10000
    n_dim = 2

    # get samples from prior
    if args.prior_type == 'mixGaussian':
        z_id_ = np.random.randint(0, 10, size=[batch_size])
        z = prior.gaussian_mixture(batch_size, n_dim, label_indices=z_id_)
    elif args.prior_type == 'swiss_roll':
        z_id_ = np.random.randint(0, 10, size=[batch_size])
        z = prior.swiss_roll(batch_size, n_dim, label_indices=z_id_)
    elif args.prior_type == 'normal':
        z, z_id_ = prior.gaussian(batch_size, n_dim, use_label_info=True)
    else:
        raise Exception("[!] There is no option for " + args.prior_type)

    # plot
    plt.figure(figsize=(8, 6))
    plt.scatter(z[:, 0],
                z[:, 1],
                c=z_id_,
                marker='o',
                edgecolor='none',
                cmap=discrete_cmap(10, 'jet'))
    plt.colorbar(ticks=range(10))
    plt.grid(True)
    axes = plt.gca()
    axes.set_xlim([-4.5, 4.5])
    axes.set_ylim([-4.5, 4.5])
    plt.show()
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        """ random condition, random noise """

        z_sample = prior.gaussian(self.batch_size, self.z_dim)

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

        save_images(
            samples[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim], self.result_dir + '/' +
            self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
예제 #3
0
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        """ random condition, random noise """

        z_sample = prior.gaussian(self.batch_size, self.z_dim)

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

        save_images(
            samples[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            check_folder(self.result_dir + '/' + self.model_dir) + '/' +
            self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')

        # 对于噪声向量维度为2的时候,模型可以理解为一个分类状况,此时可以用散点图表示出来各类之间的分类情况
        # 需要注意的是此时的z_dim=2
        """ learned manifold """
        if self.z_dim == 2:
            assert self.z_dim == 2

            z_tot = None
            id_tot = None
            for idx in range(0, 100):
                #randomly sampling
                id = np.random.randint(0, self.num_batches)
                batch_images = self.data_X[id * self.batch_size:(id + 1) *
                                           self.batch_size]
                # 用于记录图片对应的标签用于类别显示
                batch_labels = self.data_y[id * self.batch_size:(id + 1) *
                                           self.batch_size]

                z = self.sess.run(self.mu,
                                  feed_dict={self.inputs: batch_images})

                if idx == 0:
                    z_tot = z
                    id_tot = batch_labels
                else:
                    z_tot = np.concatenate((z_tot, z), axis=0)
                    id_tot = np.concatenate((id_tot, batch_labels), axis=0)

            save_scattered_image(
                z_tot,
                id_tot,
                -4,
                4,
                name=check_folder(self.result_dir + '/' + self.model_dir) +
                '/' + self.model_name + '_epoch%03d' % epoch +
                '_learned_manifold.png')
예제 #4
0
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """
        y = np.random.choice(self.y_dim, self.batch_size)
        y_one_hot = np.zeros((self.batch_size, self.y_dim))
        y_one_hot[np.arange(self.batch_size), y] = 1

        z_sample = prior.gaussian(self.batch_size, self.z_dim)

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d'
                    % epoch + '_test_all_classes.png')

        """ specified condition, random noise """
        n_styles = 10  # must be less than or equal to self.batch_size

        np.random.seed()
        si = np.random.choice(self.batch_size, n_styles)

        for l in range(self.y_dim):
            y = np.zeros(self.batch_size, dtype=np.int64) + l
            y_one_hot = np.zeros((self.batch_size, self.y_dim))
            y_one_hot[np.arange(self.batch_size), y] = 1

            samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
            save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                        check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d'
                        % epoch + '_test_class_%d.png' % l)

            samples = samples[si, :, :, :]

            if l == 0:
                all_samples = samples
            else:
                all_samples = np.concatenate((all_samples, samples), axis=0)

        """ save merged images to check style-consistency """
        canvas = np.zeros_like(all_samples)
        for s in range(n_styles):
            for c in range(self.y_dim):
                canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]

        save_images(canvas, [n_styles, self.y_dim],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d'
                    % epoch + '_test_all_classes_style_by_style.png')
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """
        y = np.random.choice(self.y_dim, self.batch_size)
        y_one_hot = np.zeros((self.batch_size, self.y_dim))
        y_one_hot[np.arange(self.batch_size), y] = 1

        z_sample = prior.gaussian(self.batch_size, self.z_dim)

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')

        """ specified condition, random noise """
        n_styles = 10  # must be less than or equal to self.batch_size

        np.random.seed()
        si = np.random.choice(self.batch_size, n_styles)

        for l in range(self.y_dim):
            y = np.zeros(self.batch_size, dtype=np.int64) + l
            y_one_hot = np.zeros((self.batch_size, self.y_dim))
            y_one_hot[np.arange(self.batch_size), y] = 1

            samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
            save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                        check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)

            samples = samples[si, :, :, :]

            if l == 0:
                all_samples = samples
            else:
                all_samples = np.concatenate((all_samples, samples), axis=0)

        """ save merged images to check style-consistency """
        canvas = np.zeros_like(all_samples)
        for s in range(n_styles):
            for c in range(self.y_dim):
                canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]

        save_images(canvas, [n_styles, self.y_dim],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
예제 #6
0
    def sample(self, cl, n):
        y = np.zeros(self.batch_size, dtype=np.int64) + cl
        y_one_hot = np.zeros((self.batch_size, self.y_dim))
        y_one_hot[np.arange(self.batch_size), y] = 1

        all_samples = []
        while len(all_samples) < n:
            z_sample = prior.gaussian(self.batch_size, self.z_dim)

            samples = self.sess.run(self.fake_images,
                                    feed_dict={
                                        self.z: z_sample,
                                        self.y: y_one_hot
                                    })

            if len(all_samples) == 0:
                all_samples = samples
            else:
                all_samples = np.concatenate((all_samples, samples), axis=0)

        return all_samples
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """

        z_sample = prior.gaussian(self.batch_size, self.z_dim)

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                    check_folder(
                        self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')

        """ learned manifold """
        if self.z_dim == 2:
            assert self.z_dim == 2

            z_tot = None
            id_tot = None
            for idx in range(0, 100):
                #randomly sampling
                id = np.random.randint(0,self.num_batches)
                batch_images = self.data_X[id * self.batch_size:(id + 1) * self.batch_size]
                batch_labels = self.data_y[id * self.batch_size:(id + 1) * self.batch_size]

                z = self.sess.run(self.mu, feed_dict={self.inputs: batch_images})

                if idx == 0:
                    z_tot = z
                    id_tot = batch_labels
                else:
                    z_tot = np.concatenate((z_tot, z), axis=0)
                    id_tot = np.concatenate((id_tot, batch_labels), axis=0)

            save_scattered_image(z_tot, id_tot, -4, 4, name=check_folder(
                self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_learned_manifold.png')
예제 #8
0
    def train(self):

        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = prior.gaussian(self.batch_size, self.z_dim)

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(
            self.log_dir + '/' + self.model_name, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.num_batches)
            start_batch_id = checkpoint_counter - start_epoch * self.num_batches
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx * self.batch_size:(idx + 1) *
                                           self.batch_size]
                batch_z = prior.gaussian(self.batch_size, self.z_dim)

                # update autoencoder
                _, summary_str, loss, nll_loss, kl_loss = self.sess.run(
                    [
                        self.optim, self.merged_summary_op, self.loss,
                        self.neg_loglikelihood, self.KL_divergence
                    ],
                    feed_dict={
                        self.inputs: batch_images,
                        self.z: batch_z
                    })
                self.writer.add_summary(summary_str, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss))

                # save training results for every 300 steps
                if np.mod(counter, 300) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={self.z: self.sample_z})

                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(
                        samples[:manifold_h * manifold_w, :, :, :],
                        [manifold_h, manifold_w], './' +
                        check_folder(self.result_dir + '/' + self.model_dir) +
                        '/' + self.model_name +
                        '_train_{:02d}_{:04d}.png'.format(epoch, idx))

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # show temporal results
            self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)
    def train(self):

        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = prior.gaussian(self.batch_size, self.z_dim)
        self.test_labels = self.data_y[0:self.batch_size]

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.num_batches)
            start_batch_id = checkpoint_counter - start_epoch * self.num_batches
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
                batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
                batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)

                # update autoencoder
                _, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence],
                                               feed_dict={self.inputs: batch_images, self.y: batch_labels, self.z: batch_z})
                self.writer.add_summary(summary_str, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss))

                # save training results for every 300 steps
                if np.mod(counter, 300) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={self.z: self.sample_z, self.y: self.test_labels})
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
                                './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
                                    epoch, idx))

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # show temporal results
            self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)
예제 #10
0
    def train(self):
        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = prior.gaussian(self.batch_size, self.z_dim)

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.num_batches)
            start_batch_id = checkpoint_counter - start_epoch * self.num_batches
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_videos = self.data_X_train[idx*self.batch_size:(idx+1)*self.batch_size]
                batch_z = np.random.randn(self.batch_size, self.z_dim).astype(np.float32)
                
                # Extract initial and final frames
                batch_images1 = batch_videos[:,0,:,:,:].copy()
                batch_images2 = batch_videos[:,-1,:,:,:].copy()

                # update autoencoder
                _, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence],
                                               feed_dict={self.inputs: batch_videos, self.z: batch_z, self.img1:batch_images1,
                                                         self.img2:batch_images2})
                self.writer.add_summary(summary_str, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss))

                # save training results for every 100 steps
                if np.mod(counter, 100) == 0:
                    samples = self.sess.run(self.fake_videos,
                                            feed_dict={self.z: self.sample_z,self.img1:batch_images1,self.img2:batch_images2})
                    #tot_num_samples = min(self.sample_num, self.batch_size)
                    #tot_num_samples = 10
                    tot_num_samples = self.batch_size
                    #samples = np.clip(samples, 0.0, 1.0)
                    samples = samples*255.
                    samples = samples.astype(np.uint8)
                    
                    for ind_vid in range(tot_num_samples):
                        uri = './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name 
                        uri_vid = uri + '_train_{:02d}_{:04d}_{:04d}.mp4'.format(epoch, idx, ind_vid)
                        uri_im1 = uri + '_train_{:02d}_{:04d}_{:04d}_img1.jpg'.format(epoch, idx, ind_vid)
                        
                        imageio.mimwrite(uri_vid, samples[ind_vid], fps=10)
                        
                        img1_array = batch_images1[ind_vid,...]*255.
                        img1_array = img1_array.astype(np.uint8)
                        img1 = Image.fromarray(img1_array, 'RGB')
                        img1.save(uri_im1)
                        
                        

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            if (epoch%50 == 0 and epoch > 0):
                self.save(self.checkpoint_dir, counter)

            # show temporal results
            # self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)
예제 #11
0
def main(args):
    """ parameters """
    RESULTS_DIR = args.results_path

    # network architecture

    n_hidden = args.n_hidden
    dim_img = IMAGE_SIZE_MNIST**2  # number of pixels for a MNIST image
    dim_z = 2  # to visualize learned manifold

    # train
    n_epochs = args.num_epochs
    batch_size = args.batch_size
    learn_rate = args.learn_rate

    # Plot
    PRR = args.PRR  # Plot Reproduce Result
    PRR_n_img_x = args.PRR_n_img_x  # number of images along x-axis in a canvas
    PRR_n_img_y = args.PRR_n_img_y  # number of images along y-axis in a canvas
    PRR_resize_factor = args.PRR_resize_factor  # resize factor for each image in a canvas

    PMLR = args.PMLR  # Plot Manifold Learning Result
    PMLR_n_img_x = args.PMLR_n_img_x  # number of images along x-axis in a canvas
    PMLR_n_img_y = args.PMLR_n_img_y  # number of images along y-axis in a canvas
    PMLR_resize_factor = args.PMLR_resize_factor  # resize factor for each image in a canvas
    PMLR_z_range = args.PMLR_z_range  # range for random latent vector
    PMLR_n_samples = args.PMLR_n_samples  # number of labeled samples to plot a map from input data space to the latent space
    """ prepare MNIST data """

    train_total_data, train_size, _, _, test_data, test_labels = mnist_data.prepare_MNIST_data(
    )
    n_samples = train_size
    """ build graph """

    # input placeholders
    # In denoising-autoencoder, x_hat == x + noise, otherwise x_hat == x
    x_hat = tf.placeholder(tf.float32, shape=[None, dim_img], name='input_img')
    x = tf.placeholder(tf.float32, shape=[None, dim_img], name='target_img')
    x_id = tf.placeholder(tf.float32, shape=[None, 10], name='input_img_label')

    # dropout
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')

    # input for PMLR
    z_in = tf.placeholder(tf.float32,
                          shape=[None, dim_z],
                          name='latent_variable')

    # samples drawn from prior distribution
    z_sample = tf.placeholder(tf.float32,
                              shape=[None, dim_z],
                              name='prior_sample')
    z_id = tf.placeholder(tf.float32,
                          shape=[None, 10],
                          name='prior_sample_label')

    # network architecture
    y, z, neg_marginal_likelihood, D_loss, G_loss = aae.adversarial_autoencoder(
        x_hat, x, x_id, z_sample, z_id, dim_img, dim_z, n_hidden, keep_prob)

    # optimization
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if "discriminator" in var.name]
    g_vars = [var for var in t_vars if "MLP_encoder" in var.name]
    ae_vars = [
        var for var in t_vars if "MLP_encoder" or "MLP_decoder" in var.name
    ]

    train_op_ae = tf.train.AdamOptimizer(learn_rate).minimize(
        neg_marginal_likelihood, var_list=ae_vars)
    train_op_d = tf.train.AdamOptimizer(learn_rate / 5).minimize(
        D_loss, var_list=d_vars)
    train_op_g = tf.train.AdamOptimizer(learn_rate).minimize(G_loss,
                                                             var_list=g_vars)
    """ training """

    # Plot for reproduce performance
    if PRR:
        PRR = plot_utils.Plot_Reproduce_Performance(RESULTS_DIR, PRR_n_img_x,
                                                    PRR_n_img_y,
                                                    IMAGE_SIZE_MNIST,
                                                    IMAGE_SIZE_MNIST,
                                                    PRR_resize_factor)

        x_PRR = test_data[0:PRR.n_tot_imgs, :]

        x_PRR_img = x_PRR.reshape(PRR.n_tot_imgs, IMAGE_SIZE_MNIST,
                                  IMAGE_SIZE_MNIST)
        PRR.save_images(x_PRR_img, name='input.jpg')

    # Plot for manifold learning result
    if PMLR and dim_z == 2:

        PMLR = plot_utils.Plot_Manifold_Learning_Result(
            RESULTS_DIR, PMLR_n_img_x, PMLR_n_img_y, IMAGE_SIZE_MNIST,
            IMAGE_SIZE_MNIST, PMLR_resize_factor, PMLR_z_range)

        x_PMLR = test_data[0:PMLR_n_samples, :]
        id_PMLR = test_labels[0:PMLR_n_samples, :]

        decoded = aae.decoder(z_in, dim_img, n_hidden)

    # train
    total_batch = int(n_samples / batch_size)
    min_tot_loss = 1e99

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer(), feed_dict={keep_prob: 0.9})

        for epoch in range(n_epochs):

            # Random shuffling
            np.random.shuffle(train_total_data)
            train_data_ = train_total_data[:, :-mnist_data.NUM_LABELS]
            train_label_ = train_total_data[:, -mnist_data.NUM_LABELS:]

            # Loop over all batches
            for i in range(total_batch):
                # Compute the offset of the current minibatch in the data.
                offset = (i * batch_size) % (n_samples)
                batch_xs_input = train_data_[offset:(offset + batch_size), :]
                batch_ids_input = train_label_[offset:(offset + batch_size), :]
                batch_xs_target = batch_xs_input

                # draw samples from prior distribution
                if args.prior_type == 'mixGaussian':
                    z_id_ = np.random.randint(0, 10, size=[batch_size])
                    samples = prior.gaussian_mixture(batch_size,
                                                     dim_z,
                                                     label_indices=z_id_)
                elif args.prior_type == 'swiss_roll':
                    z_id_ = np.random.randint(0, 10, size=[batch_size])
                    samples = prior.swiss_roll(batch_size,
                                               dim_z,
                                               label_indices=z_id_)
                elif args.prior_type == 'normal':
                    samples, z_id_ = prior.gaussian(batch_size,
                                                    dim_z,
                                                    use_label_info=True)
                else:
                    raise Exception("[!] There is no option for " +
                                    args.prior_type)

                z_id_one_hot_vector = np.zeros((batch_size, 10))
                z_id_one_hot_vector[np.arange(batch_size), z_id_] = 1

                # reconstruction loss
                _, loss_likelihood = sess.run(
                    (train_op_ae, neg_marginal_likelihood),
                    feed_dict={
                        x_hat: batch_xs_input,
                        x: batch_xs_target,
                        x_id: batch_ids_input,
                        z_sample: samples,
                        z_id: z_id_one_hot_vector,
                        keep_prob: 0.9
                    })

                # discriminator loss
                _, d_loss = sess.run(
                    (train_op_d, D_loss),
                    feed_dict={
                        x_hat: batch_xs_input,
                        x: batch_xs_target,
                        x_id: batch_ids_input,
                        z_sample: samples,
                        z_id: z_id_one_hot_vector,
                        keep_prob: 0.9
                    })

                # generator loss
                for _ in range(2):
                    _, g_loss = sess.run(
                        (train_op_g, G_loss),
                        feed_dict={
                            x_hat: batch_xs_input,
                            x: batch_xs_target,
                            x_id: batch_ids_input,
                            z_sample: samples,
                            z_id: z_id_one_hot_vector,
                            keep_prob: 0.9
                        })

            tot_loss = loss_likelihood + d_loss + g_loss

            # print cost every epoch
            print(
                "epoch %d: L_tot %03.2f L_likelihood %03.2f d_loss %03.2f g_loss %03.2f"
                % (epoch, tot_loss, loss_likelihood, d_loss, g_loss))

            # if minimum loss is updated or final epoch, plot results
            if epoch % 2 == 0 or min_tot_loss > tot_loss or epoch + 1 == n_epochs:
                min_tot_loss = tot_loss
                # Plot for reproduce performance
                if PRR:
                    y_PRR = sess.run(y, feed_dict={x_hat: x_PRR, keep_prob: 1})
                    y_PRR_img = y_PRR.reshape(PRR.n_tot_imgs, IMAGE_SIZE_MNIST,
                                              IMAGE_SIZE_MNIST)
                    PRR.save_images(y_PRR_img,
                                    name="/PRR_epoch_%02d" % (epoch) + ".jpg")

                # Plot for manifold learning result
                if PMLR and dim_z == 2:
                    y_PMLR = sess.run(decoded,
                                      feed_dict={
                                          z_in: PMLR.z,
                                          keep_prob: 1
                                      })
                    y_PMLR_img = y_PMLR.reshape(PMLR.n_tot_imgs,
                                                IMAGE_SIZE_MNIST,
                                                IMAGE_SIZE_MNIST)
                    PMLR.save_images(y_PMLR_img,
                                     name="/PMLR_epoch_%02d" % (epoch) +
                                     ".jpg")

                    # plot distribution of labeled images
                    z_PMLR = sess.run(z,
                                      feed_dict={
                                          x_hat: x_PMLR,
                                          keep_prob: 1
                                      })
                    PMLR.save_scattered_image(z_PMLR,
                                              id_PMLR,
                                              name="/PMLR_map_epoch_%02d" %
                                              (epoch) + ".jpg")
예제 #12
0
def main(args):

    np.random.seed(1337)
    """ parameters """
    RESULTS_DIR = args.results_path

    # network architecture
    n_hidden = args.n_hidden

    # train
    n_epochs = args.num_epochs
    batch_size = args.batch_size
    learn_rate = args.learn_rate

    # Plot
    PRR = args.PRR  # Plot Reproduce Result
    PRR_n_img_x = args.PRR_n_img_x  # number of images along x-axis in a canvas
    PRR_n_img_y = args.PRR_n_img_y  # number of images along y-axis in a canvas
    PRR_resize_factor = args.PRR_resize_factor  # resize factor for each image in a canvas

    PMLR = args.PMLR  # Plot Manifold Learning Result
    PMLR_n_img_x = args.PMLR_n_img_x  # number of images along x-axis in a canvas
    PMLR_n_img_y = args.PMLR_n_img_y  # number of images along y-axis in a canvas
    PMLR_resize_factor = args.PMLR_resize_factor  # resize factor for each image in a canvas
    PMLR_z_range = args.PMLR_z_range  # range for random latent vector
    PMLR_n_samples = args.PMLR_n_samples  # number of labeled samples to plot a map from input data space to the latent space
    """ prepare MNIST data """
    '''

    esense_files = [
                    "AAU_livingLab4_202481591532165_1541682359",
                    "fabio_1-202481588431654_1541691060", 
                    "alemino_ZRH_202481601716927_1541691041",
                    "IMDEA_wideband_202481598624002_1541682492"
                    ]
                    b
    esense_folder = "./datadumps/esense_data_jan2019/"
    #train_data, train_labels, test_data, test_labels, bw_labels, pos_labels = spec_data.gendata()
    for ei,efile in enumerate(esense_files):
        print efile
        if ei==0:
            train_data, train_labels,_ = esense_seqload.gendata(esense_folder+efile)
        else:
            dtrain_data, dtrain_labels,_ = esense_seqload.gendata(esense_folder+efile)
            train_data = np.vstack((train_data,dtrain_data))
            train_labels = np.vstack((train_labels,dtrain_labels))
    '''
    #train_data, train_labels, _,_,_,_,_ = synthetic_data.gendata()
    train_data, train_labels, _, _, _ = hackrf_data.gendata(
        "./datadumps/sample_hackrf_data.csv")
    #train_data, train_labels = rawdata.gendata()
    #Split the data
    train_data, train_labels = shuffle_in_unison_inplace(
        train_data, train_labels)
    splitval = int(train_data.shape[0] * 0.5)
    test_data = train_data[:splitval]
    test_labels = train_labels[:splitval]
    train_data = train_data[splitval:]
    train_labels = train_labels[splitval:]
    #Semsup splitting
    splitval = int(train_data.shape[0] * 0.2)
    train_data_sup = train_data[:splitval]
    train_data = train_data[splitval:]
    train_labels_sup = train_labels[:splitval]
    train_labels = train_labels[splitval:]
    n_samples = train_data.shape[0]
    tsamples = train_data.shape[1]
    fsamples = train_data.shape[2]
    dim_img = [tsamples, fsamples]
    nlabels = train_labels.shape[1]
    print(nlabels)

    encoder = "CNN"
    #encoder="LSTM"
    dim_z = args.dimz  # to visualize learned manifold
    enable_sel = False
    """ build graph """

    # input placeholders
    x_hat = tf.placeholder(tf.float32,
                           shape=[None, tsamples, fsamples],
                           name='input_img')
    x = tf.placeholder(tf.float32,
                       shape=[None, tsamples, fsamples],
                       name='target_img')
    x_id = tf.placeholder(tf.float32,
                          shape=[None, nlabels],
                          name='input_img_label')

    # dropout
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')

    # input for PMLR
    z_in = tf.placeholder(tf.float32,
                          shape=[None, dim_z],
                          name='latent_variable')

    # samples drawn from prior distribution
    z_sample = tf.placeholder(tf.float32,
                              shape=[None, dim_z],
                              name='prior_sample')
    cat_sample = tf.placeholder(tf.float32,
                                shape=[None, nlabels],
                                name='prior_sample_label')

    # network architecture
    #y, z, neg_marginal_likelihood, D_loss, G_loss = aae.adversarial_autoencoder(x_hat, x, x_id, z_sample, z_id, dim_img,
    #                                                                            dim_z, n_hidden, keep_prob)
    y, z, neg_marginal_likelihood, D_loss, G_loss, cat_gen_loss, cat = spec_aae.adversarial_autoencoder_semsup_cat_nodimred(
        x_hat,
        x,
        x_id,
        z_sample,
        cat_sample,
        dim_img,
        dim_z,
        n_hidden,
        keep_prob,
        nlabels=nlabels,
        vdim=2)

    # optimization
    t_vars = tf.trainable_variables()
    d_vars = [
        var for var in t_vars
        if "discriminator" or "discriminator_cat" in var.name
    ]
    g_vars = [var for var in t_vars if encoder + "_encoder_cat" in var.name]
    ae_vars = [
        var for var in t_vars
        if encoder + "_encoder_cat" or "CNN_decoder" in var.name
    ]

    train_op_ae = tf.train.AdamOptimizer(learn_rate).minimize(
        neg_marginal_likelihood, var_list=ae_vars)
    train_op_d = tf.train.AdamOptimizer(learn_rate / 2.0).minimize(
        D_loss, var_list=d_vars)
    train_op_g = tf.train.AdamOptimizer(learn_rate).minimize(G_loss,
                                                             var_list=g_vars)
    train_op_cat = tf.train.AdamOptimizer(learn_rate).minimize(cat_gen_loss,
                                                               var_list=g_vars)
    """ training """

    # Plot for reproduce performance
    if PRR:
        PRR = plot_utils.Plot_Reproduce_Performance(RESULTS_DIR, PRR_n_img_x,
                                                    PRR_n_img_y, tsamples,
                                                    fsamples,
                                                    PRR_resize_factor)

        x_PRR = test_data[0:PRR.n_tot_imgs, :]

        x_PRR_img = x_PRR.reshape(PRR.n_tot_imgs, tsamples, fsamples)
        PRR.save_images(x_PRR_img, name='input.jpg')

    # Plot for manifold learning result
    if PMLR and dim_z == 2:

        PMLR = plot_utils.Plot_Manifold_Learning_Result(
            RESULTS_DIR, PMLR_n_img_x, PMLR_n_img_y, tsamples, fsamples,
            PMLR_resize_factor, PMLR_z_range)

        x_PMLR = test_data[0:PMLR_n_samples, :]
        id_PMLR = test_labels[0:PMLR_n_samples, :]

        decoded = spec_aae.decoder(z_in, dim_img, n_hidden)
    else:
        x_PMLR = test_data[0:PMLR_n_samples, :]
        id_PMLR = test_labels[0:PMLR_n_samples, :]
        z_in = tf.placeholder(tf.float32,
                              shape=[None, dim_z],
                              name='latent_variable')

    # train
    total_batch = int(n_samples / batch_size)
    min_tot_loss = 1e99
    prev_loss = 1e99

    saver = tf.train.Saver()
    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer(), feed_dict={keep_prob: 0.9})

        for epoch in range(n_epochs):

            # Random shuffling
            train_data_, train_label_ = shuffle_in_unison_inplace(
                train_data, train_labels)
            train_data_sup_, train_labels_sup_ = shuffle_in_unison_inplace(
                train_data_sup, train_labels_sup)

            # Loop over all batches
            for i in range(total_batch):
                # Compute the offset of the current minibatch in the data.
                offset = (i * batch_size) % (n_samples)
                offset_sup = (i * batch_size) % (train_data_sup.shape[0])
                batch_xs_input = train_data_[offset:(offset + batch_size), :]
                batch_ids_input = train_label_[offset:(offset + batch_size), :]
                batch_xs_sup_input = train_data_sup_[offset_sup:(
                    offset_sup + batch_size), :]
                batch_ids_sup_input = train_labels_sup_[offset_sup:(
                    offset_sup + batch_size), :]
                batch_xs_target = batch_xs_input
                batch_xs_sup_target = batch_xs_sup_input

                # draw samples from prior distribution
                if dim_z > 2:
                    if enable_sel:
                        if args.prior_type == 'mixGaussian':
                            z_id_ = np.random.randint(0,
                                                      nlabels,
                                                      size=[batch_size])
                            samples = np.zeros((batch_size, dim_z))
                            for el in range(dim_z / 2):
                                samples_ = prior.gaussian_mixture(
                                    batch_size,
                                    2,
                                    n_labels=nlabels,
                                    label_indices=z_id_,
                                    y_var=(1.0 / nlabels))
                                samples[:, el * 2:(el + 1) * 2] = samples_
                        elif args.prior_type == 'swiss_roll':
                            z_id_ = np.random.randint(0,
                                                      nlabels,
                                                      size=[batch_size])
                            samples = np.zeros((batch_size, dim_z))
                            for el in range(dim_z / 2):
                                samples_ = prior.swiss_roll(
                                    batch_size, 2, label_indices=z_id_)
                                samples[:, el * 2:(el + 1) * 2] = samples_
                        elif args.prior_type == 'normal':
                            samples, z_id_ = prior.gaussian(
                                batch_size,
                                dim_z,
                                n_labels=nlabels,
                                use_label_info=True)
                        else:
                            raise Exception("[!] There is no option for " +
                                            args.prior_type)
                    else:
                        z_id_ = np.random.randint(0,
                                                  nlabels,
                                                  size=[batch_size])
                        samples = np.random.normal(
                            0.0, 1, (batch_size, dim_z)).astype(np.float32)
                else:
                    if args.prior_type == 'mixGaussian':
                        z_id_ = np.random.randint(0,
                                                  nlabels,
                                                  size=[batch_size])
                        samples = prior.gaussian_mixture(batch_size,
                                                         dim_z,
                                                         n_labels=nlabels,
                                                         label_indices=z_id_,
                                                         y_var=(1.0 / nlabels))
                    elif args.prior_type == 'swiss_roll':
                        z_id_ = np.random.randint(0,
                                                  nlabels,
                                                  size=[batch_size])
                        samples = prior.swiss_roll(batch_size,
                                                   dim_z,
                                                   label_indices=z_id_)
                    elif args.prior_type == 'normal':
                        samples, z_id_ = prior.gaussian(batch_size,
                                                        dim_z,
                                                        n_labels=nlabels,
                                                        use_label_info=True)
                    else:
                        raise Exception("[!] There is no option for " +
                                        args.prior_type)

                z_id_one_hot_vector = np.zeros((batch_size, nlabels))
                z_id_one_hot_vector[np.arange(batch_size), z_id_] = 1

                # reconstruction loss
                _, loss_likelihood0 = sess.run(
                    (train_op_ae, neg_marginal_likelihood),
                    feed_dict={
                        x_hat: batch_xs_input,
                        x: batch_xs_target,
                        z_sample: samples,
                        cat_sample: z_id_one_hot_vector,
                        keep_prob: 0.9
                    })

                _, loss_likelihood1 = sess.run(
                    (train_op_ae, neg_marginal_likelihood),
                    feed_dict={
                        x_hat: batch_xs_sup_input,
                        x: batch_xs_sup_target,
                        z_sample: samples,
                        cat_sample: batch_ids_sup_input,
                        keep_prob: 0.9
                    })
                loss_likelihood = loss_likelihood0 + loss_likelihood1
                # discriminator loss
                _, d_loss = sess.run(
                    (train_op_d, D_loss),
                    feed_dict={
                        x_hat: batch_xs_input,
                        x: batch_xs_target,
                        z_sample: samples,
                        cat_sample: z_id_one_hot_vector,
                        keep_prob: 0.9
                    })

                # generator loss
                for _ in range(2):
                    _, g_loss = sess.run(
                        (train_op_g, G_loss),
                        feed_dict={
                            x_hat: batch_xs_input,
                            x: batch_xs_target,
                            z_sample: samples,
                            cat_sample: z_id_one_hot_vector,
                            keep_prob: 0.9
                        })

                    # supervised phase
                    _, cat_loss = sess.run(
                        (train_op_cat, cat_gen_loss),
                        feed_dict={
                            x_hat: batch_xs_sup_input,
                            x: batch_xs_sup_target,
                            x_id: batch_ids_sup_input,
                            keep_prob: 0.9
                        })

            tot_loss = loss_likelihood + d_loss + g_loss + cat_loss

            # print cost every epoch
            print(
                "epoch %d: L_tot %03.2f L_likelihood %03.4f d_loss %03.2f g_loss %03.2f "
                % (epoch, tot_loss, loss_likelihood, d_loss, g_loss))

            #for v in sess.graph.get_operations():
            #    print(v.name)
            # if minimum loss is updated or final epoch, plot results
            if epoch % 2 == 0 or min_tot_loss > tot_loss or epoch + 1 == n_epochs:
                min_tot_loss = tot_loss
                # Plot for reproduce performance
                if PRR:
                    y_PRR = sess.run(y, feed_dict={x_hat: x_PRR, keep_prob: 1})
                    save_subimages([x_PRR[:10], y_PRR[:10]],
                                   "./results/Reco_%02d" % (epoch))
                    #y_PRR_img = y_PRR.reshape(PRR.n_tot_imgs, tsamples, fsamples)
                    #PRR.save_images(y_PRR_img, name="/PRR_epoch_%02d" %(epoch) + ".jpg")

                # Plot for manifold learning result
                if PMLR and dim_z == 2:
                    y_PMLR = sess.run(decoded,
                                      feed_dict={
                                          z_in: PMLR.z,
                                          keep_prob: 1
                                      })
                    y_PMLR_img = y_PMLR.reshape(PMLR_n_img_x, PMLR_n_img_x,
                                                tsamples, fsamples)
                    save_subimages(y_PMLR_img, "./results/Mani_%02d" % (epoch))
                    #y_PMLR_img = y_PMLR.reshape(PMLR.n_tot_imgs, fsamples, tsamples)
                    #PMLR.save_images(y_PMLR_img, name="/PMLR_epoch_%02d" % (epoch) + ".jpg")

                    # plot distribution of labeled images
                    z_PMLR = sess.run(z,
                                      feed_dict={
                                          x_hat: x_PMLR,
                                          keep_prob: 1
                                      })
                    PMLR.save_scattered_image(z_PMLR,
                                              id_PMLR,
                                              name="/PMLR_map_epoch_%02d" %
                                              (epoch) + ".jpg",
                                              N=nlabels)
                else:
                    retcat, test_cat_loss, test_ll = sess.run(
                        (cat, cat_gen_loss, neg_marginal_likelihood),
                        feed_dict={
                            x_hat: x_PMLR,
                            x_id: id_PMLR,
                            x: x_PMLR,
                            keep_prob: 1
                        })
                    print(
                        "Accuracy: ", 100.0 *
                        np.sum(np.argmax(retcat, 1) == np.argmax(id_PMLR, 1)) /
                        retcat.shape[0], test_cat_loss, test_ll)
                    save_loss = test_cat_loss + test_ll
                    if prev_loss > save_loss and (epoch % 100
                                                  == 0):  # and epoch!=0:
                        prev_loss = save_loss
                        #save_graph(sess,"./savedmodels/","saved_checkpoint","checkpoint_state","input_graph.pb","output_graph.pb",encoder+"_encoder_cat/zout/BiasAdd,"+encoder+"_encoder_cat/catout/Softmax,CNN_decoder/reshaped/Reshape,discriminator_cat_1/add_2,discriminator_1/add_2")
                        save_path = saver.save(
                            sess, "./savedmodels_allsensors/allsensors.ckpt")
                        tf.train.write_graph(sess.graph_def,
                                             "./savedmodels_allsensors/",
                                             "allsensors.pb",
                                             as_text=False)
예제 #13
0
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        """ random condition, random noise """

        if self.test_input_type == 'dataset':
            samples = self.sess.run(
                self.out,
                feed_dict={self.inputs: self.data_X[:tot_num_samples]})
        elif self.test_input_type == 'noise':
            z_sample = prior.gaussian(self.batch_size, self.z_dim)

            samples = self.sess.run(self.fake_images,
                                    feed_dict={self.z: z_sample})

        save_images(
            samples[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim],
            check_folder(self.result_dir + '/' + self.model_dir) + '/' +
            self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')

        if self.z_dim == 2:
            """ learned manifold """
            assert self.z_dim == 2

            z_tot = None
            id_tot = None
            for idx in range(0, 100):
                # randomly sampling
                id = np.random.randint(0, self.num_batches)
                batch_images = self.data_X[id * self.batch_size:(id + 1) *
                                           self.batch_size]
                batch_labels = self.data_y[id * self.batch_size:(id + 1) *
                                           self.batch_size]

                z = self.sess.run(self.mu,
                                  feed_dict={self.inputs: batch_images})

                if idx == 0:
                    z_tot = z
                    id_tot = batch_labels
                else:
                    z_tot = np.concatenate((z_tot, z), axis=0)
                    id_tot = np.concatenate((id_tot, batch_labels), axis=0)

            save_scattered_image(
                z_tot,
                id_tot,
                -4,
                4,
                name=check_folder(self.result_dir + '/' + self.model_dir) +
                '/' + self.model_name + '_epoch%03d' % epoch +
                '_learned_manifold.png')
        else:
            """ N dimensional manifold"""
            z_tot = None
            id_tot = None
            for idx in range(0, 100):
                # randomly sampling
                id = np.random.randint(0, self.num_batches)
                batch_images = self.data_X[id * self.batch_size:(id + 1) *
                                           self.batch_size]
                batch_labels = self.data_y[id * self.batch_size:(id + 1) *
                                           self.batch_size]

                z = self.sess.run(self.mu,
                                  feed_dict={self.inputs: batch_images})

                if idx == 0:
                    z_tot = z
                    id_tot = batch_labels
                else:
                    z_tot = np.concatenate((z_tot, z), axis=0)
                    id_tot = np.concatenate((id_tot, batch_labels), axis=0)

            reduced_z = PCA(n_components=2).fit_transform(z_tot)
            save_scattered_image(
                reduced_z,
                id_tot,
                -4,
                4,
                name=check_folder(self.result_dir + '/' + self.model_dir) +
                '/' + self.model_name + '_epoch%03d' % epoch +
                '_learned_manifold.png')

        pickle.dump(
            z_tot,
            open(
                check_folder(self.result_dir + '/' + self.model_dir) + '/' +
                self.model_name + '_epoch%03d' % epoch + '_z_tot.p', 'wb'))
        pickle.dump(
            id_tot,
            open(
                check_folder(self.result_dir + '/' + self.model_dir) + '/' +
                self.model_name + '_epoch%03d' % epoch + '_id_tot.p', 'wb'))