Exemplo n.º 1
0
def main(args):
    utils.print_arguments(args)
    print('--------------------------------------')

    dap_func = dap_func_wrapper(args.image_height, args.image_width)
    gallery_paths, gallery_labels, probe_paths, probe_labels = get_gallery_and_probe(
        args.test_name)
    evaluator = FaceEvaluator(
        devices=args.devices,
        # 数据输入有关的参数
        dap_func=dap_func,
        num_dap_threads=args.num_dap_threads,
        batch_size=args.batch_size,
        # 网络模型有关的参数
        model_def=args.model_def,
        embedding_size=args.embedding_size,
        use_batch_norm=args.use_batch_norm,
        use_normalized=args.use_normalized,
        pretrained_model=args.pretrained_model,
        fusion_method=args.fusion_method)
    evaluator.evaluate_1vsn(probe_paths=probe_paths,
                            probe_labels=probe_labels,
                            gallery_paths=gallery_paths,
                            gallery_labels=gallery_labels,
                            top_num=args.top_num,
                            save_feature=args.save_feature,
                            use_detail=args.use_detail)
Exemplo n.º 2
0
def train_and_evaluate(training_mode,
                       graph,
                       model,
                       logdir,
                       num_steps=8600,
                       verbose=False):
    """Helper to run the model with different training modes."""

    log_filename = os.path.join(logdir, 'train_log.txt')
    set_logger(log_filename, logging.INFO)

    logging.info('start training')
    old_time = time.time()

    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        filenames_placeholder = tf.placeholder(tf.string, name='image_paths')
        learning_rate_placeholder = tf.placeholder(tf.float32, [])

        input_queue = tf.FIFOQueue(capacity=32, dtypes=tf.string, shapes=[()])
        enqueue_op = input_queue.enqueue_many(filenames_placeholder)

        dap_func = dap_func_wrapper(FLAGS.image_width, FLAGS.image_height)
        reader = tf.TextLineReader()
        images_and_labels_list = []

        num_dap_threads = 4
        for _ in range(num_dap_threads):
            key, record = reader.read(input_queue)
            decoded = tf.decode_csv(record,
                                    record_defaults=record_defaults,
                                    field_delim=',')
            content = tf.read_file(decoded[0])
            image = dap_func(content, augment=True)
            images_and_labels_list.append([record, image, decoded[1]])

        record_batch, image_batch, labels_batch = tf.train.batch_join(
            images_and_labels_list,
            batch_size=batch_size_placeholder,
            enqueue_many=False,
            capacity=2 * num_dap_threads * FLAGS.batch_size,
            allow_smaller_final_batch=True)

        regular_train_op, regular_loss, pred_loss, scale_inner = \
            build_model_source(image_batch, labels_batch, learning_rate_placeholder, global_step)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            tf.global_variables_initializer().run()
            default_to_restore = tf.trainable_variables()
            #default_to_restore += tf.get_collection(tf.GraphKeys.RESTORE_VARIABLES)
            saver = tf.train.Saver(default_to_restore, max_to_keep=5)
            if FLAGS.checkpoint_path:
                _model_restore_fn(FLAGS.checkpoint_path, sess,
                                  default_to_restore,
                                  FLAGS.restore_from_base_network)

            filenames = [FLAGS.file_names]
            for epoch in range(1, FLAGS.max_nrof_epochs):
                # Enqueue one epoch of image paths and labels
                num_examples = sum(
                    [utils.get_file_line_count(name) for name in filenames])
                num_batches = (num_examples + FLAGS.batch_size -
                               1) // FLAGS.batch_size
                step = sess.run(global_step, feed_dict=None)
                sess.run(enqueue_op, {filenames_placeholder: filenames})

                for i in range(num_batches):
                    batch_size_actual = min(
                        num_examples - i * FLAGS.batch_size, FLAGS.batch_size)
                    source_lr = _learning_rate_fn(FLAGS.base_lr, step)
                    feed_dict = {learning_rate_placeholder: source_lr, \
                                         batch_size_placeholder: batch_size_actual}
                    start_time = time.time()
                    _, batch_loss, ploss, scale, step = sess.run(
                        [
                            regular_train_op, regular_loss, pred_loss,
                            scale_inner, global_step
                        ],
                        feed_dict=feed_dict)

                    elapsed = time.time() - start_time
                    logging.info(('epoch:{} step:{} iter:{}/{} Time: {:.4f}s  loss: {:.4f} ploss: {:.4f}  lr: {:.4f}  scale:{:.4f}'\
                    .format(epoch, step, i, num_batches, \
                    elapsed, batch_loss, ploss, source_lr, scale)))

                    if step % FLAGS.save_model_steps == 0:
                        filename = os.path.join(
                            logdir, 'model_iter_{:d}'.format(step) + '.ckpt')
                        saver.save(sess, filename)