Esempio n. 1
0
def vocab_loader(data_loc):
    """
    :param data_loc: the address of corpus
    :return: Vocabulary class variable vocab
    """
    with open(data_loc, 'r') as f:
        words = f.read().split("\n")
    vocab = Vocabulary()
    for word in words:
        vocab.add_word(word)
    return vocab
def evaluate(args):

    data = args.data_set
    print("Evaluating trained model on ", data)
    img_path = "data/resized_images/" + data
    train_cap = "data/captions/cap." + data + ".train.json"
    val_cap = "data/captions/cap." + data + ".val.json"
    dict_vocab = "data/captions/dict." + data + ".json"
    val_set = "data/image_splits/split." + data + ".val.json"

    transform_dev = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    ranker = Ranker("data/resized_images/" + data,
                    "data/image_splits/split." + args.data_set + ".val.json",
                    transform=transform_dev,
                    num_workers=args.num_workers)

    vocab = Vocabulary()

    vocab.load(dict_vocab)

    data_loader_dev = get_loader(img_path,
                                 val_cap,
                                 vocab,
                                 transform_dev,
                                 args.batch_size,
                                 shuffle=False,
                                 return_target=True,
                                 num_workers=args.num_workers)

    image_encoder = ImageEncoder_MulGate_3_2Seq_SA_multiple_sentence_attention_map_soft_multiple(
        args.embed_size).cuda()
    caption_encoder = DummyCaptionEncoder_without_embed_multiple_random_embeddings(
        vocab_size=len(vocab),
        vocab_embed_size=2 * args.embed_size,
        embed_size=args.embed_size).cuda()

    image_encoder.load_state_dict(
        torch.load("models/imagenet_randomemb_image_" + data + ".pth",
                   map_location='cuda:0'))

    caption_encoder.load_state_dict(
        torch.load("models/imagenet_randomemb_text_" + data + ".pth",
                   map_location='cuda:0'))

    image_encoder.eval()
    caption_encoder.eval()

    results = eval_batch(data_loader_dev, image_encoder, caption_encoder,
                         ranker)
def ans_type_to_idx(annotations):
    ans_type = []
    for anno in annotations:
        ans_type.append(anno["answer_type"])

    uniques = list(set(ans_type))
    ans_type_vocab = Vocabulary()

    for i, word in enumerate(uniques):
        ans_type_vocab.add_word(word)

    for anno in annotations:
        anno["answer_type"] = ans_type_vocab(anno["answer_type"])

    return ans_type_vocab
Esempio n. 4
0
def build_vocab_decoder_SemCor(threshold):

    # Create a vocab wrapper and add some special tokens.
    counter = Counter()
    target_file = open(
        "../../WSD_Evaluation_Framework/Training_Corpora/SemCor/semcor.gold.key.txt",
        "r")

    # iterate through all definitions in the SemCor
    for line in target_file:

        # synset and literal definition from the WN
        key = line.replace('\n', '').split(' ')[-1]
        synset = wn.lemma_from_key(key).synset()
        definition = synset.definition()
        def_tokens = nltk.tokenize.word_tokenize(definition)
        counter.update(def_tokens)

    # add SemEval synsets
    semeval_file = open(
        "../../WSD_Evaluation_Framework/Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt",
        "r")
    for line in semeval_file:
        key = line.replace('\n', '').split(' ')[-1]
        synset = wn.lemma_from_key(key).synset()
        definition = synset.definition()
        def_tokens = nltk.tokenize.word_tokenize(definition)
        counter.update(def_tokens)

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)

    print("Total vocabulary size: {}".format(vocab.idx))
    return vocab
def prepare_question_type_vocab(annotations, topk=65):
    counter = Counter()
    print("counting question types ... ")

    for annotation in tqdm(annotations):
        caption = str(annotation['question_type'])
        counter.update([caption])

    print("Top 20 Most Common Questions")
    for word, cnt in counter.most_common(20):
        print("{} - {}".format(word.ljust(10), cnt))

    common_words = [ word for word,cnt in counter.most_common(topk) ]

    question_type_vocab = Vocabulary()
    for word in common_words:
        question_type_vocab.add_word(word)

    return question_type_vocab
Esempio n. 6
0
def test_seer_init():
    delphi = Seer()
    check.is_true(isinstance(delphi.encoder, type(EncoderCNN(1))))
    check.is_true(isinstance(delphi.decoder, type(DecoderRNN(1, 1, 1, 1, 1))))
    check.is_true(delphi.vocab_path == 'torchdata/vocab.pkl')
    check.is_true(delphi.encoder_path == 'torchdata/encoder-5-3000.pkl')
    check.is_true(delphi.decoder_path == 'torchdata/decoder-5-3000.pkl')
    check.is_true(delphi.embed_size == 256)
    check.is_true(delphi.hidden_size == 512)
    check.is_true(delphi.num_layers == 1)
    check.is_true(isinstance(delphi.vocab, type(Vocabulary())))
def prepare_answers_vocab(annotations, topk):
    s = 'multiple_choice_answer'
    
    counter = Counter()
    print("counting occurrences ... ")
    for annotation in tqdm(annotations):
        caption = str(annotation[s])
        counter.update([caption])

    print("Top 20 Most Common Words")
    for word, cnt in counter.most_common(20):
        print("{} - {}".format(word.ljust(10), cnt))
    common_words = [ word for word,cnt in counter.most_common(topk) ]
    #common_words = [ word for word,cnt in counter.iteritems() if cnt >= 100]

    ans_vocab = Vocabulary()
    for word in common_words:
        ans_vocab.add_word(word)

    return ans_vocab
Esempio n. 8
0
def build_vocab_synset():

    i = 0
    total_size = len(set(wn.all_synsets()))

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()

    # iterate through all definitions in the WN
    for synset in wn.all_synsets():

        # convert '.' to '__' for hashing
        synset = synset.name().replace('.', '__')
        vocab.add_word(synset)

        i += 1
        if i % 1000 == 0:
            print("[{}/{}] synsets done.".format(i, total_size))

    print("Total vocabulary size: {}".format(vocab.idx))
    return vocab
Esempio n. 9
0
def build_vocab(threshold):
    """Build a simple vocabulary wrapper."""
    counter = Counter()
    i = 0
    total_size = len(set(wn.all_synsets()))

    # iterate through all definitions in the WN
    for synset in wn.all_synsets():

        definition = synset.definition()
        '''
		remove special characters
		tokenize the definition string 
		'''
        # tokenize and remove special characters
        # s.translate(str.maketrans('', '', string.punctuation))
        # nltk.tokenize.word_tokenize(wn.synset('dog.n.01').definition())
        def_tokens = nltk.tokenize.word_tokenize(definition.lower())
        counter.update(def_tokens)

        i = i + 1
        if i % 1000 == 0:
            print("[{}/{}] Tokenized the definitions.".format(i, total_size))

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)

    print("Total vocabulary size: {}".format(vocab.idx))
    return vocab
Esempio n. 10
0
def build_vocab_synset_SemCor():

    target_file = open(
        "../../WSD_Evaluation_Framework/Training_Corpora/SemCor/semcor.gold.key.txt",
        "r")

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()

    # iterate through all definitions in the SemCor
    for line in target_file:

        # synset and literal definition from the WN
        key = line.replace('\n', '').split(' ')[-1]
        synset = wn.lemma_from_key(key).synset()

        # convert '.' to '__' for hashing
        synset = synset.name().replace('.', '__')
        vocab.add_word(synset)

    # add SemEval synsets
    semeval_file = open(
        "../../WSD_Evaluation_Framework/Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt",
        "r")
    for line in semeval_file:
        key = line.replace('\n', '').split(' ')[-1]
        synset = wn.lemma_from_key(key).synset()
        synset = synset.name().replace('.', '__')
        vocab.add_word(synset)

    print("Total vocabulary size: {} {}".format(vocab.idx, len(vocab)))
    return vocab
def prepare_question_vocab(annotations, field="question"):
    counter = Counter()
    question_vocab = Vocabulary()
    print("Generating question vocabulary")

    for annotation in tqdm(annotations):
        tokens = annotation[field]
        for token in tokens:
            question_vocab.add_word(token)

    question_vocab.add_word('<end>')
    question_vocab.add_word('<unk>')

    return question_vocab
Esempio n. 12
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir, args.caption_path, vocab,
                                args.dictionary, args.batch_size,
                                shuffle=True, num_workers=args.num_workers)

    # Build the models
    #encoder = EncoderCNN(args.embed_size).to(device)
    dictionary = pd.read_csv(args.dictionary, header=0,encoding = 'unicode_escape',error_bad_lines=False)
    dictionary = list(dictionary['keys'])

    decoder = Decoder(len(vocab), len(dictionary), args.units, args.batch_size)

    # Loss and optimizer
    optimizer = tf.train.AdamOptimizer()


    def loss_function(real, pred):
        mask = 1 - np.equal(real, 0)
        loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask
        return tf.reduce_mean(loss_)

    # Train the models
    total_step = len(data_loader)
    # for epoch in range(args.num_epochs):
    #     for i, (array, captions, lengths) in enumerate(data_loader):
    #
    #         # Set mini-batch dataset
    #         array = array.to(device)
    #         captions = captions.to(device)
    #         targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
    #
    #         # Forward, backward and optimize
    #         #features = encoder(images)
    #         outputs = decoder(array, captions, lengths)
    #         loss = criterion(outputs, targets)
    #         decoder.zero_grad()
    #         #encoder.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #
    #         # Print log info
    #         if i % args.log_step == 0:
    #             print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
    #                   .format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
    #
    #         # Save the model checkpoints
    #         if (i+1) % args.save_step == 0:
    #             torch.save(decoder.state_dict(), os.path.join(
    #                 args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))

    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     decoder=decoder)

    EPOCHS = 1

    for epoch in range(EPOCHS):
        start = time.time()

        # hidden = encoder.initialize_hidden_state()
        total_loss = 0

        vocab_ins = Vocabulary()

        # for (batch, (inp, targ)) in enumerate(dataset):
        for i, (array, captions, lengths) in enumerate(data_loader):
            loss = 0
            array = array.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

            with tf.GradientTape() as tape:
                # enc_output, enc_hidden = encoder(captions, hidden)
                dec_hidden = decoder.initialize_hidden_state()

                dec_input = tf.expand_dims([vocab('<start>')] * args.batch_size, 1)
                print(targets.shape)

                # Teacher forcing - feeding the target as the next input
                for t in range(1, targets.shape[0]):
                    # passing enc_output to the decoder
                    predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, array)

                    loss += loss_function(targets[:, t], predictions)

                    # using teacher forcing
                    dec_input = tf.expand_dims(targets[:, t], 1)

            batch_loss = (loss / int(targets.shape[1]))

            total_loss += batch_loss

            variables = decoder.variables

            gradients = tape.gradient(loss, variables)

            optimizer.apply_gradients(zip(gradients, variables))

            if batch % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                             batch,
                                                             batch_loss.numpy()))
        # saving (checkpoint) the model every epoch
        checkpoint.save(file_prefix = checkpoint_prefix)

        print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                            total_loss / N_BATCH))
        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Esempio n. 13
0
def evaluate(args):
    # Image pre-processing, normalization for the pre-trained resnet
    transform_test = transforms.Compose([
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    vocab = Vocabulary()
    vocab.load(DICT.format(args.data_set))
    # Build data loader
    data_loader_test = get_loader(IMAGE_ROOT.format(args.data_set),
                                  CAPT.format(args.data_set, args.data_split),
                                  vocab,
                                  transform_test,
                                  args.batch_size,
                                  shuffle=False,
                                  return_target=False,
                                  num_workers=args.num_workers)
    ranker = Ranker(root=IMAGE_ROOT.format(args.data_set),
                    image_split_file=SPLIT.format(args.data_set,
                                                  args.data_split),
                    transform=transform_test,
                    num_workers=args.num_workers)

    # Build the dummy models
    image_encoder = DummyImageEncoder(args.embed_size).to(device)
    caption_encoder = DummyCaptionEncoder(
        vocab_size=len(vocab),
        vocab_embed_size=args.embed_size * 2,
        embed_size=args.embed_size).to(device)
    # load trained models
    image_model = os.path.join(args.model_folder,
                               "image-{}.th".format(args.embed_size))
    resnet = image_encoder.delete_resnet()
    image_encoder.load_state_dict(torch.load(image_model, map_location=device))
    image_encoder.load_resnet(resnet)

    cap_model = os.path.join(args.model_folder,
                             "cap-{}.th".format(args.embed_size))
    caption_encoder.load_state_dict(torch.load(cap_model, map_location=device))

    ranker.update_emb(image_encoder)
    image_encoder.eval()
    caption_encoder.eval()

    output = json.load(open(CAPT.format(args.data_set, args.data_split)))

    index = 0
    for _, candidate_images, captions, lengths, meta_info in data_loader_test:
        with torch.no_grad():
            candidate_images = candidate_images.to(device)
            candidate_ft = image_encoder.forward(candidate_images)
            captions = captions.to(device)
            caption_ft = caption_encoder(captions, lengths)
            rankings = ranker.get_nearest_neighbors(candidate_ft + caption_ft)
            # print(rankings)
            for j in range(rankings.size(0)):
                output[index]['ranking'] = [
                    ranker.data_asin[rankings[j, m].item()]
                    for m in range(rankings.size(1))
                ]
                index += 1

    json.dump(output,
              open("{}.{}.pred.json".format(args.data_set, args.data_split),
                   'w'),
              indent=4)
    print('eval completed. Output file: {}'.format("{}.{}.pred.json".format(
        args.data_set, args.data_split)))
Esempio n. 14
0
def test(opt):

    transform = transforms.Compose([
        transforms.CenterCrop(opt.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    vocab = Vocabulary()

    vocab.load(opt.vocab)

    data_loader = get_loader_test(opt.data_test,
                                  vocab,
                                  transform,
                                  opt.batch_size,
                                  shuffle=False,
                                  attribute_len=opt.attribute_len)

    list_of_refs = load_ori_token_data_new(opt.data_test)

    model = get_model(opt, load_weights=True)

    count = 0

    hypotheses = {}

    model.eval()

    for batch in tqdm(data_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):

        image0, image1, image0_attribute, image1_attribute = map(
            lambda x: x.to(opt.device), batch)

        hyp = beam_search(image0, image1, model, opt, vocab, image0_attribute,
                          image1_attribute)
        #         hyp = greedy_search(image1.to(device), image2.to(device), model, opt, vocab)

        hyp = hyp.split("<end>")[0].strip()

        hypotheses[count] = ["it " + hyp]

        count += 1

    # =================================================
    # Set up scorers
    # =================================================
    print('setting up scorers...')
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        # (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        # (Cider(), "CIDEr"),
        (Cider(), "CIDEr"),
        (CiderD(), "CIDEr-D")
        # (Spice(), "SPICE")
    ]

    for scorer, method in scorers:
        print('computing %s score...' % (scorer.method()))
        score, scores = scorer.compute_score(list_of_refs, hypotheses)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                # self.setEval(sc, m)
                # self.setImgToEvalImgs(scs, gts.keys(), m)
                print("%s: %0.3f" % (m, sc))
        else:
            # self.setEval(score, method)
            # self.setImgToEvalImgs(scores, gts.keys(), method)
            print("%s: %0.3f" % (method, score))

    for i in range(len(hypotheses)):
        ref = {i: list_of_refs[i]}
        hyp = {i: hypotheses[i]}
        print(ref)
        print(hyp)
        for scorer, method in scorers:
            print('computing %s score...' % (scorer.method()))
            score, scores = scorer.compute_score(ref, hyp)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    # self.setEval(sc, m)
                    # self.setImgToEvalImgs(scs, gts.keys(), m)
                    print("%s: %0.3f" % (m, sc))
            else:
                # self.setEval(score, method)
                # self.setImgToEvalImgs(scores, gts.keys(), method)
                print("%s: %0.3f" % (method, score))
Esempio n. 15
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('-data_train', type=str, default="")
    parser.add_argument('-data_dev', required=True)
    parser.add_argument('-data_test', type=str, default="")
    parser.add_argument('-vocab', required=True)

    parser.add_argument('-epoch', type=int, default=10000)
    parser.add_argument('-batch_size', type=int, default=64)

    #parser.add_argument('-d_word_vec', type=int, default=512)
    parser.add_argument('-d_model', type=int, default=512)
    # parser.add_argument('-d_inner_hid', type=int, default=2048)
    # parser.add_argument('-d_k', type=int, default=64)
    # parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_heads', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    # parser.add_argument('-embs_share_weight', action='store_true')
    # parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')
    parser.add_argument('-num_workers', type=int, default=1)

    parser.add_argument('-cnn_name', type=str, default="resnet101")
    parser.add_argument('-cnn_pretrained_model', type=str, default="")
    parser.add_argument('-joint_enc_func', type=str, default="element_multiplication")
    # parser.add_argument('-comparative_module_name', type=str, default="transformer_encoder")
    parser.add_argument('-lr', type=float, default=0.01)
    # parser.add_argument('-step_size', type=int, default=1000)
    # parser.add_argument('-gamma', type=float, default=0.9)
    parser.add_argument('-crop_size', type=int, default=224)
    parser.add_argument('-max_seq_len', type=int, default=64)
    parser.add_argument('-attribute_len', type=int, default=5)

    parser.add_argument('-pretrained_model', type=str, default="")

    parser.add_argument('-rank_alpha', type=float, default=1.0)
    parser.add_argument('-patience', type=int, default=7)
    parser.add_argument('-bleu_valid_every_n', type=int, default=5)
    parser.add_argument('-data_dev_combined', required=True)
    parser.add_argument('-beam_size', type=int, default=5)
    parser.add_argument('-seed', type=int, default=0)
    parser.add_argument('-attribute_vocab_size', type=int, default=1000)
    parser.add_argument('-add_attribute', action='store_true')

    

    args = parser.parse_args()
    args.cuda = not args.no_cuda
    args.d_word_vec = args.d_model

    args.load_weights = False
    if args.pretrained_model:
        args.load_weights = True

    np.random.seed(0)
    torch.manual_seed(0)
    args.device = torch.device('cuda' if args.cuda else 'cpu')

    log_path = args.log.split("/")
    log_path = "/".join(log_path[:-1])
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    model_path = args.save_model.split("/")
    model_path = "/".join(model_path[:-1])
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    print(args)

    if args.data_train:
        print("======================================start training======================================")
        transform = transforms.Compose([ 
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

        transform_dev = transforms.Compose([
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        vocab = Vocabulary()
        
        vocab.load(args.vocab)

        args.vocab_size = len(vocab)

        # Build data loader
        data_loader_training = get_loader(args.data_train,
                                             vocab, transform,
                                             args.batch_size, shuffle=True, num_workers=args.num_workers, \
                                             max_seq_len=args.max_seq_len,\
                                             attribute_len=args.attribute_len
                                         )

        data_loader_dev = get_loader(args.data_dev,
                                     vocab, transform_dev,
                                     args.batch_size, shuffle=False, num_workers=args.num_workers, \
                                     max_seq_len=args.max_seq_len,\
                                     attribute_len=args.attribute_len
                                     )

        data_loader_bleu = get_loader_test(args.data_dev_combined,
                                     vocab, transform_dev,
                                     1, shuffle=False,
                                    attribute_len=args.attribute_len
                                     )

        list_of_refs_dev = load_ori_token_data_new(args.data_dev_combined)

        model = get_model(args, load_weights=False)


        print(count_parameters(model))

        # print(model.get_trainable_parameters())
        # init_lr = np.power(args.d_model, -0.5)

        # optimizer = torch.optim.Adam(model.get_trainable_parameters(), lr=init_lr)
        optimizer = get_std_opt(model, args)
        
        train( model, data_loader_training, data_loader_dev, optimizer ,args, vocab, list_of_refs_dev, data_loader_bleu)

    if args.data_test:
        print("======================================start testing==============================")
        args.pretrained_model = args.save_model 
        test(args)
Esempio n. 16
0
import torch

# Read weights per word to a dictionary
file = open("./glove.6B/glove.6B.50d.txt", "r")
word2weight = dict()
for line in file:
    line = line.split()
    word = line[0]
    weights = list(map(float, line[1:]))
    word2weight[word] = weights
file.close()

# Load target vocabulary
train_caption_path = "./data/annotations/captions_train2014.json"
vocab_path = "./vocabulary.pkl"
vocab = Vocabulary(train_caption_path, vocab_path)

words_in_vocab = vocab.word_to_id.keys()

# Find matching words between vocab and pre-trained and save their weights. Else random weights
embedding_dim = 50  # Must fit the dimension from the pre-trained file
pretrained_weight_matrix = np.zeros((len(words_in_vocab), embedding_dim))
words_found = 0
for i, word in enumerate(words_in_vocab):
    try:
        pretrained_weight_matrix[i] = word2weight[word]
        words_found += 1
    except KeyError:
        pretrained_weight_matrix[i] = np.random.normal(scale=0.6,
                                                       size=(embedding_dim, ))
Esempio n. 17
0
def train(args):
    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    transform_dev = transforms.Compose([
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    vocab = Vocabulary()
    vocab.load(DICT.format(args.data_set))

    # Build data loader
    data_loader = get_loader(IMAGE_ROOT.format(args.data_set),
                             CAPT.format(args.data_set, 'train'),
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             return_target=True,
                             num_workers=args.num_workers)

    data_loader_dev = get_loader(IMAGE_ROOT.format(args.data_set),
                                 CAPT.format(args.data_set, 'val'),
                                 vocab,
                                 transform_dev,
                                 args.batch_size,
                                 shuffle=False,
                                 return_target=True,
                                 num_workers=args.num_workers)

    ranker = Ranker(root=IMAGE_ROOT.format(args.data_set),
                    image_split_file=SPLIT.format(args.data_set, 'val'),
                    transform=transform_dev,
                    num_workers=args.num_workers)

    save_folder = '{}/{}-{}'.format(args.save, args.data_set,
                                    time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(save_folder,
                   scripts_to_save=[
                       'models.py', 'data_loader.py', 'train.py',
                       'build_vocab.py', 'utils.py'
                   ])

    def logging(s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(save_folder, 'log.txt'), 'a+') as f_log:
                f_log.write(s + '\n')

    logging(str(args))
    # Build the dummy models
    image_encoder = DummyImageEncoder(args.embed_size).to(device)
    caption_encoder = DummyCaptionEncoder(
        vocab_size=len(vocab),
        vocab_embed_size=args.embed_size * 2,
        embed_size=args.embed_size).to(device)

    image_encoder.train()
    caption_encoder.train()
    params = image_encoder.get_trainable_parameters(
    ) + caption_encoder.get_trainable_parameters()

    current_lr = args.learning_rate
    optimizer = torch.optim.Adam(params, lr=current_lr)

    cur_patient = 0
    best_score = float('-inf')
    stop_train = False
    total_step = len(data_loader)
    # epoch = 1 for dummy setting
    for epoch in range(1):

        for i, (target_images, candidate_images, captions, lengths,
                meta_info) in enumerate(data_loader):

            target_images = target_images.to(device)
            target_ft = image_encoder.forward(target_images)

            candidate_images = candidate_images.to(device)
            candidate_ft = image_encoder.forward(candidate_images)

            captions = captions.to(device)
            caption_ft = caption_encoder(captions, lengths)

            # random select negative examples
            m = target_images.size(0)
            random_index = [m - 1 - n for n in range(m)]
            random_index = torch.LongTensor(random_index)
            negative_ft = target_ft[random_index]

            loss = triplet_avg(anchor=(candidate_ft + caption_ft),
                               positive=target_ft,
                               negative=negative_ft)

            caption_encoder.zero_grad()
            image_encoder.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.log_step == 0:
                logging(
                    '| epoch {:3d} | step {:6d}/{:6d} | lr {:06.6f} | train loss {:8.3f}'
                    .format(epoch, i, total_step, current_lr, loss.item()))

        image_encoder.eval()
        caption_encoder.eval()
        logging('-' * 77)
        metrics = eval_batch(data_loader_dev, image_encoder, caption_encoder,
                             ranker)
        logging('| eval loss: {:8.3f} | score {:8.5f} / {:8.5f} '.format(
            metrics['loss'], metrics['score'], best_score))
        logging('-' * 77)

        image_encoder.train()
        caption_encoder.train()

        dev_score = metrics['score']
        if dev_score > best_score:
            best_score = dev_score
            # save best model
            resnet = image_encoder.delete_resnet()
            torch.save(
                image_encoder.state_dict(),
                os.path.join(save_folder,
                             'image-{}.th'.format(args.embed_size)))
            image_encoder.load_resnet(resnet)

            torch.save(
                caption_encoder.state_dict(),
                os.path.join(save_folder, 'cap-{}.th'.format(args.embed_size)))

            cur_patient = 0
        else:
            cur_patient += 1
            if cur_patient >= args.patient:
                current_lr /= 2
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr
                if current_lr < args.learning_rate * 1e-3:
                    stop_train = True
                    break

        if stop_train:
            break
    logging('best_dev_score: {}'.format(best_score))