예제 #1
0
    def train(
        self,
        num_epochs=200,  # number of epochs
        learning_rate=0.0002,  # learning rate of optimizer
        beta1=0.5,  # parameter for Adam optimizer
        decay_rate=1.0,  # learning rate decay (0, 1], 1 means no decay
        enable_shuffle=True,  # enable shuffle of the dataset
        use_trained_model=True,  # use the saved checkpoint to initialize the network
        use_init_model=True,  # use the init model to initialize the network
        weigts=(0.0001, 0, 0)  # the weights of adversarial loss and TV loss
    ):

        # *************************** load file names of images ******************************************************
        file_names = glob(os.path.join('./data', self.dataset_name, '*.jpg'))
        size_data = len(file_names)
        np.random.seed(seed=2017)
        if enable_shuffle:
            np.random.shuffle(file_names)

        # *********************************** optimizer **************************************************************
        # over all, there are three loss functions, weights may differ from the paper because of different datasets
        self.loss_EG = self.EG_loss + weigts[0] * self.G_img_loss + weigts[
            1] * self.E_z_loss + weigts[
                2] * self.tv_loss  # slightly increase the params
        self.loss_Dz = self.D_z_loss_prior + self.D_z_loss_z
        self.loss_Di = self.D_img_loss_input + self.D_img_loss_G

        # set learning rate decay
        self.EG_global_step = tf.Variable(0,
                                          trainable=False,
                                          name='global_step')
        EG_learning_rate = tf.compat.v1.train.exponential_decay(
            learning_rate=learning_rate,
            global_step=self.EG_global_step,
            decay_steps=size_data / self.size_batch * 2,
            decay_rate=decay_rate,
            staircase=True)

        # optimizer for encoder + generator
        with tf.compat.v1.variable_scope('opt', reuse=tf.compat.v1.AUTO_REUSE):
            self.EG_optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=EG_learning_rate, beta1=beta1).minimize(
                    loss=self.loss_EG,
                    global_step=self.EG_global_step,
                    var_list=self.E_variables + self.G_variables)

            # optimizer for discriminator on z
            self.D_z_optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=EG_learning_rate,
                beta1=beta1).minimize(loss=self.loss_Dz,
                                      var_list=self.D_z_variables)

            # optimizer for discriminator on image
            self.D_img_optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=EG_learning_rate,
                beta1=beta1).minimize(loss=self.loss_Di,
                                      var_list=self.D_img_variables)

        # *********************************** tensorboard *************************************************************
        # for visualization (TensorBoard): $ tensorboard --logdir path/to/log-directory
        self.EG_learning_rate_summary = tf.summary.scalar(
            'EG_learning_rate', EG_learning_rate)
        self.summary = tf.compat.v1.summary.merge([
            self.z_summary, self.z_prior_summary, self.D_z_loss_z_summary,
            self.D_z_loss_prior_summary, self.D_z_logits_summary,
            self.D_z_prior_logits_summary, self.EG_loss_summary,
            self.E_z_loss_summary, self.D_img_loss_input_summary,
            self.D_img_loss_G_summary, self.G_img_loss_summary,
            self.EG_learning_rate_summary, self.D_G_logits_summary,
            self.D_input_logits_summary
        ])
        self.writer = tf.summary.FileWriter(
            os.path.join(self.save_dir, 'summary'), self.session.graph)

        # ************* get some random samples as testing data to visualize the learning process *********************
        sample_files = file_names[0:self.size_batch]
        file_names[0:self.size_batch] = []
        sample = [
            load_image(
                image_path=sample_file,
                image_size=self.size_image,
                image_value_range=self.image_value_range,
                is_gray=(self.num_input_channels == 1),
            ) for sample_file in sample_files
        ]
        if self.num_input_channels == 1:
            sample_images = np.array(sample).astype(np.float32)[:, :, :, None]
        else:
            sample_images = np.array(sample).astype(np.float32)
        sample_label_age = np.ones(
            shape=(len(sample_files), self.num_categories),
            dtype=np.float32) * self.image_value_range[0]
        sample_label_gender = np.ones(
            shape=(len(sample_files), 2),
            dtype=np.float32) * self.image_value_range[0]
        for i, label in enumerate(sample_files):
            #  TO CHANGE:
            label = int(str(sample_files[i]).split('/')[-1].split('_')[0])
            if 0 <= label <= 5:
                label = 0
            elif 6 <= label <= 10:
                label = 1
            elif 11 <= label <= 15:
                label = 2
            elif 16 <= label <= 20:
                label = 3
            elif 21 <= label <= 30:
                label = 4
            elif 31 <= label <= 40:
                label = 5
            elif 41 <= label <= 50:
                label = 6
            elif 51 <= label <= 60:
                label = 7
            elif 61 <= label <= 70:
                label = 8
            else:
                label = 9
            sample_label_age[i, label] = self.image_value_range[-1]
            #  TO CHANGE:
            gender = int(str(sample_files[i]).split('/')[-1].split('_')[1])
            sample_label_gender[i, gender] = self.image_value_range[-1]

        # ******************************************* training *******************************************************
        # initialize the graph
        tf.compat.v1.global_variables_initializer().run()

        # load check point
        if use_trained_model:
            if self.load_checkpoint():
                print("\tSUCCESS ^_^")
            else:
                print("\tFAILED >_<!")
                # load init model
                if use_init_model:
                    if not os.path.exists(
                            'init_model/model-init.data-00000-of-00001'):
                        from init_model.zip_opt import join
                        try:
                            join('init_model/model_parts',
                                 'init_model/model-init.data-00000-of-00001')
                        except:
                            raise Exception('Error joining files')
                    self.load_checkpoint(model_path='init_model')

        # epoch iteration
        num_batches = len(file_names) // self.size_batch
        for epoch in range(num_epochs):
            if enable_shuffle:
                np.random.shuffle(file_names)
            for ind_batch in range(num_batches):
                start_time = time.time()
                # read batch images and labels
                batch_files = file_names[ind_batch *
                                         self.size_batch:(ind_batch + 1) *
                                         self.size_batch]
                batch = [
                    load_image(
                        image_path=batch_file,
                        image_size=self.size_image,
                        image_value_range=self.image_value_range,
                        is_gray=(self.num_input_channels == 1),
                    ) for batch_file in batch_files
                ]
                if self.num_input_channels == 1:
                    batch_images = np.array(batch).astype(np.float32)[:, :, :,
                                                                      None]
                else:
                    batch_images = np.array(batch).astype(np.float32)
                batch_label_age = np.ones(
                    shape=(len(batch_files), self.num_categories),
                    dtype=np.float) * self.image_value_range[0]
                batch_label_gender = np.ones(
                    shape=(len(batch_files), 2),
                    dtype=np.float) * self.image_value_range[0]
                for i, label in enumerate(batch_files):
                    #  TO CHANGE:
                    label = int(
                        str(batch_files[i]).split('/')[-1].split('_')[0])
                    if 0 <= label <= 5:
                        label = 0
                    elif 6 <= label <= 10:
                        label = 1
                    elif 11 <= label <= 15:
                        label = 2
                    elif 16 <= label <= 20:
                        label = 3
                    elif 21 <= label <= 30:
                        label = 4
                    elif 31 <= label <= 40:
                        label = 5
                    elif 41 <= label <= 50:
                        label = 6
                    elif 51 <= label <= 60:
                        label = 7
                    elif 61 <= label <= 70:
                        label = 8
                    else:
                        label = 9
                    batch_label_age[i, label] = self.image_value_range[-1]
                    #  TO CHANGE:
                    gender = int(
                        str(batch_files[i]).split('/')[-1].split('_')[1])
                    batch_label_gender[i, gender] = self.image_value_range[-1]

                # prior distribution on the prior of z
                batch_z_prior = np.random.uniform(
                    self.image_value_range[0], self.image_value_range[-1],
                    [self.size_batch, self.num_z_channels]).astype(np.float32)

                # update
                _, _, _, EG_err, Ez_err, Dz_err, Dzp_err, Gi_err, DiG_err, Di_err, TV = self.session.run(
                    fetches=[
                        self.EG_optimizer, self.D_z_optimizer,
                        self.D_img_optimizer, self.EG_loss, self.E_z_loss,
                        self.D_z_loss_z, self.D_z_loss_prior, self.G_img_loss,
                        self.D_img_loss_G, self.D_img_loss_input, self.tv_loss
                    ],
                    feed_dict={
                        self.input_image: batch_images,
                        self.age: batch_label_age,
                        self.gender: batch_label_gender,
                        self.z_prior: batch_z_prior
                    })

                print(
                    "\nEpoch: [%3d/%3d] Batch: [%3d/%3d]\n\tEG_err=%.4f\tTV=%.4f"
                    % (epoch + 1, num_epochs, ind_batch + 1, num_batches,
                       EG_err, TV))
                print("\tEz=%.4f\tDz=%.4f\tDzp=%.4f" %
                      (Ez_err, Dz_err, Dzp_err))
                print("\tGi=%.4f\tDi=%.4f\tDiG=%.4f" %
                      (Gi_err, Di_err, DiG_err))

                # estimate left run time
                elapse = time.time() - start_time
                time_left = ((num_epochs - epoch - 1) * num_batches +
                             (num_batches - ind_batch - 1)) * elapse
                print("\tTime left: %02d:%02d:%02d" %
                      (int(time_left / 3600), int(
                          time_left % 3600 / 60), time_left % 60))
                #  TO CHANGE:
                # add to summary
                summary = self.summary.eval(
                    feed_dict={
                        self.input_image: batch_images,
                        self.age: batch_label_age,
                        self.gender: batch_label_gender,
                        self.z_prior: batch_z_prior
                    })
                self.writer.add_summary(summary, self.EG_global_step.eval())

            # save sample images for each epoch
            name = '{:02d}.png'.format(epoch + 1)
            self.sample(sample_images, sample_label_age, sample_label_gender,
                        name)
            self.test(sample_images, sample_label_gender, name)

            # save checkpoint for each 5 epoch
            if np.mod(epoch, 5) == 4:
                self.save_checkpoint()

        # save the trained model
        self.save_checkpoint()
        #  TO CHANGE:
        # close the summary writer
        self.writer.close()
예제 #2
0
    def train(
        self,
        num_epochs=200,  # number of epochs
        learning_rate=0.0002,  # learning rate of optimizer
        beta1=0.5,  # parameter for Adam optimizer
        decay_rate=1.0,  # learning rate decay (0, 1], 1 means no decay
        enable_shuffle=True,  # enable shuffle of the dataset
        use_trained_model=True,  # use the saved checkpoint to initialize the network
        use_init_model=True,  # use the init model to initialize the network
        weigts=(0.0001, 0.0001, 0.001)  # the weights of adversarial loss and TV loss
    ):

        # *************************** load file names of images ******************************************************
        file_names = glob(os.path.join('./data', self.dataset_name, '*.jpg'))
        size_data = len(file_names)
        np.random.seed(seed=2017)
        if enable_shuffle:
            np.random.shuffle(file_names)

        # *********************************** optimizer **************************************************************
        # over all, there are three loss functions, weights may differ from the paper because of different datasets
        self.loss_EG = self.EG_loss + weigts[0] * self.G_img_loss + weigts[1] * self.E_z_loss + weigts[2] * self.tv_loss # slightly increase the params
        self.loss_Dz = self.D_z_loss_prior + self.D_z_loss_z
        self.loss_Di = self.D_img_loss_input + self.D_img_loss_G
        self.loss_EG_summary = tf.summary.scalar('loss_EG', self.loss_EG)
        self.loss_Dz_summary = tf.summary.scalar('loss_Dz', self.loss_Dz)
        self.loss_Di_summary = tf.summary.scalar('loss_Di', self.loss_Di)

        # set learning rate decay
        self.EG_global_step = tf.Variable(0, trainable=False, name='global_step')
        EG_learning_rate = tf.train.exponential_decay(
            learning_rate=learning_rate,
            global_step=self.EG_global_step,
            decay_steps=size_data / self.size_batch * 2,
            decay_rate=decay_rate,
            staircase=True
        )

        # optimizer for encoder + generator
        with tf.variable_scope('opt', reuse=tf.AUTO_REUSE):
            self.EG_optimizer = tf.train.AdamOptimizer(
                learning_rate=EG_learning_rate,
                beta1=beta1
            ).minimize(
                loss=self.loss_EG,
                global_step=self.EG_global_step,
                var_list=self.E_variables + self.G_variables
            )

            # optimizer for discriminator on z
            self.D_z_optimizer = tf.train.AdamOptimizer(
                learning_rate=EG_learning_rate,
                beta1=beta1
            ).minimize(
                loss=self.loss_Dz,
                var_list=self.D_z_variables
            )

            # optimizer for discriminator on image
            self.D_img_optimizer = tf.train.AdamOptimizer(
                learning_rate=EG_learning_rate,
                beta1=beta1
            ).minimize(
                loss=self.loss_Di,
                var_list=self.D_img_variables
            )

            # self.D_real_optimizer = tf.train.AdamOptimizer(
            #     learning_rate=EG_learning_rate,
            #     beta1=beta1
            # ).minimize(
            #     loss=self.D_img_loss_input,
            #     global_step=self.EG_global_step,
            #     var_list=self.D_img_variables
            # )
            # self.D_fake_optimizer = tf.train.AdamOptimizer(
            #     learning_rate=EG_learning_rate,
            #     beta1=beta1
            # ).minimize(
            #     loss=self.D_img_loss_G,
            #     global_step=self.EG_global_step,
            #     var_list=self.D_img_variables
            # )
            # self.G_optimizer = tf.train.AdamOptimizer(
            #     learning_rate=EG_learning_rate,
            #     beta1=beta1
            # ).minimize(
            #     loss=self.G_img_loss,
            #     global_step=self.EG_global_step,
            #     var_list=self.G_variables
            # )
        # *********************************** tensorboard *************************************************************
        # for visualization (TensorBoard): $ tensorboard --logdir path/to/log-directory
        self.EG_learning_rate_summary = tf.summary.scalar('EG_learning_rate', EG_learning_rate)
        self.summary = tf.summary.merge([
            self.z_summary, self.z_prior_summary,
            self.D_z_loss_z_summary, self.D_z_loss_prior_summary,
            self.D_z_logits_summary, self.D_z_prior_logits_summary,
            self.EG_loss_summary, self.E_z_loss_summary,
            self.D_img_loss_input_summary, self.D_img_loss_G_summary,
            self.G_img_loss_summary, self.EG_learning_rate_summary,
            self.D_G_logits_summary, self.D_input_logits_summary,
            self.tv_loss_summary, self.loss_EG_summary,
            self.loss_Dz_summary, self.loss_Di_summary,
        ])
        self.writer = tf.summary.FileWriter(os.path.join(self.save_dir, 'summary'), self.session.graph)

        # ************* get some random samples as testing data to visualize the learning process *********************
        print("\n\tLoading Dataset...")
        sample_files = file_names[0:self.size_batch]
        file_names[0:self.size_batch] = []
        sample_images, sample_label_age, sample_label_gender = self.get_dataset(sample_files)

        if not os.path.exists(os.path.join('save', 'test')):
            os.makedirs(os.path.join('save', 'test'))
        cache_file = os.path.join('save', 'test', 'test_files.pkl')
        with open(cache_file, 'wb') as fid:
            cPickle.dump(sample_files, fid, cPickle.HIGHEST_PROTOCOL)
        
        # ******************************************* training *******************************************************
        # initialize the graph
        tf.global_variables_initializer().run()

        # load check point
        if use_trained_model:
            if self.load_checkpoint():
                print("\tSUCCESS ^_^")
            else:
                print("\tFAILED >_<!")
                # load init model
                if use_init_model:
                    if not os.path.exists('init_model/model-init.data-00000-of-00001'):
                        from init_model.zip_opt import join
                        try:
                            join('init_model/model_parts', 'init_model/model-init.data-00000-of-00001')
                        except:
                            raise Exception('Error joining files')
                    self.load_checkpoint(model_path='init_model')

        # epoch iteration
        print("\n\tStart Training...")
        num_batches = len(file_names) // self.size_batch
        for epoch in range(num_epochs):
            if enable_shuffle:
                np.random.shuffle(file_names)
            for ind_batch in range(num_batches):
                start_time = time.time()

                # read batch images and labels
                batch_files = file_names[ind_batch*self.size_batch:(ind_batch+1)*self.size_batch]
                batch_images, batch_label_age, batch_label_gender = self.get_dataset(batch_files)

                # prior distribution on the prior of z
                batch_z_prior = np.random.uniform(
                    self.image_value_range[0],
                    self.image_value_range[-1],
                    [self.size_batch, self.num_z_channels]
                ).astype(np.float32)

                # update
                _, _, _, EG_err, Ez_err, Dz_err, Dzp_err, Gi_err, DiG_err, Di_err, TV = self.session.run(
                    fetches = [
                        self.EG_optimizer,
                        self.D_z_optimizer,
                        self.D_img_optimizer,
                        self.EG_loss,
                        self.E_z_loss,
                        self.D_z_loss_z,
                        self.D_z_loss_prior,
                        self.G_img_loss,
                        self.D_img_loss_G,
                        self.D_img_loss_input,
                        self.tv_loss
                    ],
                    feed_dict={
                        self.input_image: batch_images,
                        self.age: batch_label_age,
                        self.gender: batch_label_gender,
                        self.z_prior: batch_z_prior
                    }
                )

                print("\nEpoch: [%3d/%3d] Batch: [%3d/%3d]\n\tEG_err=%.4f\tTV=%.4f" %
                    (epoch+1, num_epochs, ind_batch+1, num_batches, EG_err, TV))
                print("\tEz=%.4f\tDz=%.4f\tDzp=%.4f" % (Ez_err, Dz_err, Dzp_err))
                print("\tGi=%.4f\tDi=%.4f\tDiG=%.4f" % (Gi_err, Di_err, DiG_err))
                

                # estimate left run time
                elapse = time.time() - start_time
                time_left = ((num_epochs - epoch - 1) * num_batches + (num_batches - ind_batch - 1)) * elapse
                print("\tTime left: %02d:%02d:%02d" %
                      (int(time_left / 3600), int(time_left % 3600 / 60), time_left % 60))

                # add to summary
                summary = self.summary.eval(
                    feed_dict={
                        self.input_image: batch_images,
                        self.age: batch_label_age,
                        self.gender: batch_label_gender,
                        self.z_prior: batch_z_prior
                    }
                )
                self.writer.add_summary(summary, self.EG_global_step.eval())

            # save sample images for each epoch
            name = '{:02d}.png'.format(epoch+1)
            self.sample(sample_images, sample_label_age, sample_label_gender, name)
            self.test(sample_images, sample_label_gender, name)

            # save checkpoint for each 5 epoch
            if np.mod(epoch, 5) == 4:
                self.save_checkpoint()

        # save the trained model
        self.save_checkpoint()
        # close the summary writer
        self.writer.close()