Пример #1
0
    def train(self):
        '''
        train
        :return:
        '''
        start_time = time.time()

        curr_interval = 0
        for epoch_n in xrange(self.epoch):
            for interval_i in trange(self.batch_idxs):
                batch_image = np.zeros([
                    self.batch_size * self.gpus_count, self.input_size,
                    self.input_size, self.input_channel
                ], np.float32)
                batch_label = np.zeros([
                    self.data_loader_train.labels_nums,
                    self.batch_size * self.gpus_count
                ], np.float32)
                for b_i in xrange(self.gpus_count):
                    batch_image[
                        b_i * self.batch_size:(b_i + 1) * self.
                        batch_size, :, :, :], batch_label[:, b_i * self.batch_size:(
                            b_i + 1
                        ) * self.batch_size] = self.data_loader_train.read_data_batch(
                        )
                #D
                _, loss_d = self.sess.run(
                    [self.train_d_op, self.d_loss],
                    feed_dict={
                        self.batch_data: batch_image,
                        self.input_label: batch_label[0],
                        self.input_pose: batch_label[1],
                        self.input_light: batch_label[2]
                    })

                #G
                for _ in xrange(self.g_loop):
                    _ = self.sess.run(self.train_g_op,
                                      feed_dict={
                                          self.batch_data: batch_image,
                                          self.input_label: batch_label[0],
                                          self.input_pose: batch_label[1],
                                          self.input_light: batch_label[2]
                                      })
                # if interval_i%10:
                sample_data, sample_data_ex,encode_real,encode_syn1,encode_syn2, loss_g, train_summary,\
                data_ex,data,step\
                    = self.sess.run(
                    [self.output_syn1, self.output_syn2,self.pidrcontent,self.encode_syn_1,self.encode_syn_2, self.g_loss, self.summary_train,
                     self.input_data_ex,self.batch_data,self.global_step],
                    feed_dict={self.batch_data: batch_image,
                               self.input_label: batch_label[0],
                               self.input_pose: batch_label[1],
                               self.input_light: batch_label[2]})
                self.summary_write.add_summary(train_summary, global_step=step)

                logging.info('Epoch [%4d/%4d] [gpu%s] [global_step:%d]time:%.2f h, d_loss:%.4f, g_loss:%.4f'\
                %(epoch_n,self.epoch,self.gpus_list,step,(time.time()-start_time)/3600.0,loss_d,loss_g))

                if (curr_interval) % int(
                        self.sample_interval * self.batch_idxs) == 0:
                    #记录训练数据
                    score_train = np.concatenate([
                        utils.compare_pair_features(
                            np.reshape(encode_real, [-1, 512]),
                            np.reshape(encode_syn1, [-1, 512])),
                        utils.compare_pair_features(
                            np.reshape(encode_real, [-1, 512]),
                            np.reshape(encode_syn2, [-1, 512]))
                    ],
                                                 axis=0)
                    logging.info('[score_train] {:08} {}'.format(
                        step, score_train))
                    utils.write_batch(self.result_path,
                                      0,
                                      sample_data,
                                      batch_image,
                                      epoch_n,
                                      interval_i,
                                      othersample=sample_data_ex,
                                      ifmerge=True,
                                      score_f_id=score_train)
                    self.validation(interval_i, epoch_n, step)

                # slerp
                if (curr_interval) % int(
                        self.test_interval * self.batch_idxs) == 0:
                    self.slerp_interpolation(batch_image, batch_label, epoch_n,
                                             interval_i)
                    if self.ifsave and curr_interval != 0:
                        self.saver.save(self.sess,
                                        os.path.join(self.check_point_path,
                                                     self.model_name),
                                        global_step=step)
                        print '*' * 20 + 'save model successed!!!!~~~~'
                curr_interval += 1
Пример #2
0
    def validation(self, interval_i, epoch_n, step):
        '''
        inference
        :return:
        '''

        sample_batch = np.zeros([
            self.test_batch_size, self.input_size, self.input_size,
            self.input_channel
        ], np.float32)
        label_batch = np.zeros(
            [self.data_loader_train.labels_nums, self.test_batch_size],
            np.float32)
        # 加载测试batch
        # someone=np.random.randint(0,self.data_loader_valid.class_nums)
        for s_i in xrange(1):
            sample_batch[
                s_i * self.pose_c:(s_i + 1) *
                self.pose_c, :, :, :], label_batch[:, s_i * self.pose_c:(
                    s_i + 1
                ) * self.pose_c] = self.data_loader_valid.oneperson_allpose(
                    s_i)  # 得到一个人所有的图片
        sample_count = sample_batch.shape[0]
        # identity-preserved 测试
        idlabel_batch = [0] * sample_count
        sample_data, sample_data_ex, encode_real, encode_syn1, encode_syn2 = self.sess.run(
            [
                self.output_syn1, self.output_syn2, self.output_en,
                self.encode_syn_1, self.encode_syn_2
            ],
            feed_dict={
                self.batch_data: sample_batch,
                self.input_label: idlabel_batch,
                self.input_pose: label_batch[1],
                self.input_light: label_batch[2]
            })
        score_identity = np.concatenate([
            utils.compare_pair_features(np.reshape(encode_real, [-1, 512]),
                                        np.reshape(encode_syn1, [-1, 512])),
            utils.compare_pair_features(np.reshape(encode_real, [-1, 512]),
                                        np.reshape(encode_syn2, [-1, 512]))
        ],
                                        axis=0)
        utils.write_batch(self.result_path,
                          1,
                          sample_data,
                          sample_batch,
                          epoch_n,
                          interval_i,
                          othersample=sample_data_ex,
                          ifmerge=True,
                          score_f_id=score_identity)
        logging.info('[score_identity] {:08} {}'.format(step, score_identity))

        # pose - invariance测试
        for idx in xrange(sample_count):  # 将数据集中的同一个人所有!!!角度!!!照片都跑一次
            tppn = self.test_batch_size
            label_batch_sub = [sample_count] * tppn  # 没用凑齐8 为了后面的split
            pose_batch = range(0, tppn)
            light_batch = np.random.randint(0, self.light_c, tppn)
            tmp_batch = np.tile(sample_batch[idx], (tppn, 1, 1, 1)). \
                reshape(tppn, sample_batch.shape[1], sample_batch.shape[2],
                        sample_batch.shape[3])
            sample_data, sample_data_ex, encode_real, encode_syn1, encode_syn2 = self.sess.run(
                [
                    self.output_syn1, self.output_syn2, self.pidrcontent,
                    self.encode_syn_1, self.encode_syn_2
                ],
                feed_dict={
                    self.batch_data: tmp_batch,
                    self.input_label: label_batch_sub,
                    self.input_pose: pose_batch,
                    self.input_light: light_batch
                })
            score_pose = np.concatenate([
                utils.compare_pair_features(np.reshape(encode_real, [-1, 512]),
                                            np.reshape(encode_syn1,
                                                       [-1, 512])),
                utils.compare_pair_features(np.reshape(encode_real, [-1, 512]),
                                            np.reshape(encode_syn2, [-1, 512]))
            ],
                                        axis=0)
            utils.write_batch(self.result_path,
                              2,
                              sample_data,
                              tmp_batch,
                              epoch_n,
                              interval_i,
                              sample_idx=idx,
                              othersample=sample_data_ex,
                              ifmerge=True,
                              score_f_id=score_pose)
            logging.info('[score_pose] {:08} {}'.format(step, score_pose))