Пример #1
0
    def random_sampling(self,
                        sess,
                        keep_prob=1.0,
                        plot_size=10,
                        file_id=None,
                        save_path=None):
        n_samples = plot_size * plot_size
        random_vec = distributions.random_vector((n_samples, self.n_code),
                                                 dist_type='uniform')
        code_cont = distributions.random_vector((n_samples, self.n_continuous),
                                                dist_type='uniform')

        if self.n_discrete <= 0:
            code_discrete = np.zeros((n_samples, 0))
        else:
            code_discrete = []
            for i in range(self.n_discrete):
                n_class = self.cat_n_class_list[i]
                cur_code = [
                    np.random.choice(n_class) for i in range(n_samples)
                ]
                cur_code = dfutils.vec2onehot(cur_code, n_class)
                try:
                    code_discrete = np.concatenate((code_discrete, cur_code),
                                                   axis=-1)
                except ValueError:
                    code_discrete = cur_code

        self._viz_samples(sess,
                          random_vec,
                          code_discrete,
                          code_cont,
                          keep_prob,
                          plot_size=[plot_size, plot_size],
                          save_path=save_path,
                          file_name='random_sampling',
                          file_id=file_id)
Пример #2
0
    def random_sampling(self, sess, plot_size=10, file_id=None):
        """ Randomly sampling from model.

        Randomly sample plot_size * plot_size images from model
        and save as a single image.

        Args:
            sess (tf.Session): tensorflow session
            plot_size (int): side size (number of samples) of saving image
            file_id (int): index for saving image 
        """
        n_samples = plot_size * plot_size
        random_vec = distributions.random_vector(
            (n_samples, self._g_model.in_len), dist_type='uniform')
        # random_vec = np.random.normal(
        #     size=(n_samples, self._g_model.in_len))
        if self._save_path:
            self._viz_samples(sess, random_vec, plot_size, file_id=file_id)
Пример #3
0
    def interp_cont_sampling(self,
                             sess,
                             n_interpolation,
                             cont_code_id,
                             vary_discrete_id=None,
                             n_col_samples=None,
                             keep_prob=1.,
                             save_path=None,
                             file_id=None):
        """ Sample interpolation of one of continuous codes.

        Args:
            sess (tf.Session): tensorflow session
            n_interpolation (int): number of interpolation samples
            cont_code_id (int): index of continuous code for interpolation
            vary_discrete_id (int): Index of discrete code for varying
                while sample interpolation. All the discrete code will be fixed
                if it is None.
            keep_prob (float): keep probability for dropout
            save_path (str): directory for saving image
            file_id (int): index for saving image 
        """
        if cont_code_id >= self.n_continuous:
            return
        if vary_discrete_id is not None and vary_discrete_id < self.n_discrete:
            n_vary_class = self.cat_n_class_list[vary_discrete_id]
            n_samples = n_interpolation * n_vary_class
        elif n_col_samples is not None:
            n_vary_class = n_col_samples
            n_samples = n_interpolation * n_vary_class
        else:
            n_vary_class = 1
            n_samples = n_interpolation

        random_vec = distributions.random_vector((1, self.n_code),
                                                 dist_type='uniform')
        random_vec = np.tile(random_vec, (n_samples, 1))

        if self.n_discrete <= 0:
            code_discrete = np.zeros((n_samples, 0))
        else:
            code_discrete = []
            for i in range(self.n_discrete):
                n_class = self.cat_n_class_list[i]
                if i == vary_discrete_id:
                    cur_code = [
                        i for i in range(n_class)
                        for j in range(n_interpolation)
                    ]
                else:
                    cur_code = np.random.choice(n_class, 1) * np.ones(
                        (n_samples))
                cur_onehot_code = dfutils.vec2onehot(cur_code, n_class)
                try:
                    code_discrete = np.concatenate(
                        (code_discrete, cur_onehot_code), axis=-1)
                except ValueError:
                    code_discrete = cur_onehot_code

        if vary_discrete_id is not None and vary_discrete_id < self.n_discrete:
            code_cont = distributions.random_vector((1, self.n_continuous),
                                                    dist_type='uniform')
            code_cont = np.tile(code_cont, (n_samples, 1))
            cont_interp = np.linspace(-1., 1., n_interpolation)
            cont_interp = np.tile(cont_interp, (n_vary_class))
            code_cont[:, cont_code_id] = cont_interp
        else:
            code_cont = distributions.random_vector(
                (n_col_samples, self.n_continuous), dist_type='uniform')
            code_cont = np.repeat(code_cont, n_interpolation, axis=0)
            # code_cont = np.tile(code_cont.transpose(), (1, n_interpolation)).transpose()
            cont_interp = np.linspace(-1., 1., n_interpolation)
            cont_interp = np.tile(cont_interp, (n_vary_class))
            code_cont[:, cont_code_id] = cont_interp

        self._viz_samples(sess,
                          random_vec,
                          code_discrete,
                          code_cont,
                          keep_prob,
                          plot_size=[n_vary_class, n_interpolation],
                          save_path=save_path,
                          file_name='interp_cont_{}'.format(cont_code_id),
                          file_id=file_id)
Пример #4
0
    def vary_discrete_sampling(self,
                               sess,
                               vary_discrete_id,
                               keep_prob=1.0,
                               sample_per_class=10,
                               file_id=None,
                               save_path=None):
        """ Sampling by varying a discrete code.

        Args:
            sess (tf.Session): tensorflow session
            vary_discrete_id (int): index of discrete code for varying
            keep_prob (float): keep probability for dropout
            sample_per_class (int): number of samples for each class
            file_id (int): index for saving image 
            save_path (str): directory for saving image
        """
        if vary_discrete_id >= self.n_discrete:
            return
        n_vary_class = self.cat_n_class_list[vary_discrete_id]
        n_samples = n_vary_class * sample_per_class

        # sample_per_class = int(math.floor(n_samples / n_vary_class))
        n_remain_sample = n_samples - n_vary_class * sample_per_class

        random_vec = distributions.random_vector((n_samples, self.n_code),
                                                 dist_type='uniform')

        if self.n_discrete <= 0:
            code_discrete = np.zeros((n_samples, 0))
        else:
            code_discrete = []
            for i in range(self.n_discrete):
                n_class = self.cat_n_class_list[i]
                if i == vary_discrete_id:
                    cur_code = [
                        i for i in range(n_class)
                        for j in range(sample_per_class)
                    ]
                else:
                    cur_code = [
                        np.random.choice(n_class)
                        for j in range(sample_per_class)
                    ]
                    cur_code = np.tile(cur_code, (n_vary_class))
                cur_code = dfutils.vec2onehot(cur_code, n_class)
                try:
                    code_discrete = np.concatenate((code_discrete, cur_code),
                                                   axis=-1)
                except ValueError:
                    code_discrete = cur_code

        code_cont = distributions.random_vector((n_samples, self.n_continuous),
                                                dist_type='uniform')

        self._viz_samples(
            sess,
            random_vec,
            code_discrete,
            code_cont,
            keep_prob,
            plot_size=[n_vary_class, sample_per_class],
            save_path=save_path,
            file_name='vary_discrete_{}'.format(vary_discrete_id),
            file_id=file_id)
Пример #5
0
    def train_epoch(self,
                    sess,
                    train_data,
                    init_lr,
                    n_g_train=1,
                    n_d_train=1,
                    keep_prob=1.0,
                    summary_writer=None):
        """ Train for one epoch of training data

        Args:
            sess (tf.Session): tensorflow session
            train_data (DataFlow): DataFlow for training set
            init_lr (float): initial learning rate
            n_g_train (int): number of times of generator training for each step
            n_d_train (int): number of times of discriminator training for each step
            keep_prob (float): keep probability for dropout
            summary_writer (tf.FileWriter): write for summary. No summary will be
            saved if None.
        """

        assert int(n_g_train) > 0 and int(n_d_train) > 0
        display_name_list = ['d_loss', 'g_loss', 'LI_G', 'LI_D']
        cur_summary = None

        lr = init_lr
        lr_D = 2e-4
        lr_G = 1e-3
        # lr_G = lr_D * 10

        cur_epoch = train_data.epochs_completed
        step = 0
        d_loss_sum = 0
        g_loss_sum = 0
        LI_G_sum = 0
        LI_D_sum = 0
        self.epoch_id += 1
        while cur_epoch == train_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = train_data.next_batch_dict()
            im = batch_data['im']

            random_vec = distributions.random_vector((len(im), self.n_code),
                                                     dist_type='uniform')

            # code_discrete = []
            # discrete_label = []
            if self.n_discrete <= 0:
                code_discrete = np.zeros((len(im), 0))
                discrete_label = np.zeros((len(im), self.n_discrete))
            else:
                code_discrete = []
                discrete_label = []
                for i in range(self.n_discrete):
                    n_class = self.cat_n_class_list[i]
                    cur_code = np.random.choice(n_class, (len(im)))
                    cur_onehot_code = dfutils.vec2onehot(cur_code, n_class)
                    try:
                        code_discrete = np.concatenate(
                            (code_discrete, cur_onehot_code), axis=-1)
                        discrete_label = np.concatenate(
                            (discrete_label, np.expand_dims(cur_code,
                                                            axis=-1)),
                            axis=-1)
                    except ValueError:
                        code_discrete = cur_onehot_code
                        discrete_label = np.expand_dims(cur_code, axis=-1)

            code_cont = distributions.random_vector(
                (len(im), self.n_continuous), dist_type='uniform')

            # train discriminator
            for i in range(int(n_d_train)):

                _, d_loss, LI_D = sess.run(
                    [self.train_d_op, self.d_loss_op, self.LI_D],
                    feed_dict={
                        self.real: im,
                        self.lr: lr_D,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec,
                        self.code_discrete: code_discrete,
                        self.discrete_label: discrete_label,
                        self.code_continuous: code_cont
                    })
            # train generator
            for i in range(int(n_g_train)):
                _, g_loss, LI_G = sess.run(
                    [self.train_g_op, self.g_loss_op, self.LI_G],
                    feed_dict={
                        self.lr: lr_G,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec,
                        self.code_discrete: code_discrete,
                        self.discrete_label: discrete_label,
                        self.code_continuous: code_cont
                    })

            d_loss_sum += d_loss
            g_loss_sum += g_loss
            LI_G_sum += LI_G
            LI_D_sum += LI_D

            if step % 100 == 0:
                cur_summary = sess.run(self.train_summary_op,
                                       feed_dict={
                                           self.real: im,
                                           self.keep_prob: keep_prob,
                                           self.random_vec: random_vec,
                                           self.code_discrete: code_discrete,
                                           self.discrete_label: discrete_label,
                                           self.code_continuous: code_cont
                                       })

                viz.display(self.global_step,
                            step, [
                                d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                                LI_G_sum, LI_D_sum
                            ],
                            display_name_list,
                            'train',
                            summary_val=cur_summary,
                            summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, lr))
        cur_summary = sess.run(self.train_summary_op,
                               feed_dict={
                                   self.real: im,
                                   self.keep_prob: keep_prob,
                                   self.random_vec: random_vec,
                                   self.code_discrete: code_discrete,
                                   self.discrete_label: discrete_label,
                                   self.code_continuous: code_cont
                               })
        viz.display(self.global_step,
                    step, [
                        d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                        LI_G_sum, LI_D_sum
                    ],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)
Пример #6
0
    def train_epoch(self,
                    sess,
                    n_g_train=1,
                    n_d_train=1,
                    keep_prob=1.0,
                    summary_writer=None):
        """ Train for one epoch of training data

        Args:
            sess (tf.Session): tensorflow session
            n_g_train (int): number of times of generator training for each step
            n_d_train (int): number of times of discriminator training for each step
            keep_prob (float): keep probability for dropout
            summary_writer (tf.FileWriter): write for summary. No summary will be
            saved if None.
        """
        assert int(n_g_train) > 0 and int(n_d_train) > 0
        display_name_list = ['d_loss', 'g_loss']
        cur_summary = None

        if self.epoch_id == 100:
            self._lr = self._lr / 10
        if self.epoch_id == 300:
            self._lr = self._lr / 10

        cur_epoch = self._train_data.epochs_completed

        step = 0
        d_loss_sum = 0
        g_loss_sum = 0
        self.epoch_id += 1
        while cur_epoch == self._train_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = self._train_data.next_batch_dict()
            im = batch_data['im']

            # train discriminator
            for i in range(int(n_d_train)):
                random_vec = distributions.random_vector(
                    (len(im), self._t_model.in_len), dist_type='uniform')
                _, d_loss = sess.run(
                    [self._train_d_op, self._d_loss_op],
                    feed_dict={
                        self._t_model.real: im,
                        self._t_model.lr: self._lr,
                        self._t_model.keep_prob: keep_prob,
                        self._t_model.random_vec: random_vec
                    })
            # train generator
            for i in range(int(n_g_train)):
                random_vec = distributions.random_vector(
                    (len(im), self._t_model.in_len), dist_type='uniform')
                _, g_loss = sess.run(
                    [self._train_g_op, self._g_loss_op],
                    feed_dict={
                        self._t_model.lr: self._lr,
                        self._t_model.keep_prob: keep_prob,
                        self._t_model.random_vec: random_vec
                    })

            d_loss_sum += d_loss
            g_loss_sum += g_loss

            if step % 100 == 0:
                cur_summary = sess.run(self._train_summary_op,
                                       feed_dict={
                                           self._t_model.real: im,
                                           self._t_model.keep_prob: keep_prob,
                                           self._t_model.random_vec: random_vec
                                       })

                display(self.global_step,
                        step, [d_loss_sum / n_d_train, g_loss_sum / n_g_train],
                        display_name_list,
                        'train',
                        summary_val=cur_summary,
                        summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr))
        cur_summary = sess.run(self._train_summary_op,
                               feed_dict={
                                   self._t_model.real: im,
                                   self._t_model.keep_prob: keep_prob,
                                   self._t_model.random_vec: random_vec
                               })
        display(self.global_step,
                step, [d_loss_sum / n_d_train, g_loss_sum / n_g_train],
                display_name_list,
                'train',
                summary_val=cur_summary,
                summary_writer=summary_writer)
Пример #7
0
    def train_epoch(self,
                    sess,
                    train_data,
                    init_lr,
                    n_g_train=1,
                    n_d_train=1,
                    keep_prob=1.0,
                    summary_writer=None):

        assert int(n_g_train) > 0 and int(n_d_train) > 0
        display_name_list = ['d_loss', 'g_loss', 'L_fake', 'L_real']
        cur_summary = None

        lr = init_lr * (0.9**self.epoch_id)

        cur_epoch = train_data.epochs_completed
        step = 0
        d_loss_sum = 0
        g_loss_sum = 0
        l_fake_sum = 0
        l_real_sum = 0
        self.epoch_id += 1
        while cur_epoch == train_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = train_data.next_batch_dict()
            im = batch_data['im']

            random_vec = distributions.random_vector((len(im), self.n_code),
                                                     dist_type='uniform')

            # train discriminator
            for i in range(int(n_d_train)):

                _, d_loss = sess.run(
                    [self.train_d_op, self.d_loss_op],
                    feed_dict={
                        self.real: im,
                        self.lr: lr,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec
                    })

            # train generator
            for i in range(int(n_g_train)):
                # random_vec = distributions.random_vector(
                #     (len(im), self.n_code), dist_type='uniform')
                _, g_loss = sess.run(
                    [self.train_g_op, self.g_loss_op],
                    feed_dict={
                        self.lr: lr,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec
                    })

            # update k
            # random_vec = distributions.random_vector(
            #         (len(im), self.n_code), dist_type='uniform')
            _, L_fake, L_real = sess.run(
                [self.update_op, self.L_fake, self.L_real],
                feed_dict={
                    self.real: im,
                    self.random_vec: random_vec,
                    self.keep_prob: keep_prob,
                })

            d_loss_sum += d_loss
            g_loss_sum += g_loss
            l_fake_sum += L_fake
            l_real_sum += L_real

            if step % 100 == 0:
                cur_summary = sess.run(self.train_summary_op,
                                       feed_dict={
                                           self.real: im,
                                           self.keep_prob: keep_prob,
                                           self.random_vec: random_vec
                                       })

                viz.display(self.global_step,
                            step, [
                                d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                                l_fake_sum, l_real_sum
                            ],
                            display_name_list,
                            'train',
                            summary_val=cur_summary,
                            summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, lr))
        cur_summary = sess.run(self.train_summary_op,
                               feed_dict={
                                   self.real: im,
                                   self.keep_prob: keep_prob,
                                   self.random_vec: random_vec
                               })
        viz.display(self.global_step,
                    step, [
                        d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                        l_fake_sum, l_real_sum
                    ],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)