예제 #1
0
 def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
     self.backbone.eval()
     self.idprehead.eval()
     idx = 0
     embeddings = np.zeros([len(carray), conf.embedding_size])
     with torch.no_grad():
         while idx + conf.batch_size <= len(carray):
             batch = torch.tensor(carray[idx:idx + conf.batch_size])
             if tta:
                 fliped = hflip_batch(batch)
                 emb_batch = self.idprehead(
                     self.backbone(batch.to(conf.device))) + self.idprehead(
                         self.backbone(fliped.to(conf.device)))
                 embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
             else:
                 embeddings[idx:idx + conf.batch_size] = self.idprehead(
                     self.backbone(batch.to(conf.device))).cpu()
             idx += conf.batch_size
         if idx < len(carray):
             batch = torch.tensor(carray[idx:])
             if tta:
                 fliped = hflip_batch(batch)
                 emb_batch = self.idprehead(
                     self.backbone(batch.to(conf.device))) + self.idprehead(
                         self.backbone(fliped.to(conf.device)))
                 embeddings[idx:] = l2_norm(emb_batch)
             else:
                 embeddings[idx:] = self.idprehead(
                     self.backbone(batch.to(conf.device))).cpu()
     tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                    nrof_folds)
     buf = gen_plot(fpr, tpr)
     roc_curve = Image.open(buf)
     roc_curve_tensor = trans.ToTensor()(roc_curve)
     return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor
예제 #2
0
def perform_val(dataLoader,
                length,
                embedding_size,
                backbone,
                issame,
                nrof_folds=10,
                tta=True):
    idx = 0
    embeddings = np.zeros([length, embedding_size])
    #prefetcher = DataPrefetcher(dataLoader)
    with torch.no_grad():
        begin = 0
        for i, (imgs, labels) in enumerate(tqdm(dataLoader)):
            imgs = imgs.permute(0, 3, 1, 2).float().cuda()
            embedding = backbone(imgs)
            if tta:
                flipImgs = hflip_batch(imgs)
                embedding += backbone(flipImgs)
            embeddings[begin:begin + imgs.shape[0]] = np.copy(
                l2_norm(embedding).cpu().data.numpy())
            begin = begin + imgs.shape[0]

    tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                   nrof_folds)
    return float(accuracy.mean())
예제 #3
0
    def on_batch_end(self, batch, logs=None):
        if batch >= 0 and batch % self.period == 0:
            acc_list = []
            for i in range(len(self.valid_list)):
                embeddings = test(self.valid_list[i], self.extractor, self.batch_size, 10)

                # get embedding
                _, _, accuracy, val, val_std, far = verification.evaluate(embeddings, self.valid_name_list, nrof_folds=fold)
                acc, std = np.mean(accuracy), np.std(accuracy)
                return acc, std, _xnorm, embeddings_list

                print('[%s][%d]XNorm: %f' % (self.valid_name_list[i], self.batch_size, x_norm))
                print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.valid_name_list[i], self.batch_size, acc, std))
                acc_list.append(acc)

            self.save_step += 1

            is_highest = False
            if len(acc_list) > 0:
                score = sum(acc_list)
                if acc_list[-1] >= self.highest_acc[-1]:
                    if acc_list[-1] > self.highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= self.highest_acc[0]:
                            is_highest = True
                            self.highest_acc[0] = score
                    self.highest_acc[-1] = acc_list[-1]
            if is_highest:
                print('saving', self.save_step)
                filepath = self.filepath.format(epoch=epoch + 1, **logs)
                self.model.save_weights(filepath, overwrite=True)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, self.highest_acc[-1]))
def tflite_model_test(model_dir): 
    model = FACENET_TFLITE(model_dir)
    img_size = 112 
    valid_dir = VALID_DIR
    datadirs = os.listdir(valid_dir)
    embedding_size = 512

    result = pt.PrettyTable( ["data set", "AUC", "ACC" ,"VR @ FAR ", "dist max", "dist min"])
    val_all = []
    auc_all = []
    acc_all = []
    dist_all = []
    for i in range(len(datadirs)) : 
        path = valid_dir + "/" + datadirs[i]
        embeddings, issamelab = predict_evaluate_data_TFLITE(path, [img_size,img_size], model, embedding_size)
            
        embeddings1 = embeddings[0::2]
        embeddings2 = embeddings[1::2]
        diff = np.subtract(embeddings1, embeddings2)
        dist = np.sum(np.square(diff), 1)
        del embeddings1,embeddings2
        fpr, tpr,ths = roc_curve(np.asarray(issamelab).astype('int'),dist, pos_label=0 )
        auc_score = auc(fpr, tpr)
        if 0:
            plt.figure()
            lw = 2
            plt.plot(fpr, tpr, color='darkorange',
                              lw=lw, label='ROC curve (area = %0.2f)' % auc_score)
            plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver operating characteristic example')
            plt.legend(loc="lower right")
            #  plt.show()
            plt.savefig("roc.png")

        #  print('-----10 folds------')
        tpr, fpr, accuracy, val, val_std, far = evaluate(embeddings, issamelab)
        del embeddings,issamelab
        gc.collect()
        result.add_row([datadirs[i] ,
                         round(auc_score,2),
                         "%1.3f+-%1.3f" % (np.mean(accuracy), np.std(accuracy)),
                         "%2.5f+-%2.5f @ FAR=%2.5f" % (val, val_std, far),
                         round(np.max(dist),2),
                         round(np.min(dist),2),
                             ])
        val_all.append(val)
        acc_all.append(accuracy)
        auc_all.append(auc_score)
        dist_all.append([np.max(dist),np.min(dist)])
    
    print(result)
예제 #5
0
    def on_batch_end(self, batch, logs=None):
        # t = time.time()
        # print('time:', t - self.t)
        # self.t = t
        global_step = self.epoch * self.steps_per_epoch + batch
        if global_step > 0 and global_step % self.period == 0:
            acc_list = []
            logs = logs or {}
            for key in self.valid_list:
                print('Test on valid set:', key)
                bins, is_same_list = self.valid_list[key]
                dataset = tf.data.Dataset.from_tensor_slices(bins) \
                    .map(data_input.get_valid_parse_function(False), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
                    .batch(256)
                dataset_flip = tf.data.Dataset.from_tensor_slices(bins) \
                    .map(data_input.get_valid_parse_function(True), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
                    .batch(256)
                batch_num = len(bins) // 256
                if len(bins) % 256 != 0:
                    batch_num += 1
                print('predicting...')
                embeddings = self.extractor.predict(dataset,
                                                    steps=batch_num,
                                                    verbose=0)
                embeddings_flip = self.extractor.predict(dataset_flip,
                                                         steps=batch_num,
                                                         verbose=0)
                embeddings_parts = [embeddings, embeddings_flip]
                x_norm = 0.0
                x_norm_cnt = 0
                for part in embeddings_parts:
                    for i in range(part.shape[0]):
                        embedding = part[i]
                        norm = np.linalg.norm(embedding)
                        x_norm += norm
                        x_norm_cnt += 1
                x_norm /= x_norm_cnt
                embeddings = embeddings_parts[0] + embeddings_parts[1]
                embeddings = sklearn.preprocessing.normalize(embeddings)
                print(embeddings.shape)
                _, _, accuracy, val, val_std, far = verification.evaluate(
                    embeddings, is_same_list, folds=10)
                acc, std = np.mean(accuracy), np.std(accuracy)

                print('[%s][%d]XNorm: %f' % (key, batch, x_norm))
                print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                      (key, batch, acc, std))
                acc_list.append(acc)

            logs['score'] = sum(acc_list)
            if self.verbose > 0:
                print('\nScore of step %05d: %0.5f' %
                      (global_step, logs['score']))
예제 #6
0
def main(args):
    with tf.Graph().as_default():
        with tf.Session() as sess:
            # prepare validate datasets
            ver_list = []
            ver_name_list = []
            for db in args.eval_datasets:
                print('begin db %s convert.' % db)
                data_set = load_data(db, args.image_size, args)
                ver_list.append(data_set)
                ver_name_list.append(db)

            # Load the model
            load_model(args.model)

            # Get input and output tensors, ignore phase_train_placeholder for it have default value.
            inputs_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")

            # image_size = images_placeholder.get_shape()[1]  # For some reason this doesn't work for frozen graphs
            embedding_size = embeddings.get_shape()[1]

            for db_index in range(len(ver_list)):
                # Run forward pass to calculate embeddings
                print('\nRunnning forward pass on {} images'.format(ver_name_list[db_index]))
                start_time = time.time()
                data_sets, issame_list = ver_list[db_index]
            
                if data_sets.shape[0] % args.test_batch_size ==0:
                    nrof_batches = data_sets.shape[0] // args.test_batch_size
                else:
                    nrof_batches = data_sets.shape[0] // args.test_batch_size +1
                emb_array = np.zeros((data_sets.shape[0], embedding_size))

                for index in range(nrof_batches):
                    start_index = index * args.test_batch_size
                    end_index = min((index + 1) * args.test_batch_size, data_sets.shape[0])

                    feed_dict = {inputs_placeholder: data_sets[start_index:end_index, ...]}
                    emb_array[start_index:end_index, :] = sess.run(embeddings, feed_dict=feed_dict)

                tpr, fpr, accuracy, val, val_std, far = evaluate(emb_array, issame_list, nrof_folds=args.eval_nrof_folds)
                duration = time.time() - start_time
                print("total time %.3fs to evaluate %d images of %s" % (duration, data_sets.shape[0], ver_name_list[db_index]))
                print('Accuracy: %1.3f+-%1.3f' % (np.mean(accuracy), np.std(accuracy)))
                print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far))
                print('fpr and tpr: %1.3f %1.3f' % (np.mean(fpr, 0), np.mean(tpr, 0)))

                auc = metrics.auc(fpr, tpr)
                print('Area Under Curve (AUC): %1.3f' % auc)
                eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.)
                print('Equal Error Rate (EER): %1.3f' % eer)
예제 #7
0
            nrof_batches = data_sets.shape[0] // args.test_batch_size
            for index in range(nrof_batches
                               ):  # actual is same multiply 2, test data total
                start_index = index * args.test_batch_size
                end_index = min((index + 1) * args.test_batch_size,
                                data_sets.shape[0])

                feed_dict = {
                    inputs: data_sets[start_index:end_index, ...],
                    phase_train_placeholder: False
                }
                emb_array[start_index:end_index, :] = sess.run(
                    embeddings, feed_dict=feed_dict)

            duration = time.time() - start_time
            tpr, fpr, accuracy, val, val_std, far = evaluate(
                emb_array, issame_list, nrof_folds=args.eval_nrof_folds)

            print("total time %.3f to evaluate %d images of %s" %
                  (duration, data_sets.shape[0], ver_name_list[ver_step]))
            print('Accuracy: %1.3f+-%1.3f' %
                  (np.mean(accuracy), np.std(accuracy)))
            print('fpr and tpr: %1.3f %1.3f' %
                  (np.mean(fpr, 0), np.mean(tpr, 0)))
            print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' %
                  (val, val_std, far))

            auc = metrics.auc(fpr, tpr)
            print('Area Under Curve (AUC): %1.3f' % auc)
            eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x),
                         0., 1.)
            print('Equal Error Rate (EER): %1.3f\n' % eer)
예제 #8
0
def train_net(args):
    data_dir = config.dataset_path
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    training_path = os.path.join(data_dir, "train.tfrecords")

    print('Called with argument:', args, config)
    train_dataset, batches_per_epoch = data_input.training_dataset(
        training_path, default.per_batch_size)

    extractor, classifier = build_model((image_size[0], image_size[1], 3),
                                        args)

    global_step = 0
    ckpt_path = os.path.join(
        args.models_root, '%s-%s-%s' % (args.network, args.loss, args.dataset),
        'model-{step:04d}.ckpt')
    ckpt_dir = os.path.dirname(ckpt_path)
    print('ckpt_path', ckpt_path)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if len(args.pretrained) == 0:
        latest = tf.train.latest_checkpoint(ckpt_dir)
        if latest:
            global_step = int(latest.split('-')[-1].split('.')[0])
            classifier.load_weights(latest)
    else:
        print('loading', args.pretrained, args.pretrained_epoch)
        load_path = os.path.join(args.pretrained, '-', args.pretrained_epoch,
                                 '.ckpt')
        classifier.load_weights(load_path)

    initial_epoch = global_step // batches_per_epoch
    rest_batches = global_step % batches_per_epoch

    lr_decay_steps = [(int(x), args.lr * np.power(0.1, i + 1))
                      for i, x in enumerate(args.lr_steps.split(','))]
    print('lr_steps', lr_decay_steps)

    valid_datasets = data_input.load_valid_set(data_dir, config.val_targets)

    classifier.compile(
        optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.mom),
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()])
    classifier.summary()

    tensor_board = keras.callbacks.TensorBoard(ckpt_dir)
    tensor_board.set_model(classifier)

    train_names = ['train_loss', 'train_acc']
    train_results = []
    highest_score = 0
    for epoch in range(initial_epoch, default.end_epoch):
        for batch in range(rest_batches, batches_per_epoch + 1):
            utils.update_learning_rate(classifier, lr_decay_steps, global_step)
            train_results = classifier.train_on_batch(train_dataset,
                                                      reset_metrics=False)
            global_step += 1
            if global_step % 1000 == 0:
                print('lr-batch-epoch:',
                      float(K.get_value(classifier.optimizer.lr)), batch,
                      epoch)
            if global_step >= 0 and global_step % args.verbose == 0:
                acc_list = []
                for key in valid_datasets:
                    data_set, data_set_flip, is_same_list = valid_datasets[key]
                    embeddings = extractor.predict(data_set)
                    embeddings_flip = extractor.predict(data_set_flip)
                    embeddings_parts = [embeddings, embeddings_flip]
                    x_norm = 0.0
                    x_norm_cnt = 0
                    for part in embeddings_parts:
                        for i in range(part.shape[0]):
                            embedding = part[i]
                            norm = np.linalg.norm(embedding)
                            x_norm += norm
                            x_norm_cnt += 1
                    x_norm /= x_norm_cnt
                    embeddings = embeddings_parts[0] + embeddings_parts[1]
                    embeddings = sklearn.preprocessing.normalize(embeddings)
                    print(embeddings.shape)
                    _, _, accuracy, val, val_std, far = verification.evaluate(
                        embeddings, is_same_list, folds=10)
                    acc, std = np.mean(accuracy), np.std(accuracy)

                    print('[%s][%d]XNorm: %f' % (key, batch, x_norm))
                    print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                          (key, batch, acc, std))
                    acc_list.append(acc)

                if len(acc_list) > 0:
                    score = sum(acc_list)
                    if highest_score == 0:
                        highest_score = score
                    elif highest_score >= score:
                        print('\nStep %05d: score did not improve from %0.5f' %
                              (global_step, highest_score))
                    else:
                        path = ckpt_path.format(step=global_step)
                        print(
                            '\nStep %05d: score improved from %0.5f to %0.5f,'
                            ' saving model to %s' %
                            (global_step, highest_score, score, path))
                        highest_score = score
                        classifier.save_weights(path)

        utils.write_log(tensor_board, train_names, train_results, epoch)
        classifier.reset_metrics()
예제 #9
0
def custom_insightface_evaluation(args):
    tf.reset_default_graph()
    # Read the directory containing images
    pairs = read_pairs(args.insightface_pair)
    image_list, issame_list = get_paths_with_pairs(
        args.insightface_dataset_dir, pairs)

    #  Evaluate custom dataset with facenet pre-trained model
    print("Getting embeddings with facenet pre-trained model")

    # Getting batched images by TF dataset
    tf_dataset = facenet.tf_gen_dataset(
        image_list=image_list,
        label_list=None,
        nrof_preprocess_threads=args.nrof_preprocess_threads,
        image_size=args.insightface_dataset_dir,
        method='cache_slices',
        BATCH_SIZE=args.batch_size,
        repeat_count=1,
        to_float32=True,
        shuffle=False)
    # tf_dataset = facenet.tf_gen_dataset(image_list, label_list, args.nrof_preprocess_threads, args.facenet_image_size, method='cache_slices',
    #                                     BATCH_SIZE=args.batch_size, repeat_count=1, shuffle=False)
    tf_dataset_iterator = tf_dataset.make_initializable_iterator()
    tf_dataset_next_element = tf_dataset_iterator.get_next()

    images = tf.placeholder(name='img_inputs',
                            shape=[
                                None, args.insightface_image_size,
                                args.insightface_image_size, 3
                            ],
                            dtype=tf.float32)
    labels = tf.placeholder(name='img_labels', shape=[
        None,
    ], dtype=tf.int64)
    dropout_rate = tf.placeholder(name='dropout_rate', dtype=tf.float32)

    w_init_method = tf.contrib.layers.xavier_initializer(uniform=False)
    net = L_Resnet_E_IR_fix_issue9.get_resnet(images,
                                              args.net_depth,
                                              type='ir',
                                              w_init=w_init_method,
                                              trainable=False,
                                              keep_rate=dropout_rate)
    embeddings = net.outputs
    # mv_mean = tl.layers.get_variables_with_name('resnet_v1_50/bn0/moving_mean', False, True)[0]
    # 3.2 get arcface loss
    logit = arcface_loss(embedding=net.outputs,
                         labels=labels,
                         w_init=w_init_method,
                         out_num=args.num_output)

    sess = tf.Session()
    saver = tf.train.Saver()

    feed_dict = {}
    feed_dict_flip = {}
    path = args.ckpt_file + args.ckpt_index_list[0]
    saver.restore(sess, path)
    print('ckpt file %s restored!' % args.ckpt_index_list[0])
    feed_dict.update(tl.utils.dict_to_one(net.all_drop))
    feed_dict_flip.update(tl.utils.dict_to_one(net.all_drop))
    feed_dict[dropout_rate] = 1.0
    feed_dict_flip[dropout_rate] = 1.0

    batch_size = args.batch_size
    input_placeholder = images

    sess.run(tf_dataset_iterator.initializer)
    print('getting embeddings..')

    total_time = 0
    batch_number = 0
    embeddings_array = None
    embeddings_array_flip = None
    while True:
        try:
            images = sess.run(tf_dataset_next_element)

            data_tmp = images.copy()  # fix issues #4

            for i in range(data_tmp.shape[0]):
                data_tmp[i, ...] -= 127.5
                data_tmp[i, ...] *= 0.0078125
                data_tmp[i, ...] = cv2.cvtColor(data_tmp[i, ...],
                                                cv2.COLOR_RGB2BGR)

            # Getting flip to left_right batched images by TF dataset
            data_tmp_flip = images.copy()  # fix issues #4
            for i in range(data_tmp_flip.shape[0]):
                data_tmp_flip[i, ...] = np.fliplr(data_tmp_flip[i, ...])
                data_tmp_flip[i, ...] -= 127.5
                data_tmp_flip[i, ...] *= 0.0078125
                data_tmp_flip[i, ...] = cv2.cvtColor(data_tmp_flip[i, ...],
                                                     cv2.COLOR_RGB2BGR)

            start_time = time.time()

            feed_dict[input_placeholder] = data_tmp
            _embeddings = sess.run(embeddings, feed_dict)

            feed_dict_flip[input_placeholder] = data_tmp_flip
            _embeddings_flip = sess.run(embeddings, feed_dict_flip)

            if embeddings_array is None:
                embeddings_array = np.zeros(
                    (len(image_list), _embeddings.shape[1]))
                embeddings_array_flip = np.zeros(
                    (len(image_list), _embeddings_flip.shape[1]))
            try:
                embeddings_array[batch_number *
                                 batch_size:min((batch_number + 1) *
                                                batch_size, len(image_list)),
                                 ...] = _embeddings
                embeddings_array_flip[batch_number * batch_size:min(
                    (batch_number + 1) * batch_size, len(image_list)),
                                      ...] = _embeddings_flip
                # print('try: ', batch_number * batch_size, min((batch_number + 1) * batch_size, len(image_list)), ...)
            except ValueError:
                print(
                    'batch_number*batch_size value is %d min((batch_number+1)*batch_size, len(image_list)) %d,'
                    ' batch_size %d, data.shape[0] %d' %
                    (batch_number * batch_size,
                     min((batch_number + 1) * batch_size,
                         len(image_list)), batch_size, images.shape[0]))
                print('except: ', batch_number * batch_size,
                      min((batch_number + 1) * batch_size, images.shape[0]),
                      ...)

            duration = time.time() - start_time
            batch_number += 1
            total_time += duration
        except tf.errors.OutOfRangeError:
            print(
                'tf.errors.OutOfRangeError, Reinitialize tf_dataset_iterator')
            sess.run(tf_dataset_iterator.initializer)
            break

    print(f"total_time: {total_time}")

    _xnorm = 0.0
    _xnorm_cnt = 0
    for embed in [embeddings_array, embeddings_array_flip]:
        for i in range(embed.shape[0]):
            _em = embed[i]
            _norm = np.linalg.norm(_em)
            # print(_em.shape, _norm)
            _xnorm += _norm
            _xnorm_cnt += 1
    _xnorm /= _xnorm_cnt

    final_embeddings_output = embeddings_array + embeddings_array_flip
    final_embeddings_output = sklearn.preprocessing.normalize(
        final_embeddings_output)
    print(final_embeddings_output.shape)

    tpr, fpr, accuracy, val, val_std, far = verification.evaluate(
        final_embeddings_output, issame_list, nrof_folds=10)
    acc2, std2 = np.mean(accuracy), np.std(accuracy)

    auc = metrics.auc(fpr, tpr)
    print('XNorm: %f' % (_xnorm))
    print('Accuracy-Flip: %1.5f+-%1.5f' % (acc2, std2))
    print('TPR: ', np.mean(tpr), 'FPR: ', np.mean(fpr))
    print('Area Under Curve (AUC): %1.3f' % auc)

    tpr_lfw, fpr_lfw, accuracy_lfw, val_lfw, val_std_lfw, far_lfw = lfw.evaluate(
        final_embeddings_output,
        issame_list,
        nrof_folds=10,
        distance_metric=0,
        subtract_mean=False)

    print('accuracy_lfw: %2.5f+-%2.5f' %
          (np.mean(accuracy_lfw), np.std(accuracy_lfw)))
    print(
        f"val_lfw: {val_lfw}, val_std_lfw: {val_std_lfw}, far_lfw: {far_lfw}")

    print('val_lfw rate: %2.5f+-%2.5f @ FAR=%2.5f' %
          (val_lfw, val_std_lfw, far_lfw))
    auc_lfw = metrics.auc(fpr_lfw, tpr_lfw)
    print('TPR_LFW:', np.mean(tpr_lfw), 'FPR_LFW: ', np.mean(fpr_lfw))

    print('Area Under Curve LFW (AUC): %1.3f' % auc_lfw)

    sess.close()

    return acc2, std2, _xnorm, [embeddings_array, embeddings_array_flip]
예제 #10
0
def custom_facenet_evaluation(args):
    tf.reset_default_graph()
    # Read the directory containing images
    pairs = read_pairs(args.insightface_pair)
    image_list, issame_list = get_paths_with_pairs(args.facenet_dataset_dir,
                                                   pairs)

    #  Evaluate custom dataset with facenet pre-trained model
    print("Getting embeddings with facenet pre-trained model")
    with tf.Graph().as_default():
        # Getting batched images by TF dataset
        # image_list = path_list
        tf_dataset = facenet.tf_gen_dataset(
            image_list=image_list,
            label_list=None,
            nrof_preprocess_threads=args.nrof_preprocess_threads,
            image_size=args.facenet_image_size,
            method='cache_slices',
            BATCH_SIZE=args.batch_size,
            repeat_count=1,
            to_float32=True,
            shuffle=False)
        tf_dataset_iterator = tf_dataset.make_initializable_iterator()
        tf_dataset_next_element = tf_dataset_iterator.get_next()

        with tf.Session() as sess:
            sess.run(tf_dataset_iterator.initializer)

            phase_train_placeholder = tf.placeholder(tf.bool,
                                                     name='phase_train')

            image_batch = tf.placeholder(name='img_inputs',
                                         shape=[
                                             None, args.facenet_image_size,
                                             args.facenet_image_size, 3
                                         ],
                                         dtype=tf.float32)
            label_batch = tf.placeholder(name='img_labels',
                                         shape=[
                                             None,
                                         ],
                                         dtype=tf.int32)

            # Load the model
            input_map = {
                'image_batch': image_batch,
                'label_batch': label_batch,
                'phase_train': phase_train_placeholder
            }
            facenet.load_model(args.facenet_model, input_map=input_map)

            # Get output tensor
            embeddings = tf.get_default_graph().get_tensor_by_name(
                "embeddings:0")

            batch_size = args.batch_size
            input_placeholder = image_batch

            print('getting embeddings..')

            total_time = 0
            batch_number = 0
            embeddings_array = None
            embeddings_array_flip = None
            while True:
                try:
                    images = sess.run(tf_dataset_next_element)

                    data_tmp = images.copy()  # fix issues #4

                    for i in range(data_tmp.shape[0]):
                        data_tmp[i, ...] -= 127.5
                        data_tmp[i, ...] *= 0.0078125
                        data_tmp[i,
                                 ...] = cv2.cvtColor(data_tmp[i, ...],
                                                     cv2.COLOR_RGB2BGR)

                    # Getting flip to left_right batched images by TF dataset
                    data_tmp_flip = images.copy()  # fix issues #4
                    for i in range(data_tmp_flip.shape[0]):
                        data_tmp_flip[i, ...] = np.fliplr(data_tmp_flip[i,
                                                                        ...])
                        data_tmp_flip[i, ...] -= 127.5
                        data_tmp_flip[i, ...] *= 0.0078125
                        data_tmp_flip[i, ...] = cv2.cvtColor(
                            data_tmp_flip[i, ...], cv2.COLOR_RGB2BGR)

                    start_time = time.time()

                    mr_feed_dict = {
                        input_placeholder: data_tmp,
                        phase_train_placeholder: False
                    }
                    mr_feed_dict_flip = {
                        input_placeholder: data_tmp_flip,
                        phase_train_placeholder: False
                    }
                    _embeddings = sess.run(embeddings, mr_feed_dict)
                    _embeddings_flip = sess.run(embeddings, mr_feed_dict_flip)

                    if embeddings_array is None:
                        embeddings_array = np.zeros(
                            (len(image_list), _embeddings.shape[1]))
                        embeddings_array_flip = np.zeros(
                            (len(image_list), _embeddings_flip.shape[1]))
                    try:
                        embeddings_array[batch_number * batch_size:min(
                            (batch_number + 1) * batch_size, len(image_list)),
                                         ...] = _embeddings
                        embeddings_array_flip[batch_number * batch_size:min(
                            (batch_number + 1) * batch_size, len(image_list)),
                                              ...] = _embeddings_flip
                        # print('try: ', batch_number * batch_size, min((batch_number + 1) * batch_size, len(image_list)), ...)
                    except ValueError:
                        print(
                            'batch_number*batch_size value is %d min((batch_number+1)*batch_size, len(image_list)) %d,'
                            ' batch_size %d, data.shape[0] %d' %
                            (batch_number * batch_size,
                             min((batch_number + 1) * batch_size,
                                 len(image_list)), batch_size,
                             images.shape[0]))
                        print(
                            'except: ', batch_number * batch_size,
                            min((batch_number + 1) * batch_size,
                                images.shape[0]), ...)

                    duration = time.time() - start_time
                    batch_number += 1
                    total_time += duration
                except tf.errors.OutOfRangeError:
                    print(
                        'tf.errors.OutOfRangeError, Reinitialize tf_dataset_iterator'
                    )
                    sess.run(tf_dataset_iterator.initializer)
                    break

    print(f"total_time: {total_time}")

    _xnorm = 0.0
    _xnorm_cnt = 0
    for embed in [embeddings_array, embeddings_array_flip]:
        for i in range(embed.shape[0]):
            _em = embed[i]
            _norm = np.linalg.norm(_em)
            # print(_em.shape, _norm)
            _xnorm += _norm
            _xnorm_cnt += 1
    _xnorm /= _xnorm_cnt

    final_embeddings_output = embeddings_array + embeddings_array_flip
    final_embeddings_output = sklearn.preprocessing.normalize(
        final_embeddings_output)
    print(final_embeddings_output.shape)

    tpr, fpr, accuracy, val, val_std, far = verification.evaluate(
        final_embeddings_output, issame_list, nrof_folds=10)
    acc2, std2 = np.mean(accuracy), np.std(accuracy)

    auc = metrics.auc(fpr, tpr)
    print('XNorm: %f' % (_xnorm))
    print('Accuracy-Flip: %1.5f+-%1.5f' % (acc2, std2))
    print('TPR: ', np.mean(tpr), 'FPR: ', np.mean(fpr))
    print('Area Under Curve (AUC): %1.3f' % auc)

    tpr_lfw, fpr_lfw, accuracy_lfw, val_lfw, val_std_lfw, far_lfw = lfw.evaluate(
        final_embeddings_output,
        issame_list,
        nrof_folds=10,
        distance_metric=0,
        subtract_mean=False)

    print('accuracy_lfw: %2.5f+-%2.5f' %
          (np.mean(accuracy_lfw), np.std(accuracy_lfw)))
    print(
        f"val_lfw: {val_lfw}, val_std_lfw: {val_std_lfw}, far_lfw: {far_lfw}")

    print('val_lfw rate: %2.5f+-%2.5f @ FAR=%2.5f' %
          (val_lfw, val_std_lfw, far_lfw))
    auc_lfw = metrics.auc(fpr_lfw, tpr_lfw)
    print('TPR_LFW:', np.mean(tpr_lfw), 'FPR_LFW: ', np.mean(fpr_lfw))

    print('Area Under Curve LFW (AUC): %1.3f' % auc_lfw)

    return acc2, std2, _xnorm, [embeddings_array, embeddings_array_flip]
def perform_val(multi_gpu,
                device,
                embedding_size,
                batch_size,
                backbone,
                carray,
                issame,
                nrof_folds=10,
                tta=True,
                mask=0,
                save_tag=False,
                part=None,
                size=25,
                batch_same=True,
                masknet=None):
    if multi_gpu:
        backbone = backbone.module  # unpackage model from DataParallel
        backbone = backbone.to(device)
    else:
        backbone = backbone.to(device)
    backbone.eval()  # switch to evaluation mode

    if masknet is not None:
        if multi_gpu:
            masknet = masknet.module  # unpackage model from DataParallel
            masknet = masknet.to(device)
        else:
            masknet = masknet.to(device)
        masknet.eval()  # switch to evaluation mode

    idx = 0
    embeddings = np.zeros([len(carray), embedding_size])
    with torch.no_grad():
        while idx + batch_size <= len(carray):
            batch = torch.tensor(carray[idx:idx + batch_size][:,
                                                              [2, 1, 0], :, :])
            if save_tag is True:
                if mask == 1:
                    batch = random_mask(batch, batch_same=batch_same)
                elif mask == 2:
                    batch = random_mask(batch, part, size)
                imgs = depre_batch(batch)
                for i, img in enumerate(imgs):
                    if part is not None:
                        img.save('./data/lfw_occlusion/' + part + '/' +
                                 str(idx + i) + '.jpg')
                    else:
                        img.save('./data/lfw_noocclusion/' + str(idx + i) +
                                 '.jpg')

            if tta:
                ccropped = ccrop_batch(batch)
                fliped = hflip_batch(ccropped)
                if masknet is not None:
                    if mask == 1:
                        ccropped, ccropped_mask = random_mask(
                            ccropped, batch_same=batch_same, mask_return=True)
                        fliped, fliped_mask = random_mask(
                            fliped, batch_same=batch_same, mask_return=True)
                    elif mask == 2:
                        ccropped, ccropped_mask = random_mask(ccropped,
                                                              part,
                                                              size,
                                                              mask_return=True)
                        fliped, fliped_mask = random_mask(fliped,
                                                          part,
                                                          size,
                                                          mask_return=True)

                    emb_batch = add_mask(
                        ccropped[0::2], ccropped[1::2], ccropped_mask[0::2],
                        ccropped_mask[1::2], masknet, device, backbone,
                        embedding_size, batch_size) + add_mask(
                            fliped[0::2], fliped[1::2], fliped_mask[0::2],
                            fliped_mask[1::2], masknet, device, backbone,
                            embedding_size, batch_size)
                    embeddings[idx:idx + batch_size] = l2_norm(emb_batch.cpu())

                else:
                    if mask == 1:
                        ccropped = random_mask(ccropped, batch_same=batch_same)
                        fliped = random_mask(fliped, batch_same=batch_same)
                    elif mask == 2:
                        ccropped = random_mask(ccropped, part, size)
                        fliped = random_mask(fliped, part, size)
                    emb_batch = backbone(ccropped.to(device)).cpu() + backbone(
                        fliped.to(device)).cpu()
                    embeddings[idx:idx + batch_size] = l2_norm(emb_batch)
            else:
                ccropped = ccrop_batch(batch)
                if masknet is not None:
                    if mask == 1:
                        ccropped, ccropped_mask = random_mask(
                            ccropped, batch_same=batch_same, mask_return=True)
                    elif mask == 2:
                        ccropped, ccropped_mask = random_mask(ccropped,
                                                              part,
                                                              size,
                                                              mask_return=True)
                    emb_batch = add_mask(ccropped[0::2], ccropped[1::2],
                                         ccropped_mask[0::2],
                                         ccropped_mask[1::2], masknet, device,
                                         backbone, embedding_size, batch_size)
                    embeddings[idx:idx + batch_size] = l2_norm(emb_batch.cpu())
                else:
                    if mask == 1:
                        ccropped = random_mask(ccropped, batch_same=batch_same)
                    elif mask == 2:
                        ccropped = random_mask(ccropped, part, size)
                    embeddings[idx:idx + batch_size] = l2_norm(
                        backbone(ccropped.to(device))).cpu()

            idx += batch_size
        if idx < len(carray):
            batch = torch.tensor(carray[idx:])
            if save_tag is True:
                if mask == 1:
                    batch = random_mask(batch, batch_same=batch_same)
                elif mask == 2:
                    batch = random_mask(batch, part, size)
                imgs = depre_batch(batch)
                for i, img in enumerate(imgs):
                    if part is not None:
                        img.save('./data/lfw_occlusion/' + part + '/' +
                                 str(idx + i) + '.jpg')
                    else:
                        img.save('./data/lfw_noocclusion/' + str(idx + i) +
                                 '.jpg')

            if tta:
                ccropped = ccrop_batch(batch)
                fliped = hflip_batch(ccropped)
                if masknet is not None:
                    if mask == 1:
                        ccropped, ccropped_mask = random_mask(
                            ccropped, batch_same=batch_same, mask_return=True)
                        fliped, fliped_mask = random_mask(
                            fliped, batch_same=batch_same, mask_return=True)
                    elif mask == 2:
                        ccropped, ccropped_mask = random_mask(ccropped,
                                                              part,
                                                              size,
                                                              mask_return=True)
                        fliped, fliped_mask = random_mask(fliped,
                                                          part,
                                                          size,
                                                          mask_return=True)
                    emb_batch = add_mask(
                        ccropped[0::2], ccropped[1::2], ccropped_mask[0::2],
                        ccropped_mask[1::2], masknet, device, backbone,
                        embedding_size,
                        len(carray) - idx) + add_mask(
                            fliped[0::2], fliped[1::2], fliped_mask[0::2],
                            fliped_mask[1::2], masknet, device, backbone,
                            embedding_size,
                            len(carray) - idx)
                    embeddings[idx:] = l2_norm(emb_batch.cpu())

                else:
                    if mask == 1:
                        ccropped = random_mask(ccropped, batch_same=batch_same)
                        fliped = random_mask(fliped, batch_same=batch_same)
                    elif mask == 2:
                        ccropped = random_mask(ccropped, part, size)
                        fliped = random_mask(fliped, part, size)
                    emb_batch = backbone(ccropped.to(device)).cpu() + backbone(
                        fliped.to(device)).cpu()
                    embeddings[idx:] = l2_norm(emb_batch)
            else:
                ccropped = ccrop_batch(batch)
                if masknet is not None:
                    if mask == 1:
                        ccropped, ccropped_mask = random_mask(
                            ccropped, batch_same=batch_same, mask_return=True)
                    elif mask == 2:
                        ccropped, ccropped_mask = random_mask(ccropped,
                                                              part,
                                                              size,
                                                              mask_return=True)
                    emb_batch = add_mask(ccropped[0::2], ccropped[1::2],
                                         ccropped_mask[0::2],
                                         ccropped_mask[1::2], masknet, device,
                                         backbone, embedding_size,
                                         len(carray) - idx)
                    embeddings[idx:] = l2_norm(emb_batch.cpu())
                else:
                    if mask == 1:
                        ccropped = random_mask(ccropped, batch_same=batch_same)
                    elif mask == 2:
                        ccropped = random_mask(ccropped, part, size)
                    embeddings[idx:] = l2_norm(backbone(
                        ccropped.to(device))).cpu()

    tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                   nrof_folds)
    buf = gen_plot(fpr, tpr)
    roc_curve = Image.open(buf)
    roc_curve_tensor = transforms.ToTensor()(roc_curve)

    return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor
    def on_epoch_end(self, epoch, logs=None):
        result = pt.PrettyTable(
            ["data set", "AUC", "ACC", "VR @ FAR ", "dist max", "dist min"])
        val_all = []
        auc_all = []
        acc_all = []
        dist_all = []
        for i in range(len(self.datadirs)):
            path = FLAGS.valid_dir + "/" + self.datadirs[i]
            embeddings, issamelab = predict_evaluate_data(
                path, [FLAGS.img_size, FLAGS.img_size], self.model,
                FLAGS.embedding_size)

            embeddings1 = embeddings[0::2]
            embeddings2 = embeddings[1::2]
            diff = np.subtract(embeddings1, embeddings2)
            dist = np.sum(np.square(diff), 1)
            del embeddings1, embeddings2
            fpr, tpr, ths = roc_curve(np.asarray(issamelab).astype('int'),
                                      dist,
                                      pos_label=0)
            auc_score = auc(fpr, tpr)
            #  print('\ndataset:',self.datadirs[i])
            #  print('embed:',np.max(embeddings),np.min(embeddings))
            #  print("dist:",np.max(dist),np.min(dist))
            if 0:
                plt.figure()
                lw = 2
                plt.plot(fpr,
                         tpr,
                         color='darkorange',
                         lw=lw,
                         label='ROC curve (area = %0.2f)' % auc_score)
                plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
                plt.xlim([0.0, 1.0])
                plt.ylim([0.0, 1.05])
                plt.xlabel('False Positive Rate')
                plt.ylabel('True Positive Rate')
                plt.title('Receiver operating characteristic example')
                plt.legend(loc="lower right")
                #  plt.show()
                plt.savefig("roc.png")

            #  print('-----10 folds------')
            tpr, fpr, accuracy, val, val_std, far = evaluate(
                embeddings, issamelab)
            del embeddings, issamelab
            gc.collect()
            result.add_row([
                self.datadirs[i],
                round(auc_score, 2),
                "%1.3f+-%1.3f" % (np.mean(accuracy), np.std(accuracy)),
                "%2.5f+-%2.5f @ FAR=%2.5f" % (val, val_std, far),
                round(np.max(dist), 2),
                round(np.min(dist), 2),
            ])
            val_all.append(val)
            acc_all.append(accuracy)
            auc_all.append(auc_score)
            dist_all.append([np.max(dist), np.min(dist)])

        logs['predict_labels_validacc'] = np.mean(np.array(val_all))
        logs['max_dis'] = np.mean(np.array(dist_all), 0)[0]
        logs['min_dis'] = np.mean(np.array(dist_all), 0)[1]
        logs['auc'] = np.mean(np.array(auc_all))
        print(result)
예제 #13
0
        #Preprocesing the data to obtain mfcc or spectrogram for input to the network
        ppr.preprocess_data(
            cfg.audio_dir, cfg.dev_set_path, cfg.train_data_path, cfg.session
        )  # Have to do the seapration into train and validation sets manually
        ppr.preprocess_data(cfg.audio_dir, cfg.val_set_path, cfg.val_data_path,
                            cfg.session)
        vs.train(cfg.audio_dir, cfg.train_data_path, cfg.val_data_path)

    # Evaluating the model
    elif cfg.session == 'evaluate':
        print("EVALUATION SESSION...")
        ppr.preprocess_data(cfg.audio_dir, cfg.enroll_set_path,
                            cfg.enroll_data_path, 'enroll')
        ppr.preprocess_data(cfg.audio_dir, cfg.eval_set_path,
                            cfg.eval_data_path, cfg.session)
        vs.evaluate(cfg.audio_dir, cfg.eval_data_path, cfg.enroll_data_path)

    # Enrolling the speakers
    elif cfg.session == 'enroll':
        print("ENROLLMENT SESSION...")
        reply = enr.yes_or_no(
            'Do you want to give input through microphone?(y/n): ')
        if (reply):
            enr.get_audio()
            ppr.preprocess_data(cfg.audio_dir, cfg.RT_enroll_set_path,
                                cfg.RT_enroll_data_path, cfg.session)
            vs.enroll(cfg.audio_dir, cfg.RT_enroll_data_path, cfg.session)
        else:
            ppr.preprocess_data(cfg.audio_dir, cfg.enroll_set_path,
                                cfg.enroll_data_path, cfg.session)
            vs.enroll(cfg.audio_dir, cfg.enroll_data_path, cfg.session)
예제 #14
0
def main():
    cur_time = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    print(f'\n\n\n***TRAINING SESSION START AT {cur_time}***\n\n\n')
    with tf.Graph().as_default():
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        args = get_parser()

        # define global params
        global_step = tf.Variable(name='global_step',
                                  initial_value=0,
                                  trainable=False)
        epoch_step = tf.Variable(name='epoch_step',
                                 initial_value=0,
                                 trainable=False)
        epoch = tf.Variable(name='epoch', initial_value=0, trainable=False)

        # def placeholders
        print(f'***Input of size: {args.image_size}')
        print(
            f'***Perform evaluation after each {args.validate_interval} on datasets: {args.eval_datasets}'
        )
        inputs = tf.placeholder(name='img_inputs',
                                shape=[None, *args.image_size, 3],
                                dtype=tf.float32)
        labels = tf.placeholder(name='img_labels',
                                shape=[
                                    None,
                                ],
                                dtype=tf.int64)
        phase_train_placeholder = tf.placeholder_with_default(
            tf.constant(False, dtype=tf.bool), shape=None, name='phase_train')

        # prepare train dataset
        # the image is substracted 127.5 and multiplied 1/128.
        # random flip left right
        tfrecords_f = os.path.join(args.tfrecords_file_path, 'train.tfrecords')
        dataset = tf.data.TFRecordDataset(tfrecords_f)
        dataset = dataset.map(parse_function)
        # dataset = dataset.shuffle(buffer_size=args.buffer_size)
        dataset = dataset.batch(args.train_batch_size)
        iterator = dataset.make_initializable_iterator()
        next_element = iterator.get_next()

        # identity the input, for inference
        inputs = tf.identity(inputs, 'input')

        prelogits, net_points = inference(
            inputs,
            bottleneck_layer_size=args.embedding_size,
            phase_train=phase_train_placeholder,
            weight_decay=args.weight_decay)
        # record the network architecture
        hd = open("./arch/txt/MobileFaceNet_architecture.txt", 'w')
        for key in net_points.keys():
            info = '{}:{}\n'.format(key, net_points[key].get_shape().as_list())
            hd.write(info)
        hd.close()

        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')

        # Norm for the prelogits
        eps = 1e-5
        prelogits_norm = tf.reduce_mean(
            tf.norm(tf.abs(prelogits) + eps, ord=args.prelogits_norm_p,
                    axis=1))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             prelogits_norm * args.prelogits_norm_loss_factor)

        # inference_loss, logit = cos_loss(prelogits, labels, args.class_number)
        w_init_method = slim.initializers.xavier_initializer()
        if args.loss_type == 'insightface':
            print(
                f'INSIGHTFACE LOSS WITH s={args.margin_s}, m={args.margin_m}')
            inference_loss, logit = insightface_loss(embeddings,
                                                     labels,
                                                     args.class_number,
                                                     w_init_method,
                                                     s=args.margin_s,
                                                     m=args.margin_m)
        elif args.loss_type == 'cosine':
            inference_loss, logit = cosineface_loss(embeddings, labels,
                                                    args.class_number,
                                                    w_init_method)
        elif args.loss_type == 'combine':
            inference_loss, logit = combine_loss(embeddings, labels,
                                                 args.train_batch_size,
                                                 args.class_number,
                                                 w_init_method)
        else:
            assert 0, 'loss type error, choice item just one of [insightface, cosine, combine], please check!'
        tf.add_to_collection('losses', inference_loss)

        # total losses
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([inference_loss] + regularization_losses,
                              name='total_loss')

        # define the learning rate schedule
        learning_rate = tf.train.piecewise_constant(
            epoch,
            boundaries=args.lr_schedule,
            values=[0.1, 0.01, 0.001, 0.0001, 0.00001],
            name='lr_schedule')

        # define sess
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=args.log_device_mapping,
                                gpu_options=gpu_options)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        # calculate accuracy op
        pred = tf.nn.softmax(logit)
        correct_prediction = tf.cast(
            tf.equal(tf.argmax(pred, 1), tf.cast(labels, tf.int64)),
            tf.float32)
        Accuracy_Op = tf.reduce_mean(correct_prediction)

        # summary writer
        summary = tf.summary.FileWriter(args.summary_path, sess.graph)
        summaries = []
        # add train info to tensorboard summary
        summaries.append(tf.summary.scalar('inference_loss', inference_loss))
        summaries.append(tf.summary.scalar('total_loss', total_loss))
        summaries.append(tf.summary.scalar('learning_rate', learning_rate))
        summaries.append(tf.summary.scalar('training_acc', Accuracy_Op))
        summary_op = tf.summary.merge(summaries)

        # train op
        train_op = train(total_loss, global_step, args.optimizer,
                         learning_rate, args.moving_average_decay,
                         tf.global_variables(), summaries, args.log_histograms)
        inc_global_step_op = tf.assign_add(global_step,
                                           1,
                                           name='increment_global_step')
        inc_epoch_step_op = tf.assign_add(epoch_step,
                                          1,
                                          name='increment_epoch_step')
        reset_epoch_step_op = tf.assign(epoch_step, 0, name='reset_epoch_step')
        inc_epoch_op = tf.assign_add(epoch, 1, name='increment_epoch')

        # record trainable variable
        hd = open("./arch/txt/trainable_var.txt", "w")
        for var in tf.trainable_variables():
            hd.write(str(var))
            hd.write('\n')
        hd.close()

        # init all variables
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # RELOAD CHECKPOINT FOR PRETRAINED MODEL
        # pretrained model path
        pretrained_model = None
        if args.pretrained_model:
            pretrained_model = os.path.expanduser(args.pretrained_model)
            print('***Pre-trained model: %s' % pretrained_model)

        if pretrained_model is None:
            # saver to load pretrained model or save model
            saver = tf.train.Saver(tf.trainable_variables() +
                                   [epoch, epoch_step, global_step],
                                   max_to_keep=args.saver_maxkeep)
        else:
            saver = tf.train.Saver(tf.trainable_variables(),
                                   max_to_keep=args.saver_maxkeep)
        # lask checkpoint path
        checkpoint_path = None
        if args.ckpt_path:
            ckpts = os.listdir(args.ckpt_path)
            if 'checkpoint' in ckpts:
                ckpts.remove('checkpoint')
            ckpts_prefix = [x.split('_')[0] for x in ckpts]
            ckpts_prefix.sort(key=lambda x: int(x), reverse=True)

            # Get last checkpoint
            if len(ckpts_prefix) > 0:
                last_ckpt = f"{ckpts_prefix[0]}_MobileFaceNet.ckpt"
                checkpoint_path = os.path.expanduser(
                    os.path.join(args.ckpt_path, last_ckpt))
                print('***Last checkpoint: %s' % checkpoint_path)

        # load checkpoint model
        if checkpoint_path is not None:
            print('***Restoring checkpoint: %s' % checkpoint_path)
            saver.restore(sess, checkpoint_path)
        # load pretrained model
        elif pretrained_model:
            print('***Restoring pretrained model: %s' % pretrained_model)
            # ckpt = tf.train.get_checkpoint_state(pretrained_model)
            # print(ckpt)
            saver.restore(sess, pretrained_model)
        else:
            print('***No checkpoint or pretrained model found.')
            print('***Training from scratch')

        # output file path
        if not os.path.exists(args.log_file_path):
            os.makedirs(args.log_file_path)
        if not os.path.exists(args.ckpt_best_path):
            os.makedirs(args.ckpt_best_path)

        # prepare validate datasets
        ver_list = []
        ver_name_list = []
        print('***LOADING VALIDATION DATABASES..')
        for db in args.eval_datasets:
            print('\t- Loading database: %s' % db)
            data_set = load_data(db, args.image_size, args)
            ver_list.append(data_set)
            ver_name_list.append(db)

        cur_epoch, cur_global_step, cur_epoch_step = sess.run(
            [epoch, global_step, epoch_step])
        print('****************************************')
        print(
            f'Continuous training on EPOCH={cur_epoch}, GLOBAL_STEP={cur_global_step}, EPOCH_STEP={cur_epoch_step}'
        )
        print('****************************************')

        total_losses_per_summary = []
        inference_losses_per_summary = []
        train_acc_per_summary = []
        avg_total_loss_per_summary = 0
        avg_inference_loss_per_summary = 0
        avg_train_acc_per_summary = 0
        for i in range(cur_epoch, args.max_epoch + 1):
            sess.run(iterator.initializer)
            # Trained steps are ignored
            print(f'Skipping {cur_epoch_step} trained step..')
            start = time.time()
            for _j in range(cur_epoch_step):
                images_train, labels_train = sess.run(next_element)
                if _j % 1000 == 0:
                    end = time.time()
                    iter_time = end - start
                    start = time.time()
                    print(f'{_j}, time: {iter_time} seconds')
            print('***Traing started***')
            while True:
                try:
                    start = time.time()
                    images_train, labels_train = sess.run(next_element)
                    feed_dict = {
                        inputs: images_train,
                        labels: labels_train,
                        phase_train_placeholder: True
                    }
                    _, total_loss_val, inference_loss_val, reg_loss_val, _, acc_val = \
                    sess.run([train_op, total_loss, inference_loss, regularization_losses, inc_epoch_step_op, Accuracy_Op],
                             feed_dict=feed_dict)
                    end = time.time()
                    pre_sec = args.train_batch_size / (end - start)

                    cur_global_step += 1
                    cur_epoch_step += 1

                    total_losses_per_summary.append(total_loss_val)
                    inference_losses_per_summary.append(inference_loss_val)
                    train_acc_per_summary.append(acc_val)

                    # print training information
                    if cur_global_step > 0 and cur_global_step % args.show_info_interval == 0:
                        print(
                            'epoch %d, total_step %d, epoch_step %d, total loss %.2f , inference loss %.2f, reg_loss %.2f, training accuracy %.6f, rate %.3f samples/sec'
                            % (i, cur_global_step, cur_epoch_step,
                               total_loss_val, inference_loss_val,
                               np.sum(reg_loss_val), acc_val, pre_sec))

                    # save summary
                    if cur_global_step > 0 and cur_global_step % args.summary_interval == 0:
                        feed_dict = {
                            inputs: images_train,
                            labels: labels_train,
                            phase_train_placeholder: True
                        }
                        summary_op_val = sess.run(summary_op,
                                                  feed_dict=feed_dict)
                        summary.add_summary(summary_op_val, cur_global_step)

                        avg_total_loss_per_summary = sum(
                            total_losses_per_summary) / len(
                                total_losses_per_summary)
                        total_losses_per_summary = []
                        avg_inference_loss_per_summary = sum(
                            inference_losses_per_summary) / len(
                                inference_losses_per_summary)
                        inference_losses_per_summary = []
                        avg_train_acc_per_summary = sum(
                            train_acc_per_summary) / len(train_acc_per_summary)
                        train_acc_per_summary = []
                        # Create a new Summary object with your measure
                        summary2 = tf.Summary()
                        summary2.value.add(
                            tag='avg_total_loss',
                            simple_value=avg_total_loss_per_summary)
                        summary2.value.add(
                            tag='avg_inference_loss',
                            simple_value=avg_inference_loss_per_summary)
                        summary2.value.add(
                            tag='avg_train_acc',
                            simple_value=avg_train_acc_per_summary)

                        # Add it to the Tensorboard summary writer
                        # Make sure to specify a step parameter to get nice graphs over time
                        summary.add_summary(summary2, cur_global_step)

                    # save ckpt files
                    if cur_global_step > 0 and cur_global_step % args.ckpt_interval == 0:
                        filename = '{:d}_MobileFaceNet'.format(
                            cur_global_step) + '.ckpt'
                        filename = os.path.join(args.ckpt_path, filename)
                        saver.save(sess, filename)

                    # validate
                    if cur_global_step > 0 and cur_global_step % args.validate_interval == 0:
                        print(
                            '-------------------------------------------------'
                        )
                        print('\nIteration', cur_global_step, 'validating...')
                        for db_index in range(len(ver_list)):
                            start_time = time.time()
                            data_sets, issame_list = ver_list[db_index]
                            emb_array = np.zeros(
                                (data_sets.shape[0], args.embedding_size))
                            if data_sets.shape[0] % args.test_batch_size == 0:
                                nrof_batches = data_sets.shape[
                                    0] // args.test_batch_size
                            else:
                                nrof_batches = data_sets.shape[
                                    0] // args.test_batch_size + 1
                            for index in range(
                                    nrof_batches
                            ):  # actual is same multiply 2, test data total
                                start_index = index * args.test_batch_size
                                end_index = min(
                                    (index + 1) * args.test_batch_size,
                                    data_sets.shape[0])

                                feed_dict = {
                                    inputs: data_sets[start_index:end_index,
                                                      ...],
                                    phase_train_placeholder: False
                                }
                                emb_array[start_index:end_index, :] = sess.run(
                                    embeddings, feed_dict=feed_dict)

                            tpr, fpr, accuracy, val, val_std, far = evaluate(
                                emb_array,
                                issame_list,
                                nrof_folds=args.eval_nrof_folds)
                            duration = time.time() - start_time

                            print(
                                "---Total time %.3fs to evaluate %d images of %s"
                                % (duration, data_sets.shape[0],
                                   ver_name_list[db_index]))
                            print('\t- Accuracy: %1.3f+-%1.3f' %
                                  (np.mean(accuracy), np.std(accuracy)))
                            print(
                                '\t- Validation rate: %2.5f+-%2.5f @ FAR=%2.5f'
                                % (val, val_std, far))
                            print('\t- FPR and TPR: %1.3f %1.3f' %
                                  (np.mean(fpr, 0), np.mean(tpr, 0)))

                            auc = metrics.auc(fpr, tpr)
                            print('\t- Area Under Curve (AUC): %1.3f' % auc)
                            # eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.)
                            # print('Equal Error Rate (EER): %1.3f\n' % eer)

                            with open(
                                    os.path.join(
                                        args.log_file_path,
                                        '{}_result.txt'.format(
                                            ver_name_list[db_index])),
                                    'at') as f:
                                f.write('%d\t%.5f\t%.5f\t%.5f\t%.5f\t%.5f\n' %
                                        (cur_global_step, np.mean(accuracy),
                                         val, val_std, far, auc))

                            if ver_name_list[db_index] == 'lfw' and np.mean(
                                    accuracy) > 0.994:
                                print('High accuracy: %.5f' %
                                      np.mean(accuracy))
                                filename = 'MobileFaceNet_iter_best_{:d}'.format(
                                    cur_global_step) + '.ckpt'
                                filename = os.path.join(
                                    args.ckpt_best_path, filename)
                                saver.save(sess, filename)
                            print(
                                '---------------------------------------------------'
                            )

                except tf.errors.OutOfRangeError:
                    _, _ = sess.run([inc_epoch_op, reset_epoch_step_op])
                    # Save checkpoint
                    filename = '{:d}_MobileFaceNet'.format(
                        cur_global_step) + '.ckpt'
                    filename = os.path.join(args.ckpt_path, filename)
                    saver.save(sess, filename)
                    cur_epoch_step = 0
                    print("\n\n-------End of epoch %d\n\n" % i)
                    break