Exemplo n.º 1
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint
    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # data loader
    dataset = build_dataset(cfg.data.test)
    print('dataset loaded')

    # create model
    model = build_fashion_recommender(cfg.model)
    load_checkpoint(model, cfg.load_from, map_location='cpu')
    print('load checkpoint from: {}'.format(cfg.load_from))

    test_fashion_recommender(model,
                             dataset,
                             cfg,
                             distributed=False,
                             validate=False,
                             logger=None)
Exemplo n.º 2
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint
    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # build model and load checkpoint
    if args.stage == 'GMM':
        # test geometric matching
        # data loader
        dataset = get_dataset(cfg.data.test.GMM)
        print('GMM dataset loaded')

        model = build_geometric_matching(cfg.GMM)
        print('GMM model built')
        load_checkpoint(model, cfg.load_from, map_location='cpu')
        print('load checkpoint from: {}'.format(cfg.load_from))

        test_geometric_matching(model,
                                dataset,
                                cfg,
                                distributed=distributed,
                                validate=False,
                                logger=logger)

    elif args.stage == 'TOM':
        # test tryon module
        dataset = get_dataset(cfg.data.test.TOM)
        print('TOM dataset loaded')

        model = build_tryon(cfg.TOM)
        print('TOM model built')
        load_checkpoint(model, cfg.load_from, map_location='cpu')
        print('load checkpoint from: {}'.format(cfg.load_from))

        test_tryon(model,
                   dataset,
                   cfg,
                   distributed=distributed,
                   validate=False,
                   logger=logger)
Exemplo n.º 3
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    cfg.work_dir = os.path.join(cfg.work_dir, args.stage)
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from

    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    # build model
    if args.stage == 'GMM':
        model = build_geometric_matching(cfg.GMM)
        print('Geometric Matching Module built')
        dataset = get_dataset(cfg.data.train.GMM)
        print('GMM dataset loaded')
        train_geometric_matching(model,
                                 dataset,
                                 cfg,
                                 distributed=distributed,
                                 validate=args.validate,
                                 logger=logger)
    elif args.stage == 'TOM':
        model = build_tryon(cfg.TOM)
        print('Try-On Module built')
        dataset = get_dataset(cfg.data.train.TOM)
        print('TOM dataset loaded')
        train_tryon(model,
                    dataset,
                    cfg,
                    distributed=distributed,
                    validate=args.validate,
                    logger=logger)
    else:
        raise ValueError('stage should be GMM or TOM')
Exemplo n.º 4
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    cfg.load_from = args.checkpoint
    # init distributed env first
    distributed = False

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # create model
    model = build_fashion_recommender(cfg.model)
    load_checkpoint(model, cfg.load_from, map_location='cpu')
    print('load checkpoint from: {}'.format(cfg.load_from))
    if args.use_cuda:
        model.cuda()
    model.eval()

    # prepare input data
    img_tensors = []
    item_ids = []

    for dirpath, dirname, fns in os.walk(args.input_dir):
        for imgname in fns:
            item_ids.append(imgname.split('.')[0])
            tensor = get_img_tensor(
                os.path.join(dirpath, imgname), args.use_cuda)
            img_tensors.append(tensor)
    img_tensors = torch.cat(img_tensors)

    # test
    embeds = []
    with torch.no_grad():
        embed = model(img_tensors, return_loss=False)
        embeds.append(embed.data.cpu())
    embeds = torch.cat(embeds)

    try:
        metric = model.module.triplet_net.metric_branch
    except Exception:
        metric = None

    # get compatibility score, so far only support images from polyvore
    dataset = build_dataset(cfg.data.test)

    score = dataset.get_single_compatibility_score(embeds, item_ids, metric,
                                                   args.use_cuda)
    print("Compatibility score: {:.3f}".format(score))
Exemplo n.º 5
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from

    # init distributed env
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    # build predictor to extract embeddings
    model = build_retriever(cfg.model)
    print('model built')

    if cfg.init_weights_from is not None:
        model = init_weights_from(cfg.init_weights_from, model)
        print('Initialize model weights from {}'.format(cfg.init_weights_from))

    # data loader
    dataset = build_dataset(cfg.data.train)
    print('dataset loaded')

    # train
    train_retriever(
        model,
        dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)
Exemplo n.º 6
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir

    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # data loader
    cfg.data.query.find_three = False
    cfg.data.gallery.find_three = False
    query_set, gallery_set = build_dataset(cfg.data.query), build_dataset(
        cfg.data.gallery)
    print('dataset loaded')

    # build model and load checkpoint
    model = build_retriever(cfg.model)
    print('model built')

    load_checkpoint(model, cfg.load_from)
    print('load checkpoint from: {}'.format(cfg.load_from))

    # test
    test_retriever(model,
                   query_set,
                   gallery_set,
                   cfg,
                   distributed=distributed,
                   validate=args.validate,
                   logger=logger)
Exemplo n.º 7
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir

    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # data loader
    test_dataset = get_dataset(cfg.data.test)
    print('dataset loaded')

    # build model and load checkpoint
    model = build_landmark_detector(cfg.model)
    print('model built')

    checkpoint = load_checkpoint(model, cfg.load_from, map_location='cpu')
    print('load checkpoint from: {}'.format(cfg.load_from))

    # test
    test_landmark_detector(
        model,
        test_dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)