示例#1
0
    def train(self):
        self.build()
        analyze_vars(tf.trainable_variables(), os.path.join(self.output_dir, 'model_vars.txt'))
        with open(os.path.join(self.output_dir, 'regularizers.txt'), 'w') as f:
            for v in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
                f.write(v.name+'\n')
        # exit(-1)
        tf_config = tf.ConfigProto(allow_soft_placement=True)
        tf_config.gpu_options.allow_growth = True
        with tf.Session(config=tf_config) as sess:
            tf.global_variables_initializer().run()
            saver_ckpt = tf.train.Saver()
            saver_best = tf.train.Saver()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            start_time = time.time()
            best_acc = 0
            counter = 0
            if config['pretrained_model'] != '':
                saver_ckpt.restore(sess, config['pretrained_model'])
                step = int(os.path.basename(config['pretrained_model']).split('.')[0].split('-')[-1])
                sess.run(tf.assign(self.global_step, step))
                counter = self.global_step.eval(sess)
                print('start step: %d' % counter)
            debug = True
            for i in range(self.epoch_num):
                for j in range(self.step_per_epoch):
                    _, l, l_wd, l_inf, acc, s, _ = sess.run([self.train_op, self.train_loss, self.wd_loss, self.inference_loss, self.train_acc, self.train_summary, self.inc_op], feed_dict={self.train_phase_dropout: True, self.train_phase_bn: True})
                    counter += 1

                    # debug
                    # self.save_image_label(train_img, train_lbl, counter)
                    # if(debug):
                    #     if(len(train_imgs) < 100):
                    #         train_imgs.append(train_img[0])
                    #     else:
                    #         np.save(os.path.join(self.debug_dir, 'train_imgs.npy'), np.array(train_imgs))
                    #         debug=False
                    
                    print("Epoch: [%2d/%2d] [%6d/%6d] time: %.2f, loss: %.3f (inference: %.3f, wd: %.3f), acc: %.3f" % (i, self.epoch_num, j, self.step_per_epoch, time.time() - start_time, l, l_inf, l_wd, acc))
                    start_time = time.time()
                    if counter % self.val_freq == 0:
                        saver_ckpt.save(sess, os.path.join(self.checkpoint_dir, 'ckpt-m'), global_step=counter)
                        acc = []
                        with open(self.val_log, 'a') as f:
                            f.write('step: %d\n' % counter)
                            for k, v in self.val_data.items():
                                imgs, imgs_f, issame = load_bin(v, self.image_size)
                                embds = self.run_embds(sess, imgs)
                                embds_f = self.run_embds(sess, imgs_f)
                                embds = embds/np.linalg.norm(embds, axis=1, keepdims=True)+embds_f/np.linalg.norm(embds_f, axis=1, keepdims=True)
                                tpr, fpr, acc_mean, acc_std, tar, tar_std, far = evaluate(embds, issame, far_target=1e-3, distance_metric=0)
                                f.write('eval on %s: acc--%1.5f+-%1.5f, tar--%1.5f+-%1.5f@far=%1.5f\n' % (k, acc_mean, acc_std, tar, tar_std, far))
                                acc.append(acc_mean)
                            acc = np.mean(np.array(acc))
                            if acc > best_acc:
                                saver_best.save(sess, os.path.join(self.model_dir, 'best-m'), global_step=counter)
                                best_acc = acc
    def train(self):
        self.build()
        #analyze_vars(tf.trainable_variables(), os.path.join(self.output_dir, 'model_vars.txt'))
        analyze_vars(tf.all_variables(),
                     os.path.join(self.output_dir, 'model_vars.txt'))
        with open(os.path.join(self.output_dir, 'regularizers.txt'), 'w') as f:
            for v in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
                f.write(v.name + '\n')
        tf_config = tf.ConfigProto(allow_soft_placement=True)
        tf_config.gpu_options.allow_growth = True
        with tf.Session(config=tf_config) as sess:
            tf.global_variables_initializer().run()
            saver_ckpt = tf.train.Saver()
            saver_best = tf.train.Saver()
            #saver_embd = tf.train.Saver(var_list=[v for v in tf.trainable_variables() if 'embd_extractor' in v.name])
            v_names = []
            with open(
                    '/home/fangyu/fy/tflite/myinsightface_tf/origin_valiable_names.txt',
                    'r') as fd:
                lines = fd.readlines()
            for line in lines:
                v_names.append(line.strip())
            #var_list=[v for v in tf.trainable_variables() if v.name in v_names]
            #print(var_list)
            saver_embd = tf.train.Saver(var_list=[
                v for v in tf.trainable_variables() if v.name in v_names
            ])

            if config['pretrained_model'] != '':
                saver_embd.restore(
                    sess,
                    tf.train.latest_checkpoint(config['pretrained_model']))
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            start_time = time.time()
            best_acc = 0
            counter = 0
            debug = True
            for i in range(self.epoch_num):
                if i < config['fixed_epoch_num']:
                    cur_train_op = self.train_op_softmax
                else:
                    cur_train_op = self.train_op
                for j in range(self.step_per_epoch):
                    _, l, l_wd, l_inf, acc, s, _ = sess.run(
                        [
                            cur_train_op, self.train_loss, self.wd_loss,
                            self.inference_loss, self.train_acc,
                            self.train_summary, self.inc_op
                        ],
                        feed_dict={
                            self.train_phase_dropout: True,
                            self.train_phase_bn: True
                        })
                    counter += 1

                    print(
                        "Epoch: [%2d/%2d] [%6d/%6d] time: %.2f, loss: %.3f (inference: %.3f, wd: %.3f), acc: %.3f"
                        % (i, self.epoch_num, j, self.step_per_epoch,
                           time.time() - start_time, l, l_inf, l_wd, acc))

                    start_time = time.time()
                    if counter % self.val_freq == 0:
                        saver_ckpt.save(sess,
                                        os.path.join(self.checkpoint_dir,
                                                     'ckpt-m'),
                                        global_step=counter)
                        acc = []
                        with open(self.val_log, 'a') as f:
                            f.write('step: %d\n' % counter)
                            for k, v in self.val_data.items():
                                imgs, imgs_f, issame = load_bin(
                                    v, self.image_size)
                                embds = self.run_embds(sess, imgs)
                                embds_f = self.run_embds(sess, imgs_f)
                                embds = embds / np.linalg.norm(
                                    embds, axis=1,
                                    keepdims=True) + embds_f / np.linalg.norm(
                                        embds_f, axis=1, keepdims=True)
                                tpr, fpr, acc_mean, acc_std, tar, tar_std, far = evaluate(
                                    embds,
                                    issame,
                                    far_target=1e-3,
                                    distance_metric=0)
                                f.write(
                                    'eval on %s: acc--%1.5f+-%1.5f, tar--%1.5f+-%1.5f@far=%1.5f\n'
                                    %
                                    (k, acc_mean, acc_std, tar, tar_std, far))
                                acc.append(acc_mean)
                            acc = np.mean(np.array(acc))
                            if acc > best_acc:
                                saver_best.save(sess,
                                                os.path.join(
                                                    self.model_dir, 'best-m'),
                                                global_step=counter)
                                best_acc = acc