def main():

    Logger('.')

    #classes = ['pizza', 'pork chops', 'cupcake', 'hamburger', 'green beans']
    split = 'test'
    dir_exp = '/home/cadene/doc/bootstrap.pytorch/logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract_mean_features', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model_ckpt = os.path.join(
        dir_exp,
        'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')

    Options.load_from_yaml(path_opts)
    utils.set_random_seed(Options()['misc']['seed'])

    Logger()('Load dataset...')
    dataset = factory(split)

    Logger()('Load model...')
    model = model_factory()
    model_state = torch.load(path_model_ckpt)
    model.load_state_dict(model_state)
    model.set_mode(split)

    if not os.path.isdir(dir_extract):
        Logger()('Create extract_dir {}'.format(dir_extract))
        os.system('mkdir -p ' + dir_extract)

        mean_ingrs = torch.zeros(model.network.recipe_embedding.dim_ingr_out *
                                 2)  # bi LSTM
        mean_instrs = torch.zeros(model.network.recipe_embedding.dim_instr_out)

        for i in tqdm(range(len(dataset))):
            item = dataset[i]
            batch = dataset.items_tf()([item])

            batch = model.prepare_batch(batch)
            out_ingrs = model.network.recipe_embedding.forward_ingrs(
                batch['recipe']['ingrs'])
            out_instrs = model.network.recipe_embedding.forward_instrs(
                batch['recipe']['instrs'])

            mean_ingrs += out_ingrs.data.cpu().squeeze(0)
            mean_instrs += out_instrs.data.cpu().squeeze(0)

        mean_ingrs /= len(dataset)
        mean_instrs /= len(dataset)

        path_ingrs = os.path.join(dir_extract, 'ingrs.pth')
        path_instrs = os.path.join(dir_extract, 'instrs.pth')

        torch.save(mean_ingrs, path_ingrs)
        torch.save(mean_instrs, path_instrs)

    Logger()('End')
Exemplo n.º 2
0
def main():
    classes = ['pizza', 'pork chops', 'cupcake', 'hamburger', 'green beans']
    nb_points = 100
    split = 'test'
    dir_exp = 'logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model = os.path.join(dir_exp, 'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')
    
    dir_visu = os.path.join(dir_exp, 'visu', 'top5')

    #Options(path_opts)
    Options.load_from_yaml(path_opts)
    dataset = factory(split)

    network = Trijoint()
    network.eval()
    model_state = torch.load(path_model)
    network.load_state_dict(model_state['network'])

    list_idx = torch.randperm(len(dataset))

    img_embs = []
    rcp_embs = []
    for i in range(nb_points):
        idx = list_idx[i]
        path_img = os.path.join(dir_img, '{}.pth'.format(idx))
        path_rcp = os.path.join(dir_rcp, '{}.pth'.format(idx))
        img_embs.append(torch.load(path_img))
        rcp_embs.append(torch.load(path_rcp))

    img_embs = torch.stack(img_embs, 0)
    rcp_embs = torch.stack(rcp_embs, 0)

    dist = fast_distance(img_embs, rcp_embs)

    im2recipe_ids = np.argsort(dist.numpy(), axis=0)
    recipe2im_ids = np.argsort(dist.numpy(), axis=1) 

    import ipdb; ipdb.set_trace()
def main():

    Logger('.')

    split = 'train'
    dir_exp = '/home/ubuntu/moochi/recipe1m.bootstrap.pytorch/logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract_count', split)
    path_ingrs_count = os.path.join(dir_extract, 'ingrs.pth')

    Options(path_opts)
    utils.set_random_seed(Options()['misc']['seed'])

    dataset = factory(split)

    if not os.path.isfile(path_ingrs_count):
        ingrs_count = {}
        os.system('mkdir -p ' + dir_extract)

        for i in tqdm(range(len(dataset.recipes_dataset))):
            item = dataset.recipes_dataset[i]
            for ingr in item['ingrs']['interim']:
                if ingr not in ingrs_count:
                    ingrs_count[ingr] = 1
                else:
                    ingrs_count[ingr] += 1

        torch.save(ingrs_count, path_ingrs_count)
    else:
        ingrs_count = torch.load(path_ingrs_count)

    import ipdb
    ipdb.set_trace()
    sort = sorted(ingrs_count, key=ingrs_count.get)
    import ipdb
    ipdb.set_trace()

    Logger()('End')
def main():

    Logger('.')

    #classes = ['hamburger']
    #nb_points = 
    split = 'test'
    class_name = None#'potato salad'
    modality_to_modality = 'recipe_to_image'
    dir_exp = '/home/cadene/doc/bootstrap.pytorch/logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract', split)
    dir_extract_mean = os.path.join(dir_exp, 'extract_mean_features', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model_ckpt = os.path.join(dir_exp, 'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')

    

    #is_mean = True    
    #ingrs_list = ['carotte', 'salad', 'tomato']#['avocado']

    #Options(path_opts)
    Options(path_opts)
    Options()['misc']['seed'] = 11
    utils.set_random_seed(Options()['misc']['seed'])

    chosen_item_id = 51259
    dataset = factory(split)
    if class_name:
        class_id = dataset.cname_to_cid[class_name]
        indices_by_class = dataset._make_indices_by_class()
        nb_points = len(indices_by_class[class_id])
        list_idx = torch.Tensor(indices_by_class[class_id])
        rand_idx = torch.randperm(list_idx.size(0))
        list_idx = list_idx[rand_idx]
        list_idx = list_idx.view(-1).int()
        dir_visu = os.path.join(dir_exp, 'visu', 'remove_ingrs_item:{}_nb_points:{}_class:{}'.format(chosen_item_id, nb_points, class_name.replace(' ', '_')))
    else:
        nb_points = 1000
        list_idx = torch.randperm(len(dataset))
        dir_visu = os.path.join(dir_exp, 'visu', 'remove_ingrs_item:{}_nb_points:{}_removed'.format(chosen_item_id, nb_points))


    # for i in range(20):
    #     item_id = list_idx[i]
    #     item = dataset[item_id]
    #     write_img_rcp(dir_visu, item, top=i)

    Logger()('Load model...')
    model = model_factory()
    model_state = torch.load(path_model_ckpt)
    model.load_state_dict(model_state)
    model.eval()

    item = dataset[chosen_item_id]

    # from tqdm import tqdm
    # ids = []
    # for i in tqdm(range(len(dataset.recipes_dataset))):
    #     item = dataset.recipes_dataset[i]#23534]
    #     if 'broccoli' in item['ingrs']['interim']:
    #         print('broccoli', i)
    #         ids.append(i)
            
    #     # if 'mushroom' in item['ingrs']['interim']:
    #     #     print('mushroom', i)
    #     #     break

    # import ipdb; ipdb.set_trace()

    

    # input_ = {
    #     'recipe': {
    #         'ingrs': {
    #             'data': item['recipe']['ingrs']['data'],
    #             'lengths': item['recipe']['ingrs']['lengths']
    #         },
    #         'instrs': {
    #             'data': item['recipe']['instrs']['data'],
    #             'lengths': item['recipe']['instrs']['lengths']
    #         }
    #     }
    # }

    instrs = torch.FloatTensor(6,1024)
    instrs[0] = item['recipe']['instrs']['data'][0]
    instrs[1] = item['recipe']['instrs']['data'][1]
    instrs[2] = item['recipe']['instrs']['data'][3]
    instrs[3] = item['recipe']['instrs']['data'][4]
    instrs[4] = item['recipe']['instrs']['data'][6]
    instrs[5] = item['recipe']['instrs']['data'][7]

    ingrs = torch.LongTensor([612,585,844,3087,144,188,1])

    input_ = {
        'recipe': {
            'ingrs': {
                'data': ingrs,
                'lengths': ingrs.size(0)
            },
            'instrs': {
                'data': instrs,
                'lengths': instrs.size(0)
            }
        }
    }

    batch = dataset.items_tf()([input_])
    batch = model.prepare_batch(batch)
    out = model.network.recipe_embedding(batch['recipe'])

    # path_rcp = os.path.join(dir_rcp, '{}.pth'.format(23534))
    # rcp_emb = torch.load(path_rcp)
    

    Logger()('Load embeddings...')
    img_embs = []
    for i in range(nb_points):
        try:
            idx = list_idx[i]
        except:
            import ipdb; ipdb.set_trace()
        #idx = i
        path_img = os.path.join(dir_img, '{}.pth'.format(idx))
        if not os.path.isfile(path_img):
            Logger()('No such file: {}'.format(path_img))
            continue
        img_embs.append(torch.load(path_img))

    img_embs = torch.stack(img_embs, 0)

    Logger()('Fast distance...')

    dist = fast_distance(out.data.cpu().expand_as(img_embs), img_embs)
    dist = dist[0, :]
    sorted_ids = np.argsort(dist.numpy())

    os.system('rm -rf '+dir_visu)
    os.system('mkdir -p '+dir_visu)

    Logger()('Load/save images in {}...'.format(dir_visu))
    write_img_rcp(dir_visu, item, top=0, begin_with='query')
    for i in range(20):
        idx = int(sorted_ids[i])
        item_id = list_idx[idx]
        #item_id = idx
        item = dataset[item_id]
        write_img_rcp(dir_visu, item, top=i, begin_with='nn')

    Logger()('End')
def main():

    Logger('.')

    #classes = ['pizza', 'pork chops', 'cupcake', 'hamburger', 'green beans']
    nb_points = 1000
    split = 'test'
    dir_exp = '/home/cadene/doc/bootstrap.pytorch/logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract', split)
    dir_extract_mean = os.path.join(dir_exp, 'extract_mean_features', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model_ckpt = os.path.join(
        dir_exp,
        'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')

    dir_visu = os.path.join(dir_exp, 'visu', 'mean_to_image')
    os.system('mkdir -p ' + dir_visu)

    #Options(path_opts)
    Options.load_from_yaml(path_opts)
    utils.set_random_seed(Options()['misc']['seed'])

    dataset = factory(split)

    Logger()('Load model...')
    model = model_factory()
    model_state = torch.load(path_model_ckpt)
    model.load_state_dict(model_state)
    model.set_mode(split)

    #emb = network.recipe_embedding.forward_ingrs(input_['recipe']['ingrs'])
    list_idx = torch.randperm(len(dataset))

    Logger()('Load embeddings...')
    img_embs = []
    rcp_embs = []
    for i in range(nb_points):
        idx = list_idx[i]
        path_img = os.path.join(dir_img, '{}.pth'.format(idx))
        path_rcp = os.path.join(dir_rcp, '{}.pth'.format(idx))
        if not os.path.isfile(path_img):
            Logger()('No such file: {}'.format(path_img))
            continue
        if not os.path.isfile(path_rcp):
            Logger()('No such file: {}'.format(path_rcp))
            continue
        img_embs.append(torch.load(path_img))
        rcp_embs.append(torch.load(path_rcp))

    img_embs = torch.stack(img_embs, 0)
    rcp_embs = torch.stack(rcp_embs, 0)

    Logger()('Load means')
    path_ingrs = os.path.join(dir_extract_mean, 'ingrs.pth')
    path_instrs = os.path.join(dir_extract_mean, 'instrs.pth')

    mean_ingrs = torch.load(path_ingrs)
    mean_instrs = torch.load(path_instrs)

    mean_ingrs = Variable(mean_ingrs.unsqueeze(0).cuda(), requires_grad=False)
    mean_instrs = Variable(mean_instrs.unsqueeze(0).cuda(),
                           requires_grad=False)

    Logger()('Forward ingredient...')
    ingr_emb = model.network.recipe_embedding.forward_ingrs_instrs(
        mean_ingrs, mean_instrs)
    ingr_emb = ingr_emb.data.cpu()
    ingr_emb = ingr_emb.expand_as(img_embs)

    Logger()('Fast distance...')
    dist = fast_distance(img_embs, ingr_emb)[:, 0]

    sorted_img_ids = np.argsort(dist.numpy())

    Logger()('Load/save images...')
    for i in range(20):
        img_id = sorted_img_ids[i]
        img_id = int(img_id)

        path_img_from = dataset[img_id]['image']['path']
        path_img_to = os.path.join(dir_visu, 'image_top_{}.png'.format(i + 1))
        img = Image.open(path_img_from)
        img.save(path_img_to)
        #os.system('cp {} {}'.format(path_img_from, path_img_to))

    Logger()('End')
def main():

    Logger('.')

    #classes = ['pizza', 'pork chops', 'cupcake', 'hamburger', 'green beans']
    nb_points = 1000
    split = 'test'
    class_name = 'pizza'
    dir_exp = '/home/ubuntu/moochi/recipe1m.bootstrap.pytorch/logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract', split)
    dir_extract_mean = os.path.join(dir_exp, 'extract_mean_features', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model_ckpt = os.path.join(
        dir_exp,
        'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')

    is_mean = True
    ingrs_list = ['fresh_strawberries']  #['avocado']

    Options(path_opts)
    Options()['misc']['seed'] = 2
    utils.set_random_seed(Options()['misc']['seed'])

    dataset = factory(split)

    Logger()('Load model...')
    model = model_factory()
    model_state = torch.load(path_model_ckpt)
    model.load_state_dict(model_state)
    model.eval()

    if not os.path.isdir(dir_extract):
        os.system('mkdir -p ' + dir_rcp)
        os.system('mkdir -p ' + dir_img)

        for i in tqdm(range(len(dataset))):
            item = dataset[i]
            batch = dataset.items_tf()([item])

            if model.is_cuda:
                batch = model.cuda_tf()(batch)

            is_volatile = (model.mode not in ['train', 'trainval'])
            batch = model.variable_tf(volatile=is_volatile)(batch)

            out = model.network(batch)

            path_image = os.path.join(dir_img, '{}.pth'.format(i))
            path_recipe = os.path.join(dir_rcp, '{}.pth'.format(i))
            torch.save(out['image_embedding'][0].data.cpu(), path_image)
            torch.save(out['recipe_embedding'][0].data.cpu(), path_recipe)

    # b = dataset.make_batch_loader().__iter__().__next__()
    # import ipdb; ipdb.set_trace()

    ingrs = torch.LongTensor(1, len(ingrs_list))
    for i, ingr_name in enumerate(ingrs_list):
        ingrs[0, i] = dataset.recipes_dataset.ingrname_to_ingrid[ingr_name]

    input_ = {
        'recipe': {
            'ingrs': {
                'data': Variable(ingrs.cuda(), requires_grad=False),
                'lengths': [ingrs.size(1)]
            },
            'instrs': {
                'data':
                Variable(torch.FloatTensor(1, 1, 1024).fill_(0).cuda(),
                         requires_grad=False),
                'lengths': [1]
            }
        }
    }

    #emb = network.recipe_embedding.forward_ingrs(input_['recipe']['ingrs'])
    #list_idx = torch.randperm(len(dataset))

    indices_by_class = dataset._make_indices_by_class()

    #import ipdb; ipdb.set_trace()

    class_id = dataset.cname_to_cid[class_name]
    list_idx = torch.Tensor(indices_by_class[class_id])
    rand_idx = torch.randperm(list_idx.size(0))
    list_idx = list_idx[rand_idx]
    list_idx = list_idx.view(-1).int()

    nb_points = list_idx.size(0)

    dir_visu = os.path.join(
        dir_exp, 'visu',
        'ingrs_to_image_nb_points:{}_class:{}_instrs:{}_mean:{}_v2'.format(
            nb_points, class_name, '-'.join(ingrs_list), is_mean))
    os.system('mkdir -p ' + dir_visu)

    Logger()('Load embeddings...')
    img_embs = []
    rcp_embs = []
    for i in range(nb_points):
        idx = list_idx[i]
        path_img = os.path.join(dir_img, '{}.pth'.format(idx))
        path_rcp = os.path.join(dir_rcp, '{}.pth'.format(idx))
        if not os.path.isfile(path_img):
            Logger()('No such file: {}'.format(path_img))
            continue
        if not os.path.isfile(path_rcp):
            Logger()('No such file: {}'.format(path_rcp))
            continue
        img_embs.append(torch.load(path_img))
        rcp_embs.append(torch.load(path_rcp))

    img_embs = torch.stack(img_embs, 0)
    rcp_embs = torch.stack(rcp_embs, 0)

    Logger()('Load mean embeddings')

    path_ingrs = os.path.join(dir_extract_mean, 'ingrs.pth')
    path_instrs = os.path.join(dir_extract_mean, 'instrs.pth')

    mean_ingrs = torch.load(path_ingrs)
    mean_instrs = torch.load(path_instrs)

    Logger()('Forward ingredient...')
    #ingr_emb = model.network.recipe_embedding(input_['recipe'])
    ingr_emb = model.network.recipe_embedding.forward_one_ingr(
        input_['recipe']['ingrs'], emb_instrs=mean_instrs.unsqueeze(0))

    ingr_emb = ingr_emb.data.cpu()
    ingr_emb = ingr_emb.expand_as(img_embs)

    Logger()('Fast distance...')
    dist = fast_distance(img_embs, ingr_emb)[:, 0]

    sorted_ids = np.argsort(dist.numpy())

    Logger()('Load/save images in {}...'.format(dir_visu))
    for i in range(20):
        idx = int(sorted_ids[i])
        item_id = list_idx[idx]
        item = dataset[item_id]
        # path_img_from = item['image']['path']
        # ingrs = [ingr.replace('/', '\'') for ingr in item['recipe']['ingrs']['interim']]
        # cname = item['recipe']['class_name']
        # path_img_to = os.path.join(dir_visu, 'image_top:{}_ingrs:{}_cname:{}.png'.format(i+1, '-'.join(ingrs), cname))
        # img = Image.open(path_img_from)
        # img.save(path_img_to)
        #os.system('cp {} {}'.format(path_img_from, path_img_to))

        write_img_rcp(dir_visu, item, top=i, begin_with='nn')

    Logger()('End')
Exemplo n.º 7
0
def main():

    Logger('.')

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('modality_to_modality', help='foo help', default='recipe_to_image')
    args = parser.parse_args()


    #classes = ['pizza', 'pork chops', 'cupcake', 'hamburger', 'green beans']
    nb_points = 1000
    modality_to_modality = args.modality_to_modality#'image_to_image'
    print(modality_to_modality)
    split = 'test'
    dir_exp = '/home/ubuntu/moochi/recipe1m.bootstrap.pytorch/logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model_ckpt = os.path.join(dir_exp, 'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')

    Options.load_from_yaml(path_opts)
    Options()['misc']['seed'] = 11
    utils.set_random_seed(Options()['misc']['seed'])

    dataset = factory(split)

    Logger()('Load model...')
    model = model_factory()
    model_state = torch.load(path_model_ckpt)
    model.load_state_dict(model_state)
    model.eval()

    if not os.path.isdir(dir_extract):
        os.system('mkdir -p '+dir_rcp)
        os.system('mkdir -p '+dir_img)

        for i in tqdm(range(len(dataset))):
            item = dataset[i]
            batch = dataset.items_tf()([item])

            if model.is_cuda:
                batch = model.cuda_tf()(batch)

            is_volatile = (model.mode not in ['train', 'trainval'])
            batch = model.variable_tf(volatile=is_volatile)(batch)

            out = model.network(batch)

            path_image = os.path.join(dir_img, '{}.pth'.format(i))
            path_recipe = os.path.join(dir_rcp, '{}.pth'.format(i))
            torch.save(out['image_embedding'][0].data.cpu(), path_image)
            torch.save(out['recipe_embedding'][0].data.cpu(), path_recipe)



    indices_by_class = dataset._make_indices_by_class()

    # class_name = classes[0] # TODO
    # class_id = dataset.cname_to_cid[class_name]
    # list_idx = torch.Tensor(indices_by_class[class_id])
    # rand_idx = torch.randperm(list_idx.size(0))
    # list_idx = list_idx[rand_idx]
    # list_idx = list_idx.view(-1).int()
    list_idx = torch.randperm(len(dataset))

    #nb_points = list_idx.size(0)

    dir_visu = os.path.join(dir_exp, 'visu', '{}_top20_seed:{}'.format(modality_to_modality, Options()['misc']['seed']))
    os.system('rm -rf '+dir_visu)
    os.system('mkdir -p '+dir_visu)

    Logger()('Load embeddings...')
    img_embs = []
    rcp_embs = []
    for i in range(nb_points):
        idx = list_idx[i]
        #idx = i
        path_img = os.path.join(dir_img, '{}.pth'.format(idx))
        path_rcp = os.path.join(dir_rcp, '{}.pth'.format(idx))
        if not os.path.isfile(path_img):
            Logger()('No such file: {}'.format(path_img))
            continue
        if not os.path.isfile(path_rcp):
            Logger()('No such file: {}'.format(path_rcp))
            continue
        img_embs.append(torch.load(path_img))
        rcp_embs.append(torch.load(path_rcp))

    img_embs = torch.stack(img_embs, 0)
    rcp_embs = torch.stack(rcp_embs, 0)

    # Logger()('Forward ingredient...')
    # #ingr_emb = model.network.recipe_embedding(input_['recipe'])
    # ingr_emb = model.network.recipe_embedding.forward_one_ingr(
    #     input_['recipe']['ingrs'],
    #     emb_instrs=mean_instrs.unsqueeze(0))

    # ingr_emb = ingr_emb.data.cpu()
    # ingr_emb = ingr_emb.expand_as(img_embs)


    Logger()('Fast distance...')

    if modality_to_modality == 'image_to_recipe':
        dist = fast_distance(img_embs, rcp_embs)
    elif modality_to_modality == 'recipe_to_image':
        dist = fast_distance(rcp_embs, img_embs)
    elif modality_to_modality == 'recipe_to_recipe':
        dist = fast_distance(rcp_embs, rcp_embs)
    elif modality_to_modality == 'image_to_image':
        dist = fast_distance(img_embs, img_embs)

    dist=dist[:, 0]
    sorted_ids = np.argsort(dist.numpy())

    Logger()('Load/save images in {}...'.format(dir_visu))
    for i in range(20):
        idx = int(sorted_ids[i])
        item_id = list_idx[idx]
        #item_id = idx
        item = dataset[item_id]
        write_img_rcp(dir_visu, item, top=i)
        #os.system('cp {} {}'.format(path_img_from, path_img_to))


    Logger()('End')
Exemplo n.º 8
0
def main():

    Logger('.')

    #classes = ['pizza', 'pork chops', 'cupcake', 'hamburger', 'green beans']
    nb_points = 1000
    split = 'test'
    dir_exp = 'logs/recipe1m/trijoint/2017-12-14-15-04-51'
    path_opts = os.path.join(dir_exp, 'options.yaml')
    dir_extract = os.path.join(dir_exp, 'extract', split)
    dir_img = os.path.join(dir_extract, 'image')
    dir_rcp = os.path.join(dir_extract, 'recipe')
    path_model_ckpt = os.path.join(
        dir_exp,
        'ckpt_best_val_epoch.metric.recall_at_1_im2recipe_mean_model.pth.tar')

    #Options(path_opts)
    Options.load_from_yaml(path_opts)
    utils.set_random_seed(Options()['misc']['seed'])

    dataset = factory(split)

    Logger()('Load model...')
    model = model_factory()
    model_state = torch.load(path_model_ckpt)
    model.load_state_dict(model_state)
    model.set_mode(split)

    if not os.path.isdir(dir_extract):
        os.system('mkdir -p ' + dir_rcp)
        os.system('mkdir -p ' + dir_img)

        for i in tqdm(range(len(dataset))):
            item = dataset[i]
            batch = dataset.items_tf()([item])

            if model.is_cuda:
                batch = model.cuda_tf()(batch)

            is_volatile = (model.mode not in ['train', 'trainval'])
            batch = model.variable_tf(volatile=is_volatile)(batch)

            out = model.network(batch)

            path_image = os.path.join(dir_img, '{}.pth'.format(i))
            path_recipe = os.path.join(dir_rcp, '{}.pth'.format(i))
            torch.save(out['image_embedding'][0].data.cpu(), path_image)
            torch.save(out['recipe_embedding'][0].data.cpu(), path_recipe)

    # b = dataset.make_batch_loader().__iter__().__next__()
    # class_name = 'pizza'
    # ingrs = torch.LongTensor(1, 2)
    # ingrs[0, 0] = dataset.recipes_dataset.ingrname_to_ingrid['mushrooms']
    # ingrs[0, 1] = dataset.recipes_dataset.ingrname_to_ingrid['mushroom']

    class_name = 'hamburger'
    ingrs = torch.LongTensor(1, 2)
    ingrs[0, 0] = dataset.recipes_dataset.ingrname_to_ingrid['mushroom']
    ingrs[0, 1] = dataset.recipes_dataset.ingrname_to_ingrid['mushrooms']

    #ingrs[0, 0] = dataset.recipes_dataset.ingrname_to_ingrid['tomato']
    #ingrs[0, 1] = dataset.recipes_dataset.ingrname_to_ingrid['salad']
    #ingrs[0, 2] = dataset.recipes_dataset.ingrname_to_ingrid['onion']
    #ingrs[0, 3] = dataset.recipes_dataset.ingrname_to_ingrid['chicken']

    input_ = {
        'recipe': {
            'ingrs': {
                'data': Variable(ingrs.cuda(), requires_grad=False),
                'lengths': [ingrs.size(1)]
            },
            'instrs': {
                'data':
                Variable(torch.FloatTensor(1, 1, 1024).fill_(0).cuda(),
                         requires_grad=False),
                'lengths': [1]
            }
        }
    }

    #emb = network.recipe_embedding.forward_ingrs(input_['recipe']['ingrs'])
    #list_idx = torch.randperm(len(dataset))

    indices_by_class = dataset._make_indices_by_class()
    class_id = dataset.cname_to_cid[class_name]
    list_idx = torch.Tensor(indices_by_class[class_id])
    rand_idx = torch.randperm(list_idx.size(0))
    list_idx = list_idx[rand_idx]

    list_idx = list_idx.view(-1).int()

    img_embs = []
    rcp_embs = []

    if nb_points > list_idx.size(0):
        nb_points = list_idx.size(0)

    Logger()('Load {} embeddings...'.format(nb_points))
    for i in range(nb_points):
        idx = list_idx[i]
        path_img = os.path.join(dir_img, '{}.pth'.format(idx))
        path_rcp = os.path.join(dir_rcp, '{}.pth'.format(idx))
        if not os.path.isfile(path_img):
            Logger()('No such file: {}'.format(path_img))
            continue
        if not os.path.isfile(path_rcp):
            Logger()('No such file: {}'.format(path_rcp))
            continue
        img_embs.append(torch.load(path_img))
        rcp_embs.append(torch.load(path_rcp))

    img_embs = torch.stack(img_embs, 0)
    rcp_embs = torch.stack(rcp_embs, 0)

    Logger()('Forward ingredient...')
    #ingr_emb = model.network.recipe_embedding(input_['recipe'])
    ingr_emb = model.network.recipe_embedding.forward_one_ingr(
        input_['recipe']['ingrs'])

    ingr_emb = ingr_emb.data.cpu()
    ingr_emb = ingr_emb.expand_as(img_embs)

    Logger()('Fast distance...')
    dist = fast_distance(img_embs, ingr_emb)[:, 0]

    sorted_ids = np.argsort(dist.numpy())

    dir_visu = os.path.join(dir_exp, 'visu',
                            'ingrs_to_image_by_class_{}'.format(class_name))
    os.system('mkdir -p ' + dir_visu)

    Logger()('Load/save images...')
    for i in range(20):
        idx = int(sorted_ids[i])
        item_id = list_idx[idx]
        item = dataset[item_id]
        Logger()(item['recipe']['class_name'])
        Logger()(item['image']['class_name'])
        path_img_from = item['image']['path']
        path_img_to = os.path.join(dir_visu, 'image_top_{}.png'.format(i + 1))
        img = Image.open(path_img_from)
        img.save(path_img_to)
        #os.system('cp {} {}'.format(path_img_from, path_img_to))

    Logger()('End')