Beispiel #1
0
def main(args):
    if not os.path.exists(args.pretrained_model):
        print('invalid pretrained model path')
        return
    weights = np.load(args.pretrained_model)

    pairs = test_utils.read_pairs('/exports_data/czj/data/lfw/files/pairs.txt')
    imglist, labels = test_utils.get_paths(
        '/exports_data/czj/data/lfw/lfw_aligned/', pairs, '_face_.jpg')
    total_images = len(imglist)

    # ---- build graph ---- #
    input = tf.placeholder(tf.float32,
                           shape=[None, 160, 160, 3],
                           name='image_batch')
    prelogits, _ = inception_resnet_v1.inference(input, 1, phase_train=False)
    embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10)

    # ---- extract ---- #
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            log_device_placement=False,
                                            allow_soft_placement=True))
    with sess.as_default():
        beg_time = time.time()
        to_assign = [
            v.assign(weights[()][v.name][0])
            for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        ]
        sess.run(to_assign)
        print('restore parameters: %.2fsec' % (time.time() - beg_time))

        beg_time = time.time()
        images = load_data(imglist)
        print('load images: %.2fsec' % (time.time() - beg_time))

        beg_time = time.time()
        batch_size = 32
        beg = 0
        end = 0
        features = np.zeros((total_images, 128))
        while end < total_images:
            end = min(beg + batch_size, total_images)
            features[beg:end] = sess.run(embeddings, {input: images[beg:end]})
            beg = end
        print('extract features: %.2fsec' % (time.time() - beg_time))

    tpr, fpr, acc, vr, vr_std, far = test_utils.evaluate(features,
                                                         labels,
                                                         num_folds=10)
    # display
    auc = metrics.auc(fpr, tpr)
    eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.)
    print('Acc:           %1.3f+-%1.3f' % (np.mean(acc), np.std(acc)))
    print('VR@FAR=%2.5f:  %2.5f+-%2.5f' % (far, vr, vr_std))
    print('AUC:           %1.3f' % auc)
    print('EER:           %1.3f' % eer)
    sess.close()
Beispiel #2
0
def main(args):
    pairs = test_utils.read_pairs(args.lfw_pairs)
    model_list = test_utils.get_model_list(args.model_list)
    for t, model in enumerate(model_list):
        # get lfw pair filename
        paths, labels = test_utils.get_paths(args.lfw_dir, pairs, model[1])
        with tf.device('/gpu:%d' % (t + 1)):
            gpu_options = tf.GPUOptions(allow_growth=True)
            sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                    log_device_placement=False, allow_soft_placement=True))
            with sess.as_default():
                print("[%d] model: %s" % (t, model[1]))
                # restore model
                test_utils.load_model(sess, model[0])
                # load data tensor
                images_pl = tf.get_default_graph().get_tensor_by_name('image_batch:0')
                embeddings = tf.get_default_graph().get_tensor_by_name('embeddings:0')
                phase_train_pl = tf.get_default_graph().get_tensor_by_name('phase_train:0')
                image_size = args.image_size
                emb_size = embeddings.get_shape()[1]
                # extract feature
                batch_size = args.lfw_batch_size
                num_images = len(paths)
                num_batches = num_images // batch_size
                emb_arr = np.zeros((num_images, emb_size))
                for i in range(num_batches):
                    print('process %d/%d' % (i + 1, num_batches), end='\r')
                    beg_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, num_images)
                    images = test_utils.load_data(paths[beg_idx:end_idx], image_size)
                    emb = sess.run(embeddings, feed_dict={images_pl: images, phase_train_pl: False})
                    emb_arr[beg_idx:end_idx, :] = emb
        # get lfw pair filename
        print("\ndone.")
        # concate feaure
        if t == 0:
            emb_ensemble = emb_arr * math.sqrt(float(model[2]))
        else:
            emb_ensemble = np.concatenate((emb_ensemble, emb_arr * math.sqrt(float(model[2]))), axis=1)
        print("ensemble feature:", emb_ensemble.shape)

    '''
    norm = np.linalg.norm(emb_ensemble, axis=1)
    for i in range(emb_ensemble.shape[0]):
        emb_ensemble[i] = emb_ensemble[i] / norm[i]
    '''

    tpr, fpr, acc, vr, vr_std, far = test_utils.evaluate(emb_ensemble, labels, num_folds=args.num_folds)
    # display
    auc = metrics.auc(fpr, tpr)
    eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.)
    print('Acc:           %1.3f+-%1.3f' % (np.mean(acc), np.std(acc)))
    print('VR@FAR=%2.5f:  %2.5f+-%2.5f' % (far, vr, vr_std))
    print('AUC:           %1.3f' % auc)
    print('EER:           %1.3f' % eer)
Beispiel #3
0
def main(args):
    model_module = importlib.import_module(args.model_def)
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(args.logs_base_dir, subdir)
    model_dir = os.path.join(args.models_base_dir, subdir)
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print('log   dir: %s' % log_dir)
    print('model dir: %s' % model_dir)

    if args.lfw_dir:
        print('LFW directory: %s' % args.lfw_dir)
        pairs = test_utils.read_pairs(args.lfw_pairs)
        lfw_paths, actual_issame = test_utils.get_paths(
            args.lfw_dir, pairs, args.lfw_file_ext)

    with tf.Graph().as_default():
        # ---- data prepration ---- #
        image_list, label_list, num_classes = train_utils.get_datasets(
            args.data_dir, args.imglist_path)
        range_size = len(image_list)
        assert range_size > 0, 'The dataset should not be empty.'
        # random indices producer
        indices_que = tf.train.range_input_producer(range_size)
        deque_op = indices_que.dequeue_many(args.batch_size * args.epoch_size,
                                            'index_dequeue')

        tf.set_random_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        global_step = tf.Variable(0, trainable=False)
        lr_pl = tf.placeholder(tf.float32, name='learning_rate')
        batch_size_pl = tf.placeholder(tf.int32, name='batch_size')
        phase_train_pl = tf.placeholder(tf.bool, name='phase_train')
        imgpaths_pl = tf.placeholder(tf.string, name='image_paths')
        labels_pl = tf.placeholder(tf.int64, name='labels')

        # filename queue
        input_queue = tf.FIFOQueue(
            # [notice: capacity > bach_size*epoch_size]
            capacity=100000,
            dtypes=[tf.string, tf.int64],
            shapes=[(1, ), (1, )],
            shared_name=None,
            name='input_que')
        enque_op = input_queue.enqueue_many([imgpaths_pl, labels_pl],
                                            name='enque_op')
        # define 4 readers
        num_threads = 4
        threads_input_list = []
        for _ in range(num_threads):
            img_paths, label = input_queue.dequeue(
            )  # [notice: 'img_pathx' and 'label' are both tensors]
            images = []
            for img_path in tf.unstack(img_paths):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                if args.random_crop:
                    img = tf.random_crop(img,
                                         [args.image_size, args.image_size, 3])
                else:
                    img = tf.image.resize_image_with_crop_or_pad(
                        img, args.image_size, args.image_size)
                if args.random_flip:
                    img = tf.image.random_flip_left_right(img)
                img.set_shape((args.image_size, args.image_size, 3))
                images.append(
                    tf.image.per_image_standardization(img))  # prewhitened?
            threads_input_list.append([images, label])

        # define 4 buffer queue
        image_batch, label_batch = tf.train.batch_join(
            threads_input_list,
            # [notice: here is 'batch_size_pl', not 'batch_size'!!]
            batch_size=batch_size_pl,
            shapes=[(args.image_size, args.image_size, 3), ()],
            enqueue_many=True,
            # [notice: how long the prefetching is allowed to fill the queue]
            capacity=4 * num_threads * args.batch_size,
            allow_smaller_final_batch=True)
        image_batch = tf.identity(image_batch, 'image_batch')
        image_batch = tf.identity(image_batch, 'input')
        label_batch = tf.identity(label_batch, 'label_batch')
        print('Total classes: %d' % num_classes)
        print('Total images:  %d' % range_size)
        tf.summary.image('input_images', image_batch, 10)

        # ---- build graph ---- #
        with tf.device('/gpu:%d' % args.gpu_id):
            # embeddings
            prelogits, _ = model_module.inference(
                image_batch,
                args.keep_prob,
                phase_train=phase_train_pl,
                weight_decay=args.weight_decay)
            # logits
            logits = slim.fully_connected(
                prelogits,
                num_classes,
                activation_fn=None,
                weights_initializer=tf.truncated_normal_initializer(
                    stddev=0.1),
                weights_regularizer=slim.l2_regularizer(args.weight_decay),
                scope='Logits',
                reuse=False)
            # normalized features
            # [notice: used in test stage]
            embeddings = tf.nn.l2_normalize(prelogits,
                                            1,
                                            1e-10,
                                            name='embeddings')
            # ---- define loss & train op ---- #
            # center loss
            if args.center_loss_factor > 0.0:
                prelogits_center_loss, _ = train_utils.center_loss(
                    prelogits, label_batch, args.center_loss_alpha,
                    num_classes)
                tf.summary.scalar(
                    'center_loss',
                    prelogits_center_loss * args.center_loss_factor)
                tf.add_to_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES,
                    prelogits_center_loss * args.center_loss_factor)
            # cross-entropy
            cross_entropy_mean = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=label_batch, logits=logits),
                name='cross_entropy')
            tf.add_to_collection('losses', cross_entropy_mean)
            # regularity: weight decay
            reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            # total loss
            total_loss = tf.add_n([cross_entropy_mean] + reg_loss,
                                  name='total_loss')
            # [notice: here we decay manually]
            lr = tf.train.exponential_decay(lr_pl,
                                            global_step,
                                            args.lr_decay_epochs *
                                            args.epoch_size,
                                            args.lr_decay_factor,
                                            staircase=True)
            tf.summary.scalar('learning_rate', lr)
            train_op = train_utils.get_train_op(
                total_loss,
                global_step,
                args.optimizer,
                lr,
                args.moving_average_decay,
                # what is the usage of tf.global_variables()?
                tf.trainable_variables())

        # ---- training ---- #
        # [notice: use 'allow_growth' instead of memory_fraction]
        gpu_options = tf.GPUOptions(allow_growth=True)
        # [notice: use 'allow_soft_placement' to solve the problem of 'no supported kernel...']
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                log_device_placement=False,
                                                allow_soft_placement=True))
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # [notice: 'max_to_keep': keep at most 'max_to_keep' checkpoint files]
        saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=5)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            if args.pretrained_model:
                print('Resume training: %s' % args.pretrained_model)
                saver.restore(sess, args.pretrained_model)
            print('Start training ...')
            epoch = 0
            while epoch < args.max_num_epochs:
                step = sess.run(global_step,
                                feed_dict=None)  # training counter
                epoch = step // args.epoch_size

                # run epoch
                run_epoch(args, sess, epoch, image_list, label_list, deque_op,
                          enque_op, imgpaths_pl, labels_pl, lr_pl,
                          phase_train_pl, batch_size_pl, global_step,
                          total_loss, reg_loss, train_op, summary_op,
                          summary_writer)

                # snapshot for currently learnt weights
                snapshot(sess, saver, model_dir, subdir, step)

                # evaluate on LFW
                if args.lfw_dir:
                    evaluate(sess, enque_op, imgpaths_pl, labels_pl,
                             phase_train_pl, batch_size_pl, embeddings,
                             label_batch, lfw_paths, actual_issame,
                             args.lfw_batch_size, args.lfw_num_folds, log_dir,
                             step, summary_writer)
    sess.close()
Beispiel #4
0
def main(args):
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(args.logs_base_dir, subdir, 'logs')
    model_dir = os.path.join(args.logs_base_dir, subdir, 'models')
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print('log   dir: %s' % log_dir)
    print('model dir: %s' % model_dir)

    # build the graph
    # ---- load pretrained model ---- #
    pretrained = {}
    # Face model
    pretrained['Face'] = np.load(args.face_model)[()]
    # Nose model
    pretrained['Nose'] = np.load(args.nose_model)[()]
    # Lefteye model
    pretrained['Lefteye'] = np.load(args.lefteye_model)[()]
    # Rightmouth model
    pretrained['Rightmouth'] = np.load(args.rightmouth_model)[()]
    # ---- data preparation ---- #
    image_list, label_list, num_classes = train_utils.get_datasets(
        args.data_dir, args.imglist)
    range_size = len(image_list)

    if args.lfw_dir:
        print('LFW directory: %s' % args.lfw_dir)
        pairs = test_utils.read_pairs(args.lfw_pairs)
        lfw_paths, actual_issame = test_utils.get_paths(
            args.lfw_dir, pairs, args.lfw_file_ext)

    with tf.Graph().as_default():
        # random indices producer
        indices_que = tf.train.range_input_producer(range_size)
        dequeue_op = indices_que.dequeue_many(
            args.batch_size * args.epoch_size, 'index_dequeue')

        tf.set_random_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        global_step = tf.Variable(0, trainable=False)
        # lr_base_pl    = tf.placeholder(tf.float32, name='base_learning_rate')
        lr_fusion_pl = tf.placeholder(tf.float32, name='fusion_learning_rate')
        batch_size_pl = tf.placeholder(tf.int32, name='batch_size')
        phase_train_pl = tf.placeholder(tf.bool, name='phase_train')
        face_pl = tf.placeholder(tf.string, name='image_paths1')  # face images
        nose_pl = tf.placeholder(tf.string, name='image_paths2')  # nose images
        lefteye_pl = tf.placeholder(tf.string,
                                    name='image_paths3')  # left eye images
        rightmouth_pl = tf.placeholder(
            tf.string, name='image_paths4')  # right mouth images
        labels_pl = tf.placeholder(tf.int64, name='labels')

        # define a filename queue
        input_queue = tf.FIFOQueue(
            # [notice: capacity > bach_size*epoch_size]
            capacity=100000,
            dtypes=[tf.string, tf.string, tf.string, tf.string, tf.int64],
            shapes=[(1, ), (1, ), (1, ), (1, ), (1, )],
            shared_name=None,
            name='input_que')
        enque_op = input_queue.enqueue_many(
            [face_pl, nose_pl, lefteye_pl, rightmouth_pl, labels_pl],
            name='enque_op')
        # define 4 readers
        num_threads = 4
        threads_input_list = []
        for _ in range(num_threads):
            imgpath1, imgpath2, imgpath3, imgpath4, label = input_queue.dequeue(
            )
            images1 = []
            images2 = []
            images3 = []
            images4 = []
            # face
            for img_path in tf.unstack(imgpath1):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                # [notice: random crop only used in face image]
                if args.random_crop:
                    img = tf.random_crop(img, [160, 160, 3])
                else:
                    img = tf.image.resize_image_with_crop_or_pad(img, 160, 160)
                # [notice: flip only used in face image or nose patch]
                if args.random_flip:
                    img = tf.image.random_flip_left_right(img)
                img.set_shape((160, 160, 3))
                images1.append(tf.image.per_image_standardization(img))
            # Nose
            for img_path in tf.unstack(imgpath2):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                # [notice: flip only used in face image or nose patch]
                if args.random_flip:
                    img = tf.image.random_flip_left_right(img)
                img.set_shape((160, 160, 3))
                images2.append(tf.image.per_image_standardization(img))
            # Lefteye
            for img_path in tf.unstack(imgpath3):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                img.set_shape((160, 160, 3))
                images3.append(tf.image.per_image_standardization(img))
            # Rightmouth
            for img_path in tf.unstack(imgpath4):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                img.set_shape((160, 160, 3))
                images4.append(tf.image.per_image_standardization(img))
            threads_input_list.append(
                [images1, images2, images3, images4, label])

        # define 4 buffer queue
        face_batch, nose_batch, lefteye_batch, rightmouth_batch, label_batch = tf.train.batch_join(
            threads_input_list,
            # [notice: here is 'batch_size_pl', not 'batch_size'!!]
            batch_size=batch_size_pl,
            shapes=[
                # [notice: shape of each element should be assigned, otherwise it raises
                # "tensorflow queue shapes must have the same length as dtype" exception]
                (args.image_size, args.image_size, 3),
                (args.image_size, args.image_size, 3),
                (args.image_size, args.image_size, 3),
                (args.image_size, args.image_size, 3),
                ()
            ],
            enqueue_many=True,
            # [notice: how long the prefetching is allowed to fill the queue]
            capacity=4 * num_threads * args.batch_size,
            allow_smaller_final_batch=True)
        print('Total classes: %d' % num_classes)
        print('Total images:  %d' % range_size)
        tf.summary.image('face_images', face_batch, 10)
        tf.summary.image('nose_images', nose_batch, 10)
        tf.summary.image('lefteye_images', lefteye_batch, 10)
        tf.summary.image('rightmouth_images', rightmouth_batch, 10)

        # ---- build graph ---- #
        with tf.variable_scope('BaseModel'):
            with tf.device('/gpu:%d' % args.gpu_id1):
                # embeddings for face model
                features1, _ = inception_resnet_v1.inference(
                    face_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Face')
            with tf.device('/gpu:%d' % args.gpu_id2):
                # embeddings for nose model
                features2, _ = inception_resnet_v1.inference(
                    nose_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Nose')
            with tf.device('/gpu:%d' % args.gpu_id3):
                # embeddings for left eye model
                features3, _ = inception_resnet_v1.inference(
                    lefteye_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Lefteye')
            with tf.device('/gpu:%d' % args.gpu_id4):
                # embeddings for right mouth model
                features4, _ = inception_resnet_v1.inference(
                    rightmouth_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Rightmouth')
        with tf.device('/gpu:%d' % args.gpu_id5):
            with tf.variable_scope("Fusion"):
                # ---- concatenate ---- #
                concated_features = tf.concat(
                    [features1, features2, features3, features4], 1)
                # prelogits
                prelogits = slim.fully_connected(
                    concated_features,
                    args.fusion_dim,
                    activation_fn=None,
                    weights_initializer=tf.truncated_normal_initializer(
                        stddev=0.1),
                    weights_regularizer=slim.l2_regularizer(args.weight_decay),
                    scope='prelogits',
                    reuse=False)
                # logits
                logits = slim.fully_connected(
                    prelogits,
                    num_classes,
                    activation_fn=None,
                    weights_initializer=tf.truncated_normal_initializer(
                        stddev=0.1),
                    weights_regularizer=slim.l2_regularizer(args.weight_decay),
                    scope='logits',
                    reuse=False)
                # normalized feaures
                # [notice: used in test stage]
                embeddings = tf.nn.l2_normalize(prelogits,
                                                1,
                                                1e-10,
                                                name='embeddings')

            # ---- define loss & train op ---- #
            cross_entropy = -tf.reduce_sum(tf.one_hot(
                indices=tf.cast(label_batch, tf.int32),
                depth=num_classes) * tf.log(tf.nn.softmax(logits) + 1e-10),
                                           reduction_indices=[1])
            cross_entropy_mean = tf.reduce_mean(cross_entropy)
            tf.add_to_collection('losses', cross_entropy_mean)
            # weight decay
            reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            # total loss: cross_entropy + weight_decay
            total_loss = tf.add_n([cross_entropy_mean] + reg_loss,
                                  name='total_loss')
            '''
            lr_base = tf.train.exponential_decay(lr_base_pl,
                global_step,
                args.lr_decay_epochs * args.epoch_size,
                args.lr_decay_factor,
                staircase = True)
            '''
            lr_fusion = tf.train.exponential_decay(lr_fusion_pl,
                                                   global_step,
                                                   args.lr_decay_epochs *
                                                   args.epoch_size,
                                                   args.lr_decay_factor,
                                                   staircase=True)
            # tf.summary.scalar('base_learning_rate', lr_base)
            tf.summary.scalar('fusion_learning_rate', lr_fusion)
            var_list1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='BaseModel')
            var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='Fusion')
            '''
            train_op = train_utils.get_fusion_train_op(
                total_loss, global_step, args.optimizer,
                lr_base, var_list1, lr_fusion, var_list2,
                args.moving_average_decay)
            '''
            train_op = train_utils.get_train_op(total_loss, global_step,
                                                args.optimizer, lr_fusion,
                                                args.moving_average_decay,
                                                var_list2)

        # ---- training ---- #
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            # [notice: 'allow_soft_placement' will switch to cpu automatically
            #  when some operations are not supported by GPU]
            allow_soft_placement=True))
        saver = tf.train.Saver(var_list1 + var_list2)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            # ---- restore pre-trained parameters ---- #
            to_assign = []
            print("restore pretrained parameters...")
            print("total:", len(var_list1))
            for v in var_list1:
                v_name = v.name  # 'BaseModel/Face/xxx'
                v_name = v_name[v_name.find('/') + 1:]  # 'Face/xxx'
                v_name_1 = v_name[:v_name.find('/')]  # 'Face'
                v_name_2 = v_name[v_name.find('/'):]  # '/xxx'
                print("precess: %s" % v_name, end=" ")
                if v_name_1 in pretrained:
                    to_assign.append(
                        v.assign(pretrained[v_name_1][v_name_2][0]))
                    print("[ok]")
                else:
                    print("[no found]")
                    v.assign(pretrained[v_name_1][v_name_2][0])
                    print("done")
            sess.run(to_assign)

            print("start training ...")
            epoch = 0
            while epoch < args.max_num_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size

                # run one epoch
                run_epoch(args, sess, epoch, image_list, label_list,
                          dequeue_op, enque_op, face_pl, nose_pl, lefteye_pl,
                          rightmouth_pl, labels_pl, lr_fusion_pl,
                          phase_train_pl, batch_size_pl, global_step,
                          total_loss, reg_loss, train_op, summary_op,
                          summary_writer)

                # snapshot for currently learnt weights
                snapshot(sess, saver, model_dir, subdir, step)

                # evaluate on LFW
                if args.lfw_dir:
                    evaluate(sess, enque_op, face_pl, nose_pl, lefteye_pl,
                             rightmouth_pl, labels_pl, phase_train_pl,
                             batch_size_pl, embeddings, label_batch, lfw_paths,
                             actual_issame, args.lfw_batch_size,
                             args.lfw_num_folds, log_dir, step, summary_writer)
    sess.close()
Beispiel #5
0
def main(args):
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(args.logs_base_dir, 'logs', subdir)
    model_dir = os.path.join(args.logs_base_dir, 'models', subdir)
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print('log   dir: %s' % log_dir)
    print('model dir: %s' % model_dir)

    if args.lfw_dir:
        print('lfw directory: %s' % args.lfw_dir)
        pairs = test_utils.read_pairs(args.lfw_pairs)
        lfw_paths, lfw_label = test_utils.get_paths(args.lfw_dir, pairs,
                                                    args.lfw_file_ext)

    with tf.Graph().as_default():
        # ------------ data preparation ------------ #
        image_list, label_list, num_classes = train_utils.get_datasets(
            args.data_dir, args.imglist_path)
        range_size = len(image_list)
        assert range_size > 0, 'The data set should not be empty.'

        # random indices producer
        indices_que = tf.train.range_input_producer(range_size)
        deque_op = indices_que.dequeue_many(args.batch_size * args.epoch_size,
                                            'index_dequeue')

        # [notice: how to set random seed?]
        tf.set_random_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)

        # filename queue
        imgpaths_pl = tf.placeholder(tf.string, name='image_paths')
        labels_pl = tf.placeholder(tf.int64, name='labels')
        input_queue = tf.FIFOQueue(
            # [notice: capacity > bach_size*epoch_size]
            capacity=100000,
            dtypes=[tf.string, tf.int64],
            shapes=[(1, ), (1, )],
            shared_name=None,
            name='input_que')
        enque_op = input_queue.enqueue_many([imgpaths_pl, labels_pl],
                                            name='enque_op')

        # define 4 readers
        num_threads = 4
        threads_input_list = []
        for _ in range(num_threads):
            img_paths, label = input_queue.dequeue(
            )  # [notice: 'img_pathx' and 'label' are both tensors]
            images = []
            for img_path in tf.unstack(img_paths):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                if args.random_crop:
                    img = tf.random_crop(img,
                                         [args.image_size, args.image_size, 3])
                else:
                    img = tf.image.resize_image_with_crop_or_pad(
                        img, args.image_size, args.image_size)
                if args.random_flip:
                    img = tf.image.random_flip_left_right(img)
                img.set_shape((args.image_size, args.image_size, 3))
                images.append(
                    tf.image.per_image_standardization(img))  # pre-whitened?
            threads_input_list.append([images, label])

        # define 4 buffer queue
        batch_size_pl = tf.placeholder(tf.int32, name='batch_size')
        image_batch, label_batch = tf.train.batch_join(
            threads_input_list,
            # [notice: here is 'batch_size_pl', not 'batch_size'!!]
            batch_size=batch_size_pl,
            shapes=[(args.image_size, args.image_size, 3), ()],
            enqueue_many=True,
            # [notice: how long the pre-fetching is allowed to fill the queue]
            capacity=4 * num_threads * args.batch_size,
            allow_smaller_final_batch=True)
        image_batch = tf.identity(image_batch, 'image_batch')
        label_batch = tf.identity(label_batch, 'label_batch')

        print('Total classes: %d' % num_classes)
        print('Total images:  %d' % range_size)
        tf.summary.image('input_images', image_batch, 10)

        # ------------ build graph ------------ #
        hps_train = resnet.HParams(batch_size=batch_size_pl,
                                   num_residual_units=5,
                                   use_bottleneck=True,
                                   relu_leakiness=0.1)

        global_step = tf.Variable(0, trainable=False)
        phase_train_pl = tf.placeholder(tf.bool, name='phase_train')
        resnet_model = resnet(hps_train, phase_train_pl)
        with tf.device('/gpu:%d' % args.gpu_id):
            # ---- base graph ---- #
            with tf.variable_scope('ResNet'):
                # prelogits
                prelogits = resnet_model.inference(image_batch)

            # prelogits -> embeddings [notice: used in test stage]
            embeddings = tf.nn.l2_normalize(prelogits,
                                            1,
                                            1e-10,
                                            name='embeddings')

            # prelogits -> logits
            with tf.variable_scope('Logits'):
                logits = resnet.fully_connected(prelogits, num_classes)
                # predictions = tf.nn.softmax(logits)

            #  ---- losses ---- #
            # cross entropy
            with tf.variable_scope('cross_entropy'):
                cross_entropy = tf.reduce_sum(
                    tf.one_hot(indices=tf.cast(label_batch, tf.int32),
                               depth=num_classes) *
                    tf.log(tf.nn.softmax(logits) + 1e-10),
                    reduction_indices=[1])
                cross_entropy_mean = tf.reduce_mean(cross_entropy)
                tf.summary.scalar('cross_entropy', cross_entropy_mean)

            # l2 loss
            reg_loss = resnet.decay(args.weight_decay)
            tf.summary.scalar('reg_loss', reg_loss)

            # total loss
            total_loss = tf.add_n([cross_entropy_mean] + reg_loss,
                                  name='total_loss')
            train_op = resnet_model.get_train_op(total_loss, global_step,
                                                 args.lr)

        # ------------ training ------------ #
        # [notice: use 'allow_growth' instead of memory_fraction]
        gpu_options = tf.GPUOptions(allow_growth=True)
        # [notice: use 'allow_soft_placement' to solve the problem of 'no supported kernel...']
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                log_device_placement=False,
                                                allow_soft_placement=True))
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # [notice: 'max_to_keep': keep at most 'max_to_keep' checkpoint files]
        saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=5)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            if args.pretrained_model:
                print('Resume training: %s' % args.pretrained_model)
                saver.restore(sess, args.pretrained_model)

            print('Start training ...')
            epoch = 0
            while epoch < args.max_num_epochs:
                step = sess.run(global_step,
                                feed_dict=None)  # training counter
                epoch = step // args.epoch_size

                # run epoch
                run_epoch(args, sess, epoch, image_list, label_list, deque_op,
                          enque_op, imgpaths_pl, labels_pl, phase_train_pl,
                          batch_size_pl, global_step, total_loss, reg_loss,
                          train_op, summary_op, summary_writer)

                # snapshot for currently learnt weights
                snapshot(sess, saver, model_dir, subdir, step)

                # evaluate on LFW
                if args.lfw_dir:
                    evaluate(sess, enque_op, imgpaths_pl, labels_pl,
                             phase_train_pl, batch_size_pl, embeddings,
                             label_batch, lfw_paths, lfw_label,
                             args.lfw_batch_size, args.lfw_num_folds, log_dir,
                             step, summary_writer)
    sess.close()