def evaluate(args):

    data = args.data_set
    print("Evaluating trained model on ", data)
    img_path = "data/resized_images/" + data
    train_cap = "data/captions/cap." + data + ".train.json"
    val_cap = "data/captions/cap." + data + ".val.json"
    dict_vocab = "data/captions/dict." + data + ".json"
    val_set = "data/image_splits/split." + data + ".val.json"

    transform_dev = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    ranker = Ranker("data/resized_images/" + data,
                    "data/image_splits/split." + args.data_set + ".val.json",
                    transform=transform_dev,
                    num_workers=args.num_workers)

    vocab = Vocabulary()

    vocab.load(dict_vocab)

    data_loader_dev = get_loader(img_path,
                                 val_cap,
                                 vocab,
                                 transform_dev,
                                 args.batch_size,
                                 shuffle=False,
                                 return_target=True,
                                 num_workers=args.num_workers)

    image_encoder = ImageEncoder_MulGate_3_2Seq_SA_multiple_sentence_attention_map_soft_multiple(
        args.embed_size).cuda()
    caption_encoder = DummyCaptionEncoder_without_embed_multiple_random_embeddings(
        vocab_size=len(vocab),
        vocab_embed_size=2 * args.embed_size,
        embed_size=args.embed_size).cuda()

    image_encoder.load_state_dict(
        torch.load("models/imagenet_randomemb_image_" + data + ".pth",
                   map_location='cuda:0'))

    caption_encoder.load_state_dict(
        torch.load("models/imagenet_randomemb_text_" + data + ".pth",
                   map_location='cuda:0'))

    image_encoder.eval()
    caption_encoder.eval()

    results = eval_batch(data_loader_dev, image_encoder, caption_encoder,
                         ranker)
示例#2
0
def evaluate(args):
    # Image pre-processing, normalization for the pre-trained resnet
    transform_test = transforms.Compose([
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    vocab = Vocabulary()
    vocab.load(DICT.format(args.data_set))
    # Build data loader
    data_loader_test = get_loader(IMAGE_ROOT.format(args.data_set),
                                  CAPT.format(args.data_set, args.data_split),
                                  vocab,
                                  transform_test,
                                  args.batch_size,
                                  shuffle=False,
                                  return_target=False,
                                  num_workers=args.num_workers)
    ranker = Ranker(root=IMAGE_ROOT.format(args.data_set),
                    image_split_file=SPLIT.format(args.data_set,
                                                  args.data_split),
                    transform=transform_test,
                    num_workers=args.num_workers)

    # Build the dummy models
    image_encoder = DummyImageEncoder(args.embed_size).to(device)
    caption_encoder = DummyCaptionEncoder(
        vocab_size=len(vocab),
        vocab_embed_size=args.embed_size * 2,
        embed_size=args.embed_size).to(device)
    # load trained models
    image_model = os.path.join(args.model_folder,
                               "image-{}.th".format(args.embed_size))
    resnet = image_encoder.delete_resnet()
    image_encoder.load_state_dict(torch.load(image_model, map_location=device))
    image_encoder.load_resnet(resnet)

    cap_model = os.path.join(args.model_folder,
                             "cap-{}.th".format(args.embed_size))
    caption_encoder.load_state_dict(torch.load(cap_model, map_location=device))

    ranker.update_emb(image_encoder)
    image_encoder.eval()
    caption_encoder.eval()

    output = json.load(open(CAPT.format(args.data_set, args.data_split)))

    index = 0
    for _, candidate_images, captions, lengths, meta_info in data_loader_test:
        with torch.no_grad():
            candidate_images = candidate_images.to(device)
            candidate_ft = image_encoder.forward(candidate_images)
            captions = captions.to(device)
            caption_ft = caption_encoder(captions, lengths)
            rankings = ranker.get_nearest_neighbors(candidate_ft + caption_ft)
            # print(rankings)
            for j in range(rankings.size(0)):
                output[index]['ranking'] = [
                    ranker.data_asin[rankings[j, m].item()]
                    for m in range(rankings.size(1))
                ]
                index += 1

    json.dump(output,
              open("{}.{}.pred.json".format(args.data_set, args.data_split),
                   'w'),
              indent=4)
    print('eval completed. Output file: {}'.format("{}.{}.pred.json".format(
        args.data_set, args.data_split)))
示例#3
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('-data_train', type=str, default="")
    parser.add_argument('-data_dev', required=True)
    parser.add_argument('-data_test', type=str, default="")
    parser.add_argument('-vocab', required=True)

    parser.add_argument('-epoch', type=int, default=10000)
    parser.add_argument('-batch_size', type=int, default=64)

    #parser.add_argument('-d_word_vec', type=int, default=512)
    parser.add_argument('-d_model', type=int, default=512)
    # parser.add_argument('-d_inner_hid', type=int, default=2048)
    # parser.add_argument('-d_k', type=int, default=64)
    # parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_heads', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    # parser.add_argument('-embs_share_weight', action='store_true')
    # parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')
    parser.add_argument('-num_workers', type=int, default=1)

    parser.add_argument('-cnn_name', type=str, default="resnet101")
    parser.add_argument('-cnn_pretrained_model', type=str, default="")
    parser.add_argument('-joint_enc_func', type=str, default="element_multiplication")
    # parser.add_argument('-comparative_module_name', type=str, default="transformer_encoder")
    parser.add_argument('-lr', type=float, default=0.01)
    # parser.add_argument('-step_size', type=int, default=1000)
    # parser.add_argument('-gamma', type=float, default=0.9)
    parser.add_argument('-crop_size', type=int, default=224)
    parser.add_argument('-max_seq_len', type=int, default=64)
    parser.add_argument('-attribute_len', type=int, default=5)

    parser.add_argument('-pretrained_model', type=str, default="")

    parser.add_argument('-rank_alpha', type=float, default=1.0)
    parser.add_argument('-patience', type=int, default=7)
    parser.add_argument('-bleu_valid_every_n', type=int, default=5)
    parser.add_argument('-data_dev_combined', required=True)
    parser.add_argument('-beam_size', type=int, default=5)
    parser.add_argument('-seed', type=int, default=0)
    parser.add_argument('-attribute_vocab_size', type=int, default=1000)
    parser.add_argument('-add_attribute', action='store_true')

    

    args = parser.parse_args()
    args.cuda = not args.no_cuda
    args.d_word_vec = args.d_model

    args.load_weights = False
    if args.pretrained_model:
        args.load_weights = True

    np.random.seed(0)
    torch.manual_seed(0)
    args.device = torch.device('cuda' if args.cuda else 'cpu')

    log_path = args.log.split("/")
    log_path = "/".join(log_path[:-1])
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    model_path = args.save_model.split("/")
    model_path = "/".join(model_path[:-1])
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    print(args)

    if args.data_train:
        print("======================================start training======================================")
        transform = transforms.Compose([ 
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

        transform_dev = transforms.Compose([
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        vocab = Vocabulary()
        
        vocab.load(args.vocab)

        args.vocab_size = len(vocab)

        # Build data loader
        data_loader_training = get_loader(args.data_train,
                                             vocab, transform,
                                             args.batch_size, shuffle=True, num_workers=args.num_workers, \
                                             max_seq_len=args.max_seq_len,\
                                             attribute_len=args.attribute_len
                                         )

        data_loader_dev = get_loader(args.data_dev,
                                     vocab, transform_dev,
                                     args.batch_size, shuffle=False, num_workers=args.num_workers, \
                                     max_seq_len=args.max_seq_len,\
                                     attribute_len=args.attribute_len
                                     )

        data_loader_bleu = get_loader_test(args.data_dev_combined,
                                     vocab, transform_dev,
                                     1, shuffle=False,
                                    attribute_len=args.attribute_len
                                     )

        list_of_refs_dev = load_ori_token_data_new(args.data_dev_combined)

        model = get_model(args, load_weights=False)


        print(count_parameters(model))

        # print(model.get_trainable_parameters())
        # init_lr = np.power(args.d_model, -0.5)

        # optimizer = torch.optim.Adam(model.get_trainable_parameters(), lr=init_lr)
        optimizer = get_std_opt(model, args)
        
        train( model, data_loader_training, data_loader_dev, optimizer ,args, vocab, list_of_refs_dev, data_loader_bleu)

    if args.data_test:
        print("======================================start testing==============================")
        args.pretrained_model = args.save_model 
        test(args)
示例#4
0
def test(opt):

    transform = transforms.Compose([
        transforms.CenterCrop(opt.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    vocab = Vocabulary()

    vocab.load(opt.vocab)

    data_loader = get_loader_test(opt.data_test,
                                  vocab,
                                  transform,
                                  opt.batch_size,
                                  shuffle=False,
                                  attribute_len=opt.attribute_len)

    list_of_refs = load_ori_token_data_new(opt.data_test)

    model = get_model(opt, load_weights=True)

    count = 0

    hypotheses = {}

    model.eval()

    for batch in tqdm(data_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):

        image0, image1, image0_attribute, image1_attribute = map(
            lambda x: x.to(opt.device), batch)

        hyp = beam_search(image0, image1, model, opt, vocab, image0_attribute,
                          image1_attribute)
        #         hyp = greedy_search(image1.to(device), image2.to(device), model, opt, vocab)

        hyp = hyp.split("<end>")[0].strip()

        hypotheses[count] = ["it " + hyp]

        count += 1

    # =================================================
    # Set up scorers
    # =================================================
    print('setting up scorers...')
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        # (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        # (Cider(), "CIDEr"),
        (Cider(), "CIDEr"),
        (CiderD(), "CIDEr-D")
        # (Spice(), "SPICE")
    ]

    for scorer, method in scorers:
        print('computing %s score...' % (scorer.method()))
        score, scores = scorer.compute_score(list_of_refs, hypotheses)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                # self.setEval(sc, m)
                # self.setImgToEvalImgs(scs, gts.keys(), m)
                print("%s: %0.3f" % (m, sc))
        else:
            # self.setEval(score, method)
            # self.setImgToEvalImgs(scores, gts.keys(), method)
            print("%s: %0.3f" % (method, score))

    for i in range(len(hypotheses)):
        ref = {i: list_of_refs[i]}
        hyp = {i: hypotheses[i]}
        print(ref)
        print(hyp)
        for scorer, method in scorers:
            print('computing %s score...' % (scorer.method()))
            score, scores = scorer.compute_score(ref, hyp)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    # self.setEval(sc, m)
                    # self.setImgToEvalImgs(scs, gts.keys(), m)
                    print("%s: %0.3f" % (m, sc))
            else:
                # self.setEval(score, method)
                # self.setImgToEvalImgs(scores, gts.keys(), method)
                print("%s: %0.3f" % (method, score))
示例#5
0
def train(args):
    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    transform_dev = transforms.Compose([
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    vocab = Vocabulary()
    vocab.load(DICT.format(args.data_set))

    # Build data loader
    data_loader = get_loader(IMAGE_ROOT.format(args.data_set),
                             CAPT.format(args.data_set, 'train'),
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             return_target=True,
                             num_workers=args.num_workers)

    data_loader_dev = get_loader(IMAGE_ROOT.format(args.data_set),
                                 CAPT.format(args.data_set, 'val'),
                                 vocab,
                                 transform_dev,
                                 args.batch_size,
                                 shuffle=False,
                                 return_target=True,
                                 num_workers=args.num_workers)

    ranker = Ranker(root=IMAGE_ROOT.format(args.data_set),
                    image_split_file=SPLIT.format(args.data_set, 'val'),
                    transform=transform_dev,
                    num_workers=args.num_workers)

    save_folder = '{}/{}-{}'.format(args.save, args.data_set,
                                    time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(save_folder,
                   scripts_to_save=[
                       'models.py', 'data_loader.py', 'train.py',
                       'build_vocab.py', 'utils.py'
                   ])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(save_folder, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging(str(args))
    # Build the dummy models
    image_encoder = DummyImageEncoder(args.embed_size).to(device)
    caption_encoder = DummyCaptionEncoder(
        vocab_size=len(vocab),
        vocab_embed_size=args.embed_size * 2,
        embed_size=args.embed_size).to(device)

    image_encoder.train()
    caption_encoder.train()
    params = image_encoder.get_trainable_parameters(
    ) + caption_encoder.get_trainable_parameters()

    current_lr = args.learning_rate
    optimizer = torch.optim.Adam(params, lr=current_lr)

    cur_patient = 0
    best_score = float('-inf')
    stop_train = False
    total_step = len(data_loader)
    # epoch = 1 for dummy setting
    for epoch in range(1):

        for i, (target_images, candidate_images, captions, lengths,
                meta_info) in enumerate(data_loader):

            target_images = target_images.to(device)
            target_ft = image_encoder.forward(target_images)

            candidate_images = candidate_images.to(device)
            candidate_ft = image_encoder.forward(candidate_images)

            captions = captions.to(device)
            caption_ft = caption_encoder(captions, lengths)

            # random select negative examples
            m = target_images.size(0)
            random_index = [m - 1 - n for n in range(m)]
            random_index = torch.LongTensor(random_index)
            negative_ft = target_ft[random_index]

            loss = triplet_avg(anchor=(candidate_ft + caption_ft),
                               positive=target_ft,
                               negative=negative_ft)

            caption_encoder.zero_grad()
            image_encoder.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.log_step == 0:
                logging(
                    '| epoch {:3d} | step {:6d}/{:6d} | lr {:06.6f} | train loss {:8.3f}'
                    .format(epoch, i, total_step, current_lr, loss.item()))

        image_encoder.eval()
        caption_encoder.eval()
        logging('-' * 77)
        metrics = eval_batch(data_loader_dev, image_encoder, caption_encoder,
                             ranker)
        logging('| eval loss: {:8.3f} | score {:8.5f} / {:8.5f} '.format(
            metrics['loss'], metrics['score'], best_score))
        logging('-' * 77)

        image_encoder.train()
        caption_encoder.train()

        dev_score = metrics['score']
        if dev_score > best_score:
            best_score = dev_score
            # save best model
            resnet = image_encoder.delete_resnet()
            torch.save(
                image_encoder.state_dict(),
                os.path.join(save_folder,
                             'image-{}.th'.format(args.embed_size)))
            image_encoder.load_resnet(resnet)

            torch.save(
                caption_encoder.state_dict(),
                os.path.join(save_folder, 'cap-{}.th'.format(args.embed_size)))

            cur_patient = 0
        else:
            cur_patient += 1
            if cur_patient >= args.patient:
                current_lr /= 2
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr
                if current_lr < args.learning_rate * 1e-3:
                    stop_train = True
                    break

        if stop_train:
            break
    logging('best_dev_score: {}'.format(best_score))