def generate_samples(self, sess, plot_size, manifold=False, file_id=None):
        # if z is None:
        #     gen_im = sess.run(self._generate_op)
        # else:
        n_samples = plot_size * plot_size

        label_indices = None
        if self._use_label:
            cur_r = 0
            label_indices = []
            cur_label = -1
            while cur_r < plot_size:
                cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0
                row_label = np.ones(plot_size) * cur_label
                label_indices.extend(row_label)
                cur_r += 1

        if manifold:
            if self._dist == 'gaussian':
                random_code = distribution.interpolate(
                    plot_size=plot_size, interpolate_range=[-3, 3, -3, 3])
                self.viz_samples(sess, random_code, plot_size, file_id=file_id)
            else:
                for mode_id in range(self._n_labels):
                    random_code = distribution.interpolate_gm(
                        plot_size=plot_size,
                        interpolate_range=[-1., 1., -0.2, 0.2],
                        mode_id=mode_id,
                        n_mode=self._n_labels)
                    self.viz_samples(sess,
                                     random_code,
                                     plot_size,
                                     file_id='{}_{}'.format(file_id, mode_id))
        else:
            if self._dist == 'gaussian':
                random_code = distribution.diagonal_gaussian(
                    n_samples, self._g_model.n_code, mean=0, var=1.0)
            else:
                random_code = distribution.gaussian_mixture(
                    n_samples,
                    n_dim=self._g_model.n_code,
                    n_labels=self._n_labels,
                    x_var=0.5,
                    y_var=0.1,
                    label_indices=label_indices)

            self.viz_samples(sess, random_code, plot_size, file_id=file_id)
Ejemplo n.º 2
0
    def train_z_gan_epoch(self, sess, ae_dropout=1.0, summary_writer=None):
        self._t_model.set_is_training(True)
        display_name_list = ['loss', 'd_loss', 'g_loss']
        cur_summary = None
        # if self.epoch_id == 50:
        #     self._lr = self._lr / 10
        # if self.epoch_id == 200:
        #     self._lr = self._lr / 10
        if self.epoch_id == 100:
            self._lr = self._lr / 10
        if self.epoch_id == 300:
            self._lr = self._lr / 10

        cur_epoch = self._train_data.epochs_completed

        step = 0
        loss_sum = 0
        d_loss_sum = 0
        g_loss_sum = 0
        self.epoch_id += 1
        while cur_epoch == self._train_data.epochs_completed:
            self.global_step += 1
            step += 1

            # batch_data = self._train_data.next_batch_dict()
            # im = batch_data['im']
            # label = batch_data['label']

            # _, d_loss = sess.run(
            #     [self._train_d_op, self._d_loss_op],
            #     feed_dict={self._t_model.image: im,
            #                self._t_model.lr: self._lr,
            #                self._t_model.keep_prob: 1.})

            batch_data = self._train_data.next_batch_dict()
            im = batch_data['im']
            label = batch_data['label']

            if self._use_label:
                label_indices = label
            else:
                label_indices = None

            if self._dist == 'gmm':
                real_sample = distribution.gaussian_mixture(
                    len(im),
                    n_dim=self._t_model.n_code,
                    n_labels=10,
                    x_var=0.5,
                    y_var=0.1,
                    label_indices=label_indices)
            else:
                real_sample = distribution.diagonal_gaussian(
                    len(im), self._t_model.n_code, mean=0, var=1.0)

            # train autoencoder
            _, loss, cur_summary = sess.run(
                [self._train_op, self._loss_op, self._train_summary_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: ae_dropout,
                    self._t_model.label: label,
                    self._t_model.real_distribution: real_sample
                })

            # train discriminator

            _, d_loss = sess.run(
                [self._train_d_op, self._d_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.,
                    self._t_model.real_distribution: real_sample
                })

            # train generator
            _, g_loss = sess.run(
                [self._train_g_op, self._g_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.
                })

            # batch_data = self._train_data.next_batch_dict()
            # im = batch_data['im']
            # label = batch_data['label']
            loss_sum += loss
            d_loss_sum += d_loss
            g_loss_sum += g_loss

            if step % 100 == 0:
                display(self.global_step,
                        step, [loss_sum, d_loss_sum, g_loss_sum],
                        display_name_list,
                        'train',
                        summary_val=cur_summary,
                        summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr))
        display(self.global_step,
                step, [loss_sum, d_loss_sum, g_loss_sum],
                display_name_list,
                'train',
                summary_val=cur_summary,
                summary_writer=summary_writer)