Example #1
0
def predict(args):
    logger = set_logger("extractor", args.out_dir)
    for name, val in vars(args).items():
        logger.info("[PARAMS] {}: {}".format(name, val))
    try:
        base_model = nets[args.net]()
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return 
    logger.info("Extracting {} feature.".format(args.net))

    ## construct model
    input_shape = (INPUT_SIZE, INPUT_SIZE, 3)
    inputs = keras.layers.Input(shape=input_shape)
    feature, logits = base_model.embedder(inputs)
    model = keras.Model(
        inputs=inputs,
        outputs=[feature, logits],
        name="retrieval"
        )
    model.summary()
    keras.utils.plot_model(model, '{}/model.png'.format(args.out_dir), show_shapes=True)
    
    ## load weights
    try:
        latest = tf.train.latest_checkpoint(args.out_dir)
        logger.info("Loading pretrained weights from {}".format(latest))
        model.load_weights(latest).expect_partial()
        # model = keras.models.load_model(os.path.join(args.out_dir, "model_final.h5"), compile=False)
    except Exception as e:
        logger.info(e)
        logger.info("Loading failed, using initialized weights")
    keras.backend.set_learning_phase(False)

    logger.info("Loading test_data ......")
    query_data = load_test_data(args.query_data, args.batch_size)
    gallery_data = load_test_data(args.gallery_data, args.batch_size)
    logger.info("Start extracting ......")
    for test_data in [query_data, gallery_data]:
        len_data = int(tf.data.experimental.cardinality(test_data))
        progbar = keras.utils.Progbar(len_data)
        for i, (batch_imgs, batch_names) in enumerate(test_data, 1):
            progbar.update(i)
            batch_features, batch_predicts = model.predict(batch_imgs)
            write_result(
                logger, 
                args, 
                batch_features, 
                batch_predicts, 
                batch_names
            )
Example #2
0
def extract(args):
    logger = set_logger("extractor", args.out_dir)
    for name, val in vars(args).items():
        logger.info("[PARAMS] {}: {}".format(name, val))
    try:
        model = nets[args.net]()
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return 
    logger.info("Extracting {} feature.".format(args.net))

    ## load weights
    try:
        latest = tf.train.latest_checkpoint(args.out_dir)
        logger.info("Loading pretrained weights from {}".format(latest))
        model.load_weights(latest).expect_partial()
    except Exception as e:
        logger.info(e)
        logger.info("Loading failed, using initialized weights")
    model.trainable = False
    keras.backend.set_learning_phase(False)

    logger.info("Loading test_data ......")
    query_data = load_test_data(args.query_data, args.batch_size)
    gallery_data = load_test_data(args.gallery_data, args.batch_size)
    logger.info("Start extracting ......")
    for test_data in [query_data, gallery_data]:
        len_data = int(tf.data.experimental.cardinality(test_data))
        progbar = keras.utils.Progbar(len_data)
        for i, (batch_imgs, batch_names) in enumerate(test_data, 1):
            progbar.update(i)
            # batch_imgs = tf.stop_gradient(batch_imgs)
            batch_features, batch_predicts = model(batch_imgs, training=False)
            write_result(
                logger, 
                args, 
                batch_features, 
                batch_predicts, 
                batch_names
            )
Example #3
0
parser.add_argument('--omega',
                    type=float,
                    default=0.7,
                    help='weight between losses')
parser.add_argument('--out_dir',
                    type=str,
                    default='output',
                    help='output directory')
parser.add_argument('--use-cpu',
                    dest='use_cpu',
                    action='store_true',
                    help='this is tf-gpu, use-cpu to run on cpu')
args, _ = parser.parse_known_args()
if args.use_cpu:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
logger = set_logger("trainer", args.out_dir)
for name, val in vars(args).items():
    logger.info("[PARAMS] {}: {}".format(name, val))


def compute_accuracy(y_pred, y_true):
    correct_predictions = tf.equal(tf.argmax(y_pred, 1),
                                   tf.cast(y_true, tf.int64))
    accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
    return accuracy


tf.reset_default_graph()
try:
    xentropy_loss = tf.losses.sparse_softmax_cross_entropy
    ## model forward
Example #4
0
                    action='store_true',
                    help='this is tf-gpu, use-cpu to run on cpu')
args, _ = parser.parse_known_args()

if args.use_cpu:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

if args.pcaw_path:
    m = np.load(os.path.join(args.pcaw_path, "mean.npy"))
    P = np.load(os.path.join(args.pcaw_path, "pcaw.npy"))
    pcaw = PCAW(m, P, args.pcaw_dims)
    args.pcaw = pcaw
else:
    args.pcaw = None

logger = set_logger("extractor", args.out_dir)
for name, val in vars(args).items():
    logger.info("[PARAMS] {}: {}".format(name, val))


def write_result(batch_features, batch_predicts, batch_names):
    if args.pcaw is not None:
        batch_features = args.pcaw(batch_features.T, transpose=True)
    batch_predicts = np.argmax(batch_predicts, axis=1)
    for feature_per_image, predict_per_image, name_per_image in zip(
            batch_features, batch_predicts, batch_names):
        ## convert name from `bytes` to `str`
        name_per_image = name_per_image.decode("utf-8")
        try:
            out_dir = os.path.join(args.out_dir, 'feat')
            if not os.path.exists(out_dir):
Example #5
0
    parser = argparse.ArgumentParser(description='Deep Fashion2 Retrieval.')
    parser.add_argument('--not_use_gpu', dest='not_use_gpu', action='store_true', help='do not use gpu')
    parser.add_argument('--iteration_num', type=int, default=25000, help='epoch num')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size')
    parser.add_argument('--checkpoint_period', type=int, default=10000, help='period to save checkpoint')
    parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
    parser.add_argument('--lr_milestones', type=int, nargs='+', default=[12000, 20000], help='milestones for lr_scheduler')
    parser.add_argument('--lr_gamma', type=float, default=0.1, help='gamma for lr_scheduler')
    parser.add_argument('--margin', type=float, default=0.3, help='margin')
    parser.add_argument('--omega', type=float, default=0.9, help='weight between losses')
    parser.add_argument('--use-labelsmooth', dest='use_labelsmooth', action='store_true', help='use CrossEntropyLabelSmooth, otherwise CrossEntropyLoss')
    parser.add_argument('--use-warmup', dest='use_warmup', action='store_true', help='use warmup-scheduler for training')
    parser.add_argument('--warmup_factor', type=float, default=0.01, help='warmup factor for initial low-lr training')
    parser.add_argument('--warmup_iters', type=int, default=500, help='warmup training iterations')
    parser.add_argument('--warmup_method', type=str, default="linear", help=' method to increase warmup lr')
    parser.add_argument('--use-hardtriplet', dest='use_hardtriplet', action='store_true', help='use triplet loss with hard mining')
    parser.add_argument('--use-amp', dest='use_amp', action='store_true', help='use Automatic Mixed Precision')
    parser.add_argument('--out_dir', type=str, default='output', help='output directory (do not save if out_dir is empty)')
    args, _ = parser.parse_known_args()
    args.device = torch.device("cuda") if torch.cuda.is_available() and not args.not_use_gpu else torch.device("cpu")

    logger = set_logger("main", args.out_dir)
    for name, val in vars(args).items():
        logger.info("[PARAMS] {}: {}".format(name, val))
        
    train(args)

'''bash
CUDA_VISIBLE_DEVICES=1 python trainer.py --use-amp
'''
Example #6
0
def train(args):
    logger = set_logger("trainer", args.out_dir)
    for name, val in vars(args).items():
        logger.info("[PARAMS] {}: {}".format(name, val))
    try:
        model = nets[args.net]()
        siamese_model = ntk.build_threestream_siamese_network(model)
        siamese_model.summary()
        keras.utils.plot_model(siamese_model,
                               '{}/siamese_model.png'.format(args.out_dir),
                               show_shapes=True)
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return
    logger.info("Training {}.".format(args.net))

    keras.backend.set_learning_phase(True)

    loss_objects = [
        net_loss.TripletLoss(args.margin),
        keras.losses.SparseCategoricalCrossentropy(),
    ]
    optimizer = make_optimizer(args)
    train_acc = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    val_acc = tf.keras.metrics.SparseCategoricalAccuracy(
        name='validation_accuracy')

    arguments = {}
    arguments.update(vars(args))
    arguments["itr"] = 0

    logger.info("Loading train_data ......")
    train_data = load_train_triplet(args.batch_size, args.train_num)
    ## split train and validation set
    val_data = train_data.take(args.val_num)
    train_data = train_data.skip(args.val_num)
    ## train model
    logger.info("Start training ......")
    meters = MetricLogger(delimiter=", ")
    max_itr = args.train_num
    start_itr = arguments["itr"] + 1
    itr_start_time = time.time()
    training_start_time = time.time()
    for itr, batch_data in enumerate(train_data, start_itr):
        loss_dict = train_step(siamese_model, optimizer, train_acc, args.omega,
                               loss_objects[0], loss_objects[1], *batch_data)
        arguments["itr"] = itr
        meters.update(**loss_dict)
        itr_time = time.time() - itr_start_time
        itr_start_time = time.time()
        meters.update(itr_time=itr_time)
        if itr % 50 == 0:
            eta_seconds = meters.itr_time.avg * (max_itr - itr)
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))
            logger.info(
                meters.delimiter.join([
                    "itr: {itr}/{max_itr}",
                    "lr: {lr:.7f}",
                    "{meters}",
                    "train_accuracy: {train_acc:.2f}",
                    "eta: {eta}\n",
                ]).format(
                    itr=itr,
                    # lr=optimizer.lr.numpy(),
                    lr=optimizer._decayed_lr(tf.float32),
                    max_itr=max_itr,
                    meters=str(meters),
                    train_acc=train_acc.result().numpy(),
                    eta=eta))

        ## save model
        if itr % args.checkpoint_period == 0:
            model.save_weights("{}/model_{:07d}.ckpt".format(
                args.out_dir, itr))
            ## validation
            for batch_data in val_data:
                val_step(model, val_acc, *batch_data)
            logger.info("val_accuracy: {:.2f}\n".format(
                val_acc.result().numpy()))
            train_acc.reset_states()
            val_acc.reset_states()

        if itr == max_itr:
            model.save_weights("{}/model_final.ckpt".format(args.out_dir))
            with open(os.path.join(args.out_dir, "arguments.json"), "w") as fw:
                json.dump(arguments, fw)
            break

    training_time = time.time() - training_start_time
    training_time = str(datetime.timedelta(seconds=int(training_time)))
    logger.info("total training time: {}".format(training_time))
Example #7
0
def fit(args):
    logger = set_logger("trainer", args.out_dir)
    for name, val in vars(args).items():
        logger.info("[PARAMS] {}: {}".format(name, val))
    try:
        base_model = nets[args.net]()
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return
    logger.info("Training {}.".format(args.net))

    logger.info("Loading train_data ......")
    train_data = load_train_pair_with_two_targets(args.batch_size,
                                                  args.train_num)
    val_data = train_data.take(args.val_num)
    train_data = train_data.skip(args.val_num)
    # images, labels = load_train_pair_from_numpy(args.train_num)

    keras.backend.set_learning_phase(True)
    ## construct model
    input_shape = (INPUT_SIZE, INPUT_SIZE, 3)
    inputs = keras.layers.Input(shape=input_shape)
    feature, logits = base_model.embedder(inputs)
    model = keras.Model(inputs=inputs,
                        outputs=[feature, logits],
                        name="retrieval")
    model.summary()
    keras.utils.plot_model(model,
                           '{}/model.png'.format(args.out_dir),
                           show_shapes=True)
    ## compile model
    losses = {
        HEAD_FT: net_loss.TripletSemiHardLoss(args.margin),
        HEAD_CLS: keras.losses.SparseCategoricalCrossentropy()
    }
    loss_weights = {HEAD_FT: args.omega, HEAD_CLS: 1 - args.omega}
    metrics = {HEAD_CLS: [keras.metrics.SparseCategoricalAccuracy()]}

    if args.optimizer == "adam":
        optimizer = keras.optimizers.Adam(learning_rate=args.learning_rate)
    elif args.optimizer == "sgd":
        optimizer = keras.optimizers.SGD(learning_rate=args.learning_rate)
    else:
        raise NotImplementedError("Choose Adam or SGD")

    callbacks = make_callbacks(args)
    model.compile(optimizer=optimizer,
                  loss=losses,
                  loss_weights=loss_weights,
                  metrics=metrics)
    ## train model
    logger.info("Start training ......")
    H = model.fit(x=train_data.repeat(),
                  validation_data=val_data,
                  epochs=args.epoch_num,
                  steps_per_epoch=args.steps_per_epoch,
                  callbacks=callbacks,
                  verbose=1)
    # H = model.fit(
    #     x=images,
    #     y={
    #         HEAD_FT: labels,
    #         HEAD_CLS: labels
    #     },
    #     validation_split=0.1,
    #     epochs=args.epoch_num,
    #     batch_size=args.batch_size,
    #     callbacks=callbacks,
    #     verbose=1
    # )

    train_acc = H.history["logits_sparse_categorical_accuracy"]
    val_acc = H.history["val_logits_sparse_categorical_accuracy"]
    logger.info("train_accuracy: {}".format(
        list(map(lambda e: round(e, 3), train_acc))))
    logger.info("val_accuracy: {}".format(
        list(map(lambda e: round(e, 3), val_acc))))

    model.save(os.path.join(args.out_dir, "model_final.h5"))
    # new_model = keras.models.load_model("model_final.h5")
    with open(os.path.join(args.out_dir, "arguments.json"), "w") as fw:
        json.dump(vars(args), fw)