def generate_samples(self, sess, plot_size, manifold=False, file_id=None):
        # if z is None:
        #     gen_im = sess.run(self._generate_op)
        # else:
        n_samples = plot_size * plot_size

        label_indices = None
        if self._use_label:
            cur_r = 0
            label_indices = []
            cur_label = -1
            while cur_r < plot_size:
                cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0
                row_label = np.ones(plot_size) * cur_label
                label_indices.extend(row_label)
                cur_r += 1

        if manifold:
            if self._dist == 'gaussian':
                random_code = distribution.interpolate(
                    plot_size=plot_size, interpolate_range=[-3, 3, -3, 3])
                self.viz_samples(sess, random_code, plot_size, file_id=file_id)
            else:
                for mode_id in range(self._n_labels):
                    random_code = distribution.interpolate_gm(
                        plot_size=plot_size,
                        interpolate_range=[-1., 1., -0.2, 0.2],
                        mode_id=mode_id,
                        n_mode=self._n_labels)
                    self.viz_samples(sess,
                                     random_code,
                                     plot_size,
                                     file_id='{}_{}'.format(file_id, mode_id))
        else:
            if self._dist == 'gaussian':
                random_code = distribution.diagonal_gaussian(
                    n_samples, self._g_model.n_code, mean=0, var=1.0)
            else:
                random_code = distribution.gaussian_mixture(
                    n_samples,
                    n_dim=self._g_model.n_code,
                    n_labels=self._n_labels,
                    x_var=0.5,
                    y_var=0.1,
                    label_indices=label_indices)

            self.viz_samples(sess, random_code, plot_size, file_id=file_id)
Ejemplo n.º 2
0
    def train_z_gan_epoch(self, sess, ae_dropout=1.0, summary_writer=None):
        self._t_model.set_is_training(True)
        display_name_list = ['loss', 'd_loss', 'g_loss']
        cur_summary = None
        # if self.epoch_id == 50:
        #     self._lr = self._lr / 10
        # if self.epoch_id == 200:
        #     self._lr = self._lr / 10
        if self.epoch_id == 100:
            self._lr = self._lr / 10
        if self.epoch_id == 300:
            self._lr = self._lr / 10

        cur_epoch = self._train_data.epochs_completed

        step = 0
        loss_sum = 0
        d_loss_sum = 0
        g_loss_sum = 0
        self.epoch_id += 1
        while cur_epoch == self._train_data.epochs_completed:
            self.global_step += 1
            step += 1

            # batch_data = self._train_data.next_batch_dict()
            # im = batch_data['im']
            # label = batch_data['label']

            # _, d_loss = sess.run(
            #     [self._train_d_op, self._d_loss_op],
            #     feed_dict={self._t_model.image: im,
            #                self._t_model.lr: self._lr,
            #                self._t_model.keep_prob: 1.})

            batch_data = self._train_data.next_batch_dict()
            im = batch_data['im']
            label = batch_data['label']

            if self._use_label:
                label_indices = label
            else:
                label_indices = None

            if self._dist == 'gmm':
                real_sample = distribution.gaussian_mixture(
                    len(im),
                    n_dim=self._t_model.n_code,
                    n_labels=10,
                    x_var=0.5,
                    y_var=0.1,
                    label_indices=label_indices)
            else:
                real_sample = distribution.diagonal_gaussian(
                    len(im), self._t_model.n_code, mean=0, var=1.0)

            # train autoencoder
            _, loss, cur_summary = sess.run(
                [self._train_op, self._loss_op, self._train_summary_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: ae_dropout,
                    self._t_model.label: label,
                    self._t_model.real_distribution: real_sample
                })

            # train discriminator

            _, d_loss = sess.run(
                [self._train_d_op, self._d_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.,
                    self._t_model.real_distribution: real_sample
                })

            # train generator
            _, g_loss = sess.run(
                [self._train_g_op, self._g_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.
                })

            # batch_data = self._train_data.next_batch_dict()
            # im = batch_data['im']
            # label = batch_data['label']
            loss_sum += loss
            d_loss_sum += d_loss
            g_loss_sum += g_loss

            if step % 100 == 0:
                display(self.global_step,
                        step, [loss_sum, d_loss_sum, g_loss_sum],
                        display_name_list,
                        'train',
                        summary_val=cur_summary,
                        summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr))
        display(self.global_step,
                step, [loss_sum, d_loss_sum, g_loss_sum],
                display_name_list,
                'train',
                summary_val=cur_summary,
                summary_writer=summary_writer)
Ejemplo n.º 3
0
    def train_semisupervised_epoch(self,
                                   sess,
                                   ae_dropout=1.0,
                                   summary_writer=None):
        label_data = self._train_data['labeled']
        unlabel_data = self._train_data['unlabeled']
        display_name_list = [
            'loss', 'z_d_loss', 'z_g_loss', 'y_d_loss', 'y_g_loss', 'cls_loss',
            'cls_accuracy'
        ]
        cur_summary = None
        cur_epoch = unlabel_data.epochs_completed

        self.epoch_id += 1

        if self.epoch_id == 150:
            self._lr = self._lr / 10
        if self.epoch_id == 200:
            self._lr = self._lr / 10

        step = 0
        loss_sum = 0
        z_d_loss_sum = 0
        z_g_loss_sum = 0
        y_d_loss_sum = 0
        y_g_loss_sum = 0
        cls_loss_sum = 0
        cls_accuracy_sum = 0
        while cur_epoch == unlabel_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = unlabel_data.next_batch_dict()
            im = batch_data['im']
            label = batch_data['label']

            z_real_sample = distribution.diagonal_gaussian(
                len(im), self._t_model.n_code, mean=0, var=1.0)

            y_real_sample = np.random.choice(self._t_model.n_class, len(im))
            # a = np.array([1, 0, len(im)])
            # b = np.zeros((len(im), self._t_model.n_class))
            # b[np.arange(len(im)), y_real_sample] = 1
            # y_real_sample = b
            # print(y_real_sample)

            # train autoencoder
            _, loss, cur_summary = sess.run(
                [self._train_op, self._loss_op, self._train_summary_op],
                feed_dict={
                    self._t_model.image: im,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: ae_dropout,
                    self._t_model.label: label,
                    self._t_model.real_distribution: z_real_sample,
                    self._t_model.real_y: y_real_sample
                })

            # z discriminator
            _, z_d_loss = sess.run(
                [self._train_d_op, self._d_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    # self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.,
                    self._t_model.real_distribution: z_real_sample
                })

            # z generator
            _, z_g_loss = sess.run(
                [self._train_g_op, self._g_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    # self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.
                })

            # y discriminator
            _, y_d_loss = sess.run(
                [self._train_cat_d_op, self._cat_d_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    # self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.,
                    self._t_model.real_y: y_real_sample
                })

            # y generator
            _, y_g_loss = sess.run(
                [self._train_cat_g_op, self._cat_g_loss_op],
                feed_dict={
                    self._t_model.image: im,
                    # self._t_model.label: label,
                    self._t_model.lr: self._lr,
                    self._t_model.keep_prob: 1.
                })

            batch_data = label_data.next_batch_dict()
            im = batch_data['im']
            label = batch_data['label']
            # semisupervise
            if self.global_step % 10 == 0:
                _, cls_loss, cls_accuracy = sess.run(
                    [
                        self._cls_train_op, self._cls_loss_op,
                        self._cls_accuracy_op
                    ],
                    feed_dict={
                        self._t_model.image: im,
                        self._t_model.label: label,
                        self._t_model.lr: self._lr,
                        self._t_model.keep_prob: 1.
                    })
                cls_loss_sum += cls_loss
                cls_accuracy_sum += cls_accuracy

            loss_sum += loss
            z_d_loss_sum += z_d_loss
            z_g_loss_sum += z_g_loss
            y_d_loss_sum += y_d_loss
            y_g_loss_sum += y_g_loss

            if step % 100 == 0:
                display(
                    self.global_step,
                    step, [
                        loss_sum, z_d_loss_sum, z_g_loss_sum, y_d_loss_sum,
                        y_g_loss_sum, cls_loss_sum * 10, cls_accuracy_sum * 10
                    ],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr))
        display(
            self.global_step,
            step,
            [loss_sum, z_d_loss_sum, z_g_loss_sum, y_d_loss_sum, y_g_loss_sum],
            display_name_list,
            'train',
            summary_val=cur_summary,
            summary_writer=summary_writer)