コード例 #1
0
def main(_):
    if FLAGS.binary_img:
        tfrecord_name = './data/ms1m_bin.tfrecord'
    else:
        tfrecord_name = './data/ms1m.tfrecord'

    train_dataset = load_tfrecord_dataset(tfrecord_name,
                                          FLAGS.batch_size,
                                          binary_img=FLAGS.binary_img,
                                          is_ccrop=FLAGS.is_ccrop)

    num_samples = 100
    start_time = time.time()
    for idx, parsed_record in enumerate(train_dataset.take(num_samples)):
        (x_train, _), y_train = parsed_record
        print("{} x_train: {}, y_train: {}".format(idx, x_train.shape,
                                                   y_train.shape))

        if FLAGS.visualization:
            recon_img = np.array(x_train[0].numpy() * 255, 'uint8')
            cv2.imshow('img', cv2.cvtColor(recon_img, cv2.COLOR_RGB2BGR))

            if cv2.waitKey(0) == 113:
                exit()

    print("data fps: {:.2f}".format(num_samples / (time.time() - start_time)))
コード例 #2
0
ファイル: utils.py プロジェクト: Pol22/esrgan-tf2
def load_dataset(path, size, batch, shuffle=True, buffer_size=10240):
    """load dataset"""
    dataset = load_tfrecord_dataset(tfrecord_name=path,
                                    batch_size=batch,
                                    size=size,
                                    shuffle=shuffle,
                                    using_bin=True,
                                    using_flip=True,
                                    using_rot=True,
                                    buffer_size=buffer_size)
    return dataset
コード例 #3
0
ファイル: utils.py プロジェクト: zjzhou521/nero_3_SRDenseNet
def load_dataset(cfg, key, shuffle=True, buffer_size=10240):
    """load dataset"""
    dataset_cfg = cfg[key]
    logging.info("load {} from {}".format(key, dataset_cfg['path']))
    dataset = load_tfrecord_dataset(tfrecord_name=dataset_cfg['path'],
                                    batch_size=cfg['batch_size'],
                                    gt_size=cfg['gt_size'],
                                    scale=cfg['scale'],
                                    shuffle=shuffle,
                                    using_bin=dataset_cfg['using_bin'],
                                    using_flip=dataset_cfg['using_flip'],
                                    using_rot=dataset_cfg['using_rot'],
                                    buffer_size=buffer_size)
    return dataset
コード例 #4
0
def main(_):
    if FLAGS.using_bin:
        train_dataset = load_tfrecord_dataset(
            './data/DIV2K800_sub_bin.tfrecord', 16, 128, 4,
            using_bin=True, using_flip=True, using_rot=False, buffer_size=10)
    else:
        train_dataset = load_tfrecord_dataset(
            './data/DIV2K800_sub.tfrecord', 16, 128, 4,
            using_bin=False, using_flip=True, using_rot=False, buffer_size=10)

    num_samples = 100
    start_time = time.time()
    for idx, (inputs, labels) in enumerate(train_dataset.take(num_samples)):
        print("{} inputs:".format(idx), inputs.shape, "outputs:", labels.shape)

        if FLAGS.visualization:
            cv2.imshow('inputs', cv2.cvtColor(inputs[0].numpy(),cv2.COLOR_RGB2BGR))
            cv2.imshow('labels', cv2.cvtColor(labels[0].numpy(),cv2.COLOR_RGB2BGR))

            if cv2.waitKey(0) == ord('q'):
                exit()

    print("data fps: {:.2f}".format(num_samples / (time.time() - start_time)))
コード例 #5
0
def load_dataset(cfg, priors, shuffle=True, buffer_size=10240):
    """load dataset"""
    logging.info("load dataset from {}".format(cfg['dataset_path']))
    dataset = load_tfrecord_dataset(tfrecord_name=cfg['dataset_path'],
                                    batch_size=cfg['batch_size'],
                                    img_dim=cfg['input_size'],
                                    using_bin=cfg['using_bin'],
                                    using_flip=cfg['using_flip'],
                                    using_distort=cfg['using_distort'],
                                    using_encoding=True,
                                    priors=priors,
                                    match_thresh=cfg['match_thresh'],
                                    ignore_thresh=cfg['ignore_thresh'],
                                    variances=cfg['variances'],
                                    shuffle=shuffle,
                                    buffer_size=buffer_size)
    return dataset
コード例 #6
0
def load_dataset(cfg, priors, shuffle=True, buffer_size=10240):
    """load dataset"""
    logging.info("load dataset from {}".format(cfg["dataset_path"]))
    dataset = load_tfrecord_dataset(
        tfrecord_name=cfg["dataset_path"],
        batch_size=cfg["batch_size"],
        img_dim=cfg["input_size"],
        using_bin=cfg["using_bin"],
        using_flip=cfg["using_flip"],
        using_distort=cfg["using_distort"],
        using_encoding=True,
        priors=priors,
        match_thresh=cfg["match_thresh"],
        ignore_thresh=cfg["ignore_thresh"],
        variances=cfg["variances"],
        shuffle=shuffle,
        buffer_size=buffer_size,
    )
    return dataset
コード例 #7
0
def load_dataset(cfg, priors, split, hvd):
    """load dataset"""
    logging.info("load dataset from {}".format(cfg['dataset_root']))

    if split is 'train':
        batch_size = cfg['batch_size']
        shuffle = True
        using_flip = cfg['using_flip']
        using_distort = cfg['using_distort']
        using_encoding = True
        buffer_size = 2000
        number_cycles = cfg['number_cycles']
        threads = tf.data.experimental.AUTOTUNE
    else:
        batch_size = 1
        shuffle = False
        using_flip = False
        using_distort = False
        using_encoding = False
        buffer_size = 2000
        number_cycles = 1
        threads = tf.data.experimental.AUTOTUNE

    dataset = load_tfrecord_dataset(dataset_root=cfg['dataset_root'],
                                    split=split,
                                    threads=threads,
                                    number_cycles=number_cycles,
                                    batch_size=batch_size,
                                    hvd=hvd,
                                    img_dim=cfg['input_size'],
                                    using_bin=cfg['using_bin'],
                                    using_flip=using_flip,
                                    using_distort=using_distort,
                                    using_encoding=using_encoding,
                                    priors=priors,
                                    match_thresh=cfg['match_thresh'],
                                    ignore_thresh=cfg['ignore_thresh'],
                                    variances=cfg['variances'],
                                    shuffle=shuffle,
                                    buffer_size=buffer_size)
    return dataset
コード例 #8
0
def main(_):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu

    logger = tf.get_logger()
    logger.disabled = True
    logger.setLevel(logging.FATAL)
    set_memory_growth()

    cfg = load_yaml(FLAGS.cfg_path)

    model = ArcFaceModel(size=cfg['input_size'],
                         backbone_type=cfg['backbone_type'],
                         num_classes=cfg['num_classes'],
                         head_type=cfg['head_type'],
                         embd_shape=cfg['embd_shape'],
                         w_decay=cfg['w_decay'],
                         training=True)
    model.summary(line_length=80)

    if cfg['train_dataset']:
        logging.info("load ms1m dataset.")
        dataset_len = cfg['num_samples']
        steps_per_epoch = dataset_len // cfg['batch_size']
        train_dataset = dataset.load_tfrecord_dataset(cfg['train_dataset'],
                                                      cfg['batch_size'],
                                                      cfg['binary_img'],
                                                      is_ccrop=cfg['is_ccrop'])
    else:
        logging.info("load fake dataset.")
        dataset_len = 1
        steps_per_epoch = 1
        train_dataset = dataset.load_fake_dataset(cfg['input_size'])

    learning_rate = tf.constant(cfg['base_lr'])
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                        momentum=0.9,
                                        nesterov=True)
    loss_fn = SoftmaxLoss()

    ckpt_path = tf.train.latest_checkpoint('./checkpoints/' + cfg['sub_name'])
    if ckpt_path is not None:
        print("[*] load ckpt from {}".format(ckpt_path))
        model.load_weights(ckpt_path)
        epochs, steps = get_ckpt_inf(ckpt_path, steps_per_epoch)
    else:
        print("[*] training from scratch.")
        epochs, steps = 1, 1

    if FLAGS.mode == 'eager_tf':
        # Eager mode is great for debugging
        # Non eager graph mode is recommended for real training
        summary_writer = tf.summary.create_file_writer('./logs/' +
                                                       cfg['sub_name'])

        train_dataset = iter(train_dataset)

        while epochs <= cfg['epochs']:
            inputs, labels = next(train_dataset)

            with tf.GradientTape() as tape:
                logist = model(inputs, training=True)
                reg_loss = tf.reduce_sum(model.losses)
                pred_loss = loss_fn(labels, logist)
                total_loss = pred_loss + reg_loss

            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            if steps % 5 == 0:
                verb_str = "Epoch {}/{}: {}/{}, loss={:.2f}, lr={:.4f}"
                print(
                    verb_str.format(epochs, cfg['epochs'],
                                    steps % steps_per_epoch, steps_per_epoch,
                                    total_loss.numpy(), learning_rate.numpy()))

                with summary_writer.as_default():
                    tf.summary.scalar('loss/total loss',
                                      total_loss,
                                      step=steps)
                    tf.summary.scalar('loss/pred loss', pred_loss, step=steps)
                    tf.summary.scalar('loss/reg loss', reg_loss, step=steps)
                    tf.summary.scalar('learning rate',
                                      optimizer.lr,
                                      step=steps)

            if steps % cfg['save_steps'] == 0:
                print('[*] save ckpt file!')
                model.save_weights('checkpoints/{}/e_{}_b_{}.ckpt'.format(
                    cfg['sub_name'], epochs, steps % steps_per_epoch))

            steps += 1
            epochs = steps // steps_per_epoch + 1
    else:
        model.compile(optimizer=optimizer,
                      loss=loss_fn,
                      run_eagerly=(FLAGS.mode == 'eager_fit'))

        mc_callback = ModelCheckpoint(
            'checkpoints/' + cfg['sub_name'] + '/e_{epoch}_b_{batch}.ckpt',
            save_freq=cfg['save_steps'] * cfg['batch_size'],
            verbose=1,
            save_weights_only=True)
        tb_callback = TensorBoard(log_dir='logs/',
                                  update_freq=cfg['batch_size'] * 5,
                                  profile_batch=0)
        tb_callback._total_batches_seen = steps
        tb_callback._samples_seen = steps * cfg['batch_size']
        callbacks = [mc_callback, tb_callback]

        history = model.fit(train_dataset,
                            epochs=cfg['epochs'],
                            steps_per_epoch=steps_per_epoch,
                            callbacks=callbacks,
                            initial_epoch=epochs - 1)

    print("[*] training done!")
コード例 #9
0
def main(_):
    set_memory_growth()

    cfg = load_yaml(FLAGS.cfg_path)
    model = ArcFaceModel(size=cfg['input_size'],
                         backbone_type=cfg['backbone_type'],
                         num_classes=cfg['num_classes'],
                         head_type=cfg['head_type'],
                         embd_shape=cfg['embd_shape'],
                         w_decay=cfg['w_decay'],
                         training=True)
    learning_rate = tf.constant(cfg['base_lr'])
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                        momentum=0.9,
                                        nesterov=True)
    loss_fn = SoftmaxLoss()

    ckpt_path = tf.train.latest_checkpoint('./checkpoints/train_' +
                                           cfg['sub_name'])
    if ckpt_path is not None:
        print("[*] load ckpt from {}".format(ckpt_path))
        model.load_weights(ckpt_path)
    else:
        print("[*] training from scratch.")
    model.compile(optimizer=optimizer, loss=loss_fn)

    # resnet_model = tf.keras.Model(inputs=model.get_layer('resnet50').input,
    #                                 outputs=model.get_layer('resnet50').output)
    # resnet_head = tf.keras.Model(inputs=resnet_model.input,
    #                                 outputs=resnet_model.get_layer('conv2_block1_add').output)
    # resnet_tail = split(resnet_model, 18, 1000) # conv2_block1_out
    # output_model = tf.keras.Model(inputs=model.get_layer('OutputLayer').input,
    #                                 outputs=model.get_layer('OutputLayer').output)
    # archead = tf.keras.Model(inputs=model.get_layer('ArcHead').input,
    #                                 outputs=model.get_layer('ArcHead').output)

    temp1 = np.ones((62, 112, 3))
    temp2 = np.zeros((50, 112, 3))
    masked_img = np.concatenate([temp1, temp2], axis=0).reshape(1, 112, 112, 3)

    temp1 = np.ones((1, 18, 28, 256))
    temp2 = np.zeros((1, 10, 28, 256))
    masked = np.concatenate([temp1, temp2], axis=1)
    # inputs = Input((112, 112, 3))
    # labels = Input([])
    # s = resnet_head(inputs)
    # s = tf.keras.layers.Multiply()([s, tf.constant(masked, dtype=tf.float32)])
    # s = resnet_tail(s)
    # s = output_model(s)
    # s = archead([s, labels])
    # new_model = Model((inputs, labels), s)
    # new_model.summary()

    # new_model.compile(optimizer=optimizer, loss=loss_fn)

    path_to_data = '/home/anhdq23/Desktop/nguyen/data/AR/test2'
    anchor_names = get_gallery_pr2(path_to_data)  # From 1 to 100
    name_dicts = get_probe_pr2(
        path_to_data)  # Dictionary: {anchor_name:[name_image, ...]}

    augment = ImgAugTransform()
    import faiss

    if FLAGS.mode == 'eager_tf':
        top_left_all = [0.012]
        best_acc = 0.8
        for epochs in range(cfg['epochs']):
            logging.info("Shuffle ms1m dataset.")
            dataset_len = cfg['num_samples']
            steps_per_epoch = dataset_len // cfg['batch_size']
            train_dataset = dataset.load_tfrecord_dataset(
                cfg['train_dataset'],
                cfg['batch_size'],
                cfg['binary_img'],
                is_ccrop=cfg['is_ccrop'])

            sub_train_dataset = dataset.load_tfrecord_dataset(
                cfg['train_dataset'],
                cfg['batch_size'],
                cfg['binary_img'],
                is_ccrop=cfg['is_ccrop'])

            for batch, ((x, y), (sub_x, sub_y)) in enumerate(
                    zip(train_dataset, sub_train_dataset)):
                x0_new = np.array(x[0], dtype=np.uint8)
                x1_new = np.array(x[1], dtype=np.float32)
                for i in np.arange(len(x0_new)):
                    x0_new[i] = augment(x0_new[i])
                temp = np.array(x0_new, dtype=np.float32) / 255.0
                temp = np.multiply(temp, masked_img)

                sub_x0_new = np.array(sub_x[0], dtype=np.uint8)
                sub_x1_new = np.array(sub_x[1], dtype=np.float32)
                for i in np.arange(len(sub_x0_new)):
                    sub_x0_new[i] = augment(sub_x0_new[i])
                sub_temp = np.array(sub_x0_new, dtype=np.float32) / 255.0

                model.train_on_batch(*((sub_temp, sub_x1_new), sub_x1_new))
                loss = model.train_on_batch(*((temp, x1_new), x1_new))

                if batch % 50 == 0:
                    verb_str = "Epoch {}/{}: {}/{}, loss={:.6f}, lr={:.6f}"
                    print(
                        verb_str.format(
                            epochs, cfg['epochs'], batch, steps_per_epoch,
                            loss, cfg['base_lr'] / (1.0 + cfg['w_decay'] *
                                                    (epochs * 45489 + batch))))

                    if batch % cfg['save_steps'] == 0:
                        resnet_model = tf.keras.Model(
                            inputs=model.get_layer('resnet50').input,
                            outputs=model.get_layer('resnet50').output)

                        output_model = tf.keras.Model(
                            inputs=model.get_layer('OutputLayer').input,
                            outputs=model.get_layer('OutputLayer').output)

                        database_image_names = []
                        database_feature_list = []
                        for anchor_name in anchor_names:
                            img1 = Image.open(
                                os.path.join(path_to_data, anchor_name))
                            img1 = img1.resize((112, 112))
                            img1 = np.array(img1) / 255.0
                            img1 = np.multiply(img1, masked_img)

                            fc1 = resnet_model.predict(
                                img1.reshape((1, 112, 112, 3)))
                            fc1 = output_model.predict(fc1)
                            norm_fc1 = preprocessing.normalize(fc1.reshape(
                                (1, 512)),
                                                               norm='l2',
                                                               axis=1)
                            database_image_names.append(anchor_name)
                            database_feature_list.append(norm_fc1)

                        index_flat = faiss.IndexFlatL2(512)
                        gpu_index_flat = index_flat
                        gpu_index_flat.add(
                            np.array(database_feature_list).reshape(
                                (-1, 512)))  # add vectors to the index
                        count = 0
                        for key in list(name_dicts.keys()):
                            for name in name_dicts[key]:
                                img2 = Image.open(
                                    os.path.join(path_to_data, name)).resize(
                                        (112, 112))
                                img2 = img2.resize((112, 112))
                                img2 = np.array(img2) / 255.0
                                img2 = np.multiply(img2, masked_img)

                                fc2 = resnet_model.predict(
                                    img2.reshape((1, 112, 112, 3)))
                                fc2 = output_model.predict(fc2)
                                norm_fc2 = preprocessing.normalize(fc2.reshape(
                                    (1, 512)),
                                                                   norm='l2',
                                                                   axis=1)

                                D, I = gpu_index_flat.search(
                                    norm_fc2, k=1)  # actual search
                                if name[0:5] == database_image_names[I[0]
                                                                     [0]][0:5]:
                                    count += 1
                        acc = count / 600.0
                        if acc > best_acc:
                            best_acc = acc
                            print('[*] save ckpt file!')
                            model.save_weights(
                                'checkpoints/{}/e_{}_b_{}.ckpt'.format(
                                    cfg['sub_name'], epochs,
                                    batch % steps_per_epoch))
                        print("The current acc: {:.6f} ".format(acc))
                        print("The best acc: {:.6f} ".format(best_acc))

                    model.save_weights('checkpoints/train_{}/{}.ckpt'.format(
                        cfg['sub_name'], cfg['sub_name']))
コード例 #10
0
    model = ArcFaceModel(size=cfg['input_size'],
                         backbone_type=cfg['backbone_type'],
                         num_classes=cfg['num_classes'],
                         head_type=cfg['head_type'],
                         embd_shape=cfg['embd_shape'],
                         w_decay=cfg['w_decay'],
                         training=True)
    model.summary(line_length=80)

    if cfg['train_dataset']:
        logging.info("load ms1m dataset.")
        dataset_len = cfg['num_samples']
        steps_per_epoch = dataset_len // cfg['batch_size']
        train_dataset = dataset.load_tfrecord_dataset(
            cfg['train_dataset'], cfg['batch_size'], cfg['binary_img'],
            is_ccrop=cfg['is_ccrop'])
    else:
        logging.info("load fake dataset.")
        dataset_len = 1
        steps_per_epoch = 1
        train_dataset = dataset.load_fake_dataset(cfg['input_size'])

    learning_rate = tf.constant(cfg['base_lr'])
    optimizer = tf.keras.optimizers.SGD(
        learning_rate=learning_rate, momentum=0.9, nesterov=True)
    loss_fn = SoftmaxLoss()

    ckpt_path = tf.train.latest_checkpoint('./checkpoints/' + cfg['sub_name'])
    if ckpt_path is not None:
        print("[*] load ckpt from {}".format(ckpt_path))
コード例 #11
0
visualization = True  # False for time cost estimattion
using_encoding = True  # batch size should be 1 when False
variances = [0.1, 0.2]
match_thresh = 0.45
ignore_thresh = 0.3
num_samples = 100

if using_bin:
    tfrecord_name = './data/widerface_train_bin.tfrecord'
else:
    tfrecord_name = './data/widerface_train.tfrecord'

train_dataset = load_tfrecord_dataset(
    tfrecord_name, batch_size, img_dim=640, using_bin=using_bin,
    using_flip=True, using_distort=False, using_encoding=using_encoding,
    priors=priors, match_thresh=match_thresh, ignore_thresh=ignore_thresh,
    variances=variances, shuffle=False)

start_time = time.time()
for idx, (inputs, labels) in enumerate(train_dataset.take(num_samples)):
    print("{} inputs:".format(idx), inputs.shape, "labels:", labels.shape)

    if not visualization:
        continue

    img = np.clip(inputs.numpy()[0], 0, 255).astype(np.uint8)
    if not using_encoding:
        # labels includes loc, landm, landm_valid.
        targets = labels.numpy()[0]
        for target in targets:
コード例 #12
0
def main(_):

    min_sizes = [[16, 32], [64, 128], [256, 512]]
    steps = [8, 16, 32]
    clip = False

    img_dim = 640
    priors = prior_box((img_dim, img_dim), min_sizes, steps, clip)

    variances = [0.1, 0.2]
    match_thresh = 0.45
    ignore_thresh = 0.3
    batch_size = 1
    shuffle = True
    using_flip = True
    using_distort = True
    using_bin = True
    buffer_size = 4000
    number_cycles = 2
    threads = 2

    check_dataset = load_tfrecord_dataset(dataset_root=FLAGS.dataset_path,
                                          split=FLAGS.split,
                                          threads=threads,
                                          number_cycles=number_cycles,
                                          batch_size=batch_size,
                                          hvd=[],
                                          img_dim=img_dim,
                                          using_bin=using_bin,
                                          using_flip=using_flip,
                                          using_distort=using_distort,
                                          using_encoding=FLAGS.using_encoding,
                                          priors=priors,
                                          match_thresh=match_thresh,
                                          ignore_thresh=ignore_thresh,
                                          variances=variances,
                                          shuffle=shuffle,
                                          buffer_size=buffer_size)

    time.time()
    for idx, (inputs, labels, _) in enumerate(check_dataset):
        print("{} inputs:".format(idx), inputs.shape, "labels:", labels.shape)

        if not FLAGS.visualization:
            continue

        img = np.clip(inputs.numpy()[0], 0, 255).astype(np.uint8)
        if not FLAGS.using_encoding:
            # labels includes loc, landm, landm_valid.
            targets = labels.numpy()[0]
            for target in targets:
                draw_bbox_landm(img, target, img_dim, img_dim)
        else:
            # labels includes loc, landm, landm_valid, conf.
            targets = decode_tf(labels[0], priors, variances=variances).numpy()
            for prior_index in range(len(targets)):
                if targets[prior_index][-1] != 1:
                    continue

                draw_bbox_landm(img, targets[prior_index], img_dim, img_dim)
                draw_anchor(img, priors[prior_index], img_dim, img_dim)

        cv2.imwrite('{}/{}.png'.format(FLAGS.output_path, str(idx)),
                    img[:, :, ::-1])
コード例 #13
0
def main(_):
    min_sizes = [[16, 32], [64, 128], [256, 512]]
    steps = [8, 16, 32]
    clip = False

    img_dim = 640
    priors = prior_box((img_dim, img_dim), min_sizes, steps, clip)

    variances = [0.1, 0.2]
    match_thresh = 0.45
    ignore_thresh = 0.3
    num_samples = 100

    if FLAGS.using_encoding:
        assert FLAGS.batch_size == 1

    if FLAGS.using_bin:
        tfrecord_name = './data/widerface_train_bin.tfrecord'
    else:
        tfrecord_name = './data/widerface_train.tfrecord'

    train_dataset = load_tfrecord_dataset(tfrecord_name,
                                          FLAGS.batch_size,
                                          img_dim=640,
                                          using_bin=FLAGS.using_bin,
                                          using_flip=True,
                                          using_distort=False,
                                          using_encoding=FLAGS.using_encoding,
                                          priors=priors,
                                          match_thresh=match_thresh,
                                          ignore_thresh=ignore_thresh,
                                          variances=variances,
                                          shuffle=False)

    start_time = time.time()
    for idx, (inputs, labels) in enumerate(train_dataset.take(num_samples)):
        print("{} inputs:".format(idx), inputs.shape, "labels:", labels.shape)

        if not FLAGS.visualization:
            continue

        img = np.clip(inputs.numpy()[0], 0, 255).astype(np.uint8)
        if not FLAGS.using_encoding:
            # labels includes loc, landm, landm_valid.
            targets = labels.numpy()[0]
            for target in targets:
                draw_bbox_landm(img, target, img_dim, img_dim)
        else:
            # labels includes loc, landm, landm_valid, conf.
            targets = decode_tf(labels[0], priors, variances=variances).numpy()
            for prior_index in range(len(targets)):
                if targets[prior_index][-1] != 1:
                    continue

                draw_bbox_landm(img, targets[prior_index], img_dim, img_dim)
                draw_anchor(img, priors[prior_index], img_dim, img_dim)

        cv2.imshow('img', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        if cv2.waitKey(0) == ord('q'):
            exit()

    print("data fps: {:.2f}".format(num_samples / (time.time() - start_time)))
コード例 #14
0
def main(_):
    set_memory_growth()

    cfg = load_yaml(FLAGS.cfg_path)
    model = ArcFaceModel(size=cfg['input_size'],
                         backbone_type=cfg['backbone_type'],
                         num_classes=cfg['num_classes'],
                         head_type=cfg['head_type'],
                         embd_shape=cfg['embd_shape'],
                         w_decay=cfg['w_decay'],
                         training=True)
    model.summary()

    learning_rate = tf.constant(cfg['base_lr'])
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
                                        momentum=0.9,
                                        nesterov=True)
    loss_fn = SoftmaxLoss()
    ckpt_path = tf.train.latest_checkpoint('./checkpoints/' + cfg['sub_name'])
    if ckpt_path is not None:
        print("[*] load ckpt from {}".format(ckpt_path))
        model.load_weights(ckpt_path)
    else:
        print("[*] training from scratch.")

    model.compile(optimizer=optimizer, loss=loss_fn)

    data_path = 'data'
    lfw, lfw_issame = get_val_pair(data_path, 'lfw_align_112/lfw')
    lfw = np.transpose(lfw, [0, 2, 3, 1]) * 0.5 + 0.5

    image_1 = lfw[0::2]
    image_2 = lfw[1::2]

    augment = ImgAugTransform()
    if FLAGS.mode == 'eager_tf':
        top_left_all = [0.008807]
        for epochs in range(cfg['epochs']):
            logging.info("Shuffle ms1m dataset.")
            dataset_len = cfg['num_samples']
            steps_per_epoch = dataset_len // cfg['batch_size']
            train_dataset = dataset.load_tfrecord_dataset(
                cfg['train_dataset'],
                cfg['batch_size'],
                cfg['binary_img'],
                is_ccrop=cfg['is_ccrop'])

            for batch, (x, y) in enumerate(train_dataset):
                x0_new = np.array(x[0], dtype=np.uint8)
                x1_new = np.array(x[1], dtype=np.float32)
                for i in np.arange(len(x0_new)):
                    x0_new[i] = augment(x0_new[i])
                temp = np.array(x0_new, dtype=np.float32) / 255.0

                loss = model.train_on_batch(*((temp, x1_new), x1_new))

                if batch % 50 == 0:
                    verb_str = "Epoch {}/{}: {}/{}, loss={:.6f}, lr={:.6f}"
                    print(
                        verb_str.format(
                            epochs, cfg['epochs'], batch, steps_per_epoch,
                            loss, cfg['base_lr'] / (1.0 + cfg['w_decay'] *
                                                    (epochs * 45489 + batch))))

                    if batch % cfg['save_steps'] == 0:
                        resnet_model = tf.keras.Model(
                            inputs=model.get_layer('resnet50').input,
                            outputs=model.get_layer('resnet50').output)

                        output_model = tf.keras.Model(
                            inputs=model.get_layer('OutputLayer').input,
                            outputs=model.get_layer('OutputLayer').output)

                        dist_all = []
                        top_left_batch = []
                        for idx in range(0, len(lfw_issame),
                                         cfg['batch_size']):
                            tem = resnet_model.predict(
                                image_1[idx:idx + cfg['batch_size']])
                            embeds_1 = output_model.predict(tem)
                            norm_embeds_1 = preprocessing.normalize(embeds_1,
                                                                    norm='l2',
                                                                    axis=1)

                            tem = resnet_model.predict(
                                image_2[idx:idx + cfg['batch_size']])
                            embeds_2 = output_model.predict(tem)
                            norm_embeds_2 = preprocessing.normalize(embeds_2,
                                                                    norm='l2',
                                                                    axis=1)

                            diff = np.subtract(norm_embeds_1, norm_embeds_2)
                            dist = np.sqrt(np.sum(np.square(diff), 1)) / 2
                            dist_all.extend(dist)

                        thresholds = np.arange(0, 1, 0.01)
                        for thr in thresholds:
                            tpr, fpr, _ = calculate_accuracy(
                                thr, np.array(dist_all), lfw_issame)
                            top_left = np.sqrt((1 - tpr)**2 + fpr**2)
                            top_left_batch.append(top_left)
                        print(
                            "The current top left: {:.6f}     Threshold: {:.2f}"
                            .format(np.min(top_left_batch),
                                    0.01 * np.argmin(top_left_batch)))

                        if not len(top_left_all):
                            print(
                                "The best top left: {:.6f}     Threshold: {:.2f}"
                                .format(np.min(top_left_batch),
                                        0.01 * np.argmin(top_left_batch)))
                        else:
                            print("The best top left: {:.6f}".format(
                                top_left_all[-1]))

                        if not len(top_left_all):
                            top_left_all.append(np.min(top_left_batch))
                            print('[*] save ckpt file!')
                            model.save_weights(
                                'checkpoints/{}/e_{}_b_{}.ckpt'.format(
                                    cfg['sub_name'], epochs,
                                    batch % steps_per_epoch))

                        elif top_left_all[-1] > np.min(top_left_batch):
                            top_left_all.append(np.min(top_left_batch))
                            print('[*] save ckpt file!')
                            model.save_weights(
                                'checkpoints/{}/e_{}_b_{}.ckpt'.format(
                                    cfg['sub_name'], epochs,
                                    batch % steps_per_epoch))

                    model.save_weights('checkpoints/train_{}/{}.ckpt'.format(
                        cfg['sub_name'], cfg['sub_name']))