예제 #1
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    img = cv2.imread(args.input, -1)
    img = cv2.resize(img, (224, 224))
    img_tensor = img_to_tensor(img, squeeze=True, cuda=args.use_cuda)

    model = build_retriever(cfg.model)
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

    if args.use_cuda:
        model.cuda()

    model.eval()

    query_embed = model(img_tensor, landmark=None, return_loss=False)

    query_embed = query_embed.data.cpu().numpy()

    gallery_set = build_dataset(cfg.data.gallery)
    gallery_embeds = _process_embeds(gallery_set, model, cfg)

    retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
                                 cfg.data.gallery.img_path)
    retriever.show_retrieved_images(query_embed, gallery_embeds)
예제 #2
0
def main():
    seed = 0
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    args = parse_args()
    cfg = Config.fromfile(args.config)

    model = build_retriever(cfg.model)
    load_checkpoint(model, args.checkpoint)
    print('load checkpoint from {}'.format(args.checkpoint))

    if args.use_cuda:
        model.cuda()
    model.eval()

    img_tensor = get_img_tensor(args.input, args.use_cuda)

    query_feat = model(img_tensor, landmark=None, return_loss=False)
    query_feat = query_feat.data.cpu().numpy()

    gallery_set = build_dataset(cfg.data.gallery)
    gallery_embeds = _process_embeds(gallery_set, model, cfg)

    retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
                                 cfg.data.gallery.img_path)
    retriever.show_retrieved_images(query_feat, gallery_embeds)
예제 #3
0
파일: app.py 프로젝트: Dogacel/mmfashion
def retrieve():
    file = request.files.get('image')
    img_tensor = get_img_tensor(file, True)
 
    query_feat = model_ret(img_tensor, landmark=None, return_loss=False)
    query_feat = query_feat.data.cpu().numpy()
    gallery_set = build_dataset(cfg_ret.data.gallery)
    gallery_embeds = _process_embeds(gallery_set, model_ret, cfg_ret)
    retriever = ClothesRetriever(cfg_ret.data.gallery.img_file, cfg_ret.data_root,
                                 cfg_ret.data.gallery.img_path)

    result = retriever.show_retrieved_images(query_feat, gallery_embeds)
    resultDict = {}
    resultDict['paths'] = result
    return jsonify(resultDict)
예제 #4
0
def _init_models():
    args = parse_args()

    # Build retrieval model and load checkpoint
    cfg = mmcv.Config.fromfile(args.config_retrieval)
    model_rt = build_retriever(cfg.model)
    load_checkpoint(model_rt, args.checkpoint_retrieval)
    print('load retriever checkpoint from {}'.format(
        args.checkpoint_retrieval))

    # Build landmark detection model and load checkpoint
    cfg_lm = mmcv.Config.fromfile(args.config_landmark)
    model_lm = build_landmark_detector(cfg_lm.model)
    load_checkpoint(model_lm, args.checkpoint_landmark)
    print('load landmark detector checkpoint from: {}'.format(
        args.checkpoint_landmark))

    if args.use_cuda:
        model_rt.cuda()
        model_lm.cuda()
    model_rt.eval()
    model_lm.eval()

    # Build database for retrieval
    gallery_list = np.load(args.image_list)
    gallery_embeds = _process_embeds(args.image_embeddings)
    retriever = ClothesRetriever(gallery_list, [args.topk])
    print('build database for retrieval')

    # Return retrieval, landmark, and detection model, database for retrieval and retriever
    return model_rt, model_lm, gallery_embeds, retriever
예제 #5
0
def main():
    seed = 0

    torch.manual_seed(seed)
    args = parse_args()
    if args.use_cuda and torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    cfg = Config.fromfile(args.config)

    model = build_retriever(cfg.model)
    load_checkpoint(model, args.checkpoint, map_location=torch.device('cuda:0'))
    print('load checkpoint from {}'.format(args.checkpoint))

    if args.use_cuda:
        model.cuda()
    model.eval()

    print('Model evaled')
    img_tensor = get_img_tensor(args.input, args.use_cuda)
    print('Image tensor got.')
    query_feat = model(img_tensor, landmark=None, return_loss=False)
    print('Query feat 1')
    query_feat = query_feat.data.cpu().numpy()
    print('Query feat 2')
    gallery_set = build_dataset(cfg.data.gallery)
    print('Gallery set')
    gallery_embeds = _process_embeds(gallery_set, model, cfg)
    print('Gallery embeds')
    retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
                                 cfg.data.gallery.img_path)
    print('Retriever')
    results = retriever.show_retrieved_images(query_feat, gallery_embeds)
    print('Show retriever')
    for result in results:
        print(result)
예제 #6
0
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    cfg = Config.fromfile(args.config)

    model = build_retriever(cfg.model)
    load_checkpoint(model, args.checkpoint)
    print('load checkpoint from {}'.format(args.checkpoint))

    if args.use_cuda:
        model = model.cuda()
    model.eval()

    gallery_set = build_dataset(cfg.data.gallery)
    gallery_embeds = _process_embeds(gallery_set, model, cfg)

    retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
                                 cfg.data.gallery.img_path)

    model = model.cpu()
    torch.cuda.empty_cache()

    # Flask
    #------------------------------------
    print("MMFashion server started!")
    app.debug = args.debug
    if (args.enable_threaded):
        app.run(host=args.host, port=args.port, threaded=False)
    else:
        app.run(host=args.host, port=args.port, threaded=True)