Exemple #1
0
def translate_sentence(sentence, model, opt, SRC, TRG, counter):
    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)
    # import ipdb; ipdb.set_trace()
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0:
            indexed.append(SRC.vocab.stoi[tok])
        else:
            indexed.append(0)
        #     indexed.append(get_synonym(tok, SRC))
    sentence = Variable(torch.LongTensor([indexed]))
    if opt.no_cuda is False:
        sentence = sentence.cuda()
    # try: 
        # import ipdb; ipdb.set_trace()
    # if opt.nmt_model_type == 'transformer':
    sentence = beam_search(sentence, model, SRC, TRG, opt)
    # else:
        # sentence = rnn_beam_search(sentence, model, TRG, opt)
        # sentence = beam_search(sentence, model, SRC, TRG, opt)
        # sentence = generate_rnn_translations(sentence, model, SRC, TRG, opt)
    # except:
    #     sentence = ''
    #     print(f'Error happened at sentence {counter}!')
        # import ipdb; ipdb.set_trace()
        
    return  multiple_replace({' ?' : '?',' !':'!',' .':'.','\' ':'\'',' ,':','}, sentence)
Exemple #2
0
def translate_sentence(sentence, model, opt, SRC, TRG):

    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
            indexed.append(SRC.vocab.stoi[tok])
        else:
            indexed.append(get_synonym(tok, SRC))
    sentence = Variable(torch.LongTensor([indexed]))
    if opt.device == 0:
        sentence = sentence.cuda()

    sentence = beam_search(sentence, model, SRC, TRG, opt)
    sentence = capitalize(sentence)

    return multiple_replace(
        {
            ' ?': '?',
            ' !': '!',
            ' .': '.',
            '\' ': '\'',
            ' ,': ',',
            " '": "'"
        }, sentence)
Exemple #3
0
def translate_sentence(sentence, model, opt, SRC, TRG):

    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
            indexed.append(SRC.vocab.stoi[tok])
        else:
            print("getting synonym on", tok)
            indexed.append(get_synonym(tok, SRC))
    print('indexed', indexed)
    sentence = Variable(torch.LongTensor([indexed]))
    if opt.device == 0:
        sentence = sentence.cuda()
    print(sentence)
    input("hiiiii")
    sentence = beam_search(sentence, model, SRC, TRG, opt)

    return multiple_replace(
        {
            ' ?': '?',
            ' !': '!',
            ' .': '.',
            '\' ': '\'',
            ' ,': ','
        }, sentence)
Exemple #4
0
def translate_sentence(sentence, model, opt, src_vocab, trg_vocab):

    sentence = preprocess_input(sentence)
    words = sentence.split()
    indices = [src_vocab.bos_idx]
    for i, w in enumerate(words):
        if i + 1 == opt.max_src_len:
            break
        try:
            idx = src_vocab.stoi[w.lower()]
        except:
            idx = src_vocab.unk_idx
        indices.append(idx)
    indices.append(src_vocab.eos_idx)
    if len(
            indices
    ) < opt.max_src_len + 1:  # we add bos token when initialize ss so we need to plus 1
        indices += [src_vocab.pad_idx] * (opt.max_src_len - len(indices) + 1)
    elif len(indices) > opt.max_src_len + 1:
        indices = indices[:opt.max_src_len + 1]
        indices[-1] = src_vocab.eos_idx
    sentence = Variable(torch.LongTensor([indices]))
    sentence = sentence.to(opt.device)

    sentence = beam_search(sentence, model, src_vocab, trg_vocab, opt)

    return multiple_replace(
        {
            ' ?': '?',
            ' !': '!',
            ' .': '.',
            '\' ': '\'',
            ' ,': ','
        }, sentence)
Exemple #5
0
def translate_sentence(sentence, model, opt, SRC, TRG):
    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)
    import ipdb
    ipdb.set_trace()
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0:
            indexed.append(SRC.vocab.stoi[tok])
        else:
            indexed.append(0)
        #     indexed.append(get_synonym(tok, SRC))
    sentence = Variable(torch.LongTensor([indexed]))
    if opt.no_cuda is False:
        sentence = sentence.cuda()
    import ipdb
    ipdb.set_trace()
    sentence = beam_search(sentence, model, SRC, TRG, opt)

    return multiple_replace(
        {
            ' ?': '?',
            ' !': '!',
            ' .': '.',
            '\' ': '\'',
            ' ,': ','
        }, sentence)
Exemple #6
0
def translate_sentence(sentence, model, opt, vocab, tokenizer):
    
    model.eval()
    indexed = []
    #sentence = SRC.preprocess(sentence)

    sentence = convert_tokens_to_ids(vocab, tokenizer, sentence)

    sentence = Variable(torch.LongTensor([sentence]))
    sentence = sentence.cuda()
    
    sentence = beam_search(sentence, model, vocab, opt)

    #return  multiple_replace({' ?' : '?',' !':'!',' .':'.','\' ':'\'',' ,':','}, sentence)
    
    return sentence
def generate(model,opt):
    print("generating music using beam search...")
    model.eval()

    # choose 2 random pitches within the vocab (except rest/pad token) to start the sequence
    starting_pitch = torch.randint(2, len(opt.vocab)-1, (2,)).unsqueeze(1).transpose(0,1).to(opt.device)

    # generate the sequence using beam search
    generated_seq = beam_search(starting_pitch, model, opt)

    # Make the index values back to original pitch
    output_seq = IndexToPitch(generated_seq, opt.vocab)

    # Process the output format such that it is the same as our dataset
    processed = ProcessModelOutput(output_seq)

    return processed
def translate_sentence(sentence, model, opt, SRC, TRG):
    
    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)
    print(sentence[:-1])
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
            indexed.append(SRC.vocab.stoi[tok])
        else:
            indexed.append(get_synonym(tok, SRC))
    sentence = Variable(torch.LongTensor([indexed]))
    if opt.device == 0:
        sentence = sentence.cuda()
    
    sentence = beam_search(sentence, model, SRC, TRG, opt)

    return  sentence
def translate(sentence, model, opt, SRC, TRG):
    model.eval()

    sentences = sentence.lower().split(".")

    for sentence in sentences:
        preprocess_sentence = SRC.preprocess(sentence + '.')

        indexed = []
        # get the index of the text in each sentence
        for token in preprocess_sentence:
            if SRC.vocab.stoi[token] != 0 or opt.floyd == True:
                indexed.append(SRC.vocab.stoi[token])
            else:
                indexed.append(get_synonym(token, SRC))

        sentence = Variable(torch.LongTensor([indexed])).to(opt.device)
        sentence = beam_search(sentence, model, SRC, TRG, opt)
        return sentence
    def translate_sentence(self, sentence, model, opt, SRC, TRG):
        model.eval()
        indexed = []
        sentence = SRC.preprocess(sentence)
        for tok in sentence:
            if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
                indexed.append(SRC.vocab.stoi[tok])
            else:
                indexed.append(self.get_synonym(tok, SRC))
        sentence = Variable(torch.LongTensor([indexed]))
        if opt.device == 0:
            sentence = sentence.cuda()

        sentences, query, string_query = beam_search(sentence, model, SRC, TRG, opt)
        # print(sentences)
        # print(query)

        for sentence in sentences:
            self.multiple_replace({' ?': '?', ' !': '!', ' .': '.', '\' ': '\'', ' ,': ','}, sentence)
        return sentences, query, string_query
Exemple #11
0
def eval_epoch_bleu(model, validation_data, device, vocab, list_of_refs_dev, args):
    ''' Epoch operation in evaluation phase '''

    model.eval()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    hypotheses = {}
    count = 0

    with torch.no_grad():
        for batch in tqdm(
                validation_data, mininterval=2,
                desc='  - (Validation) ', leave=False):

            # prepare data
            image0, image1, image0_attribute, image1_attribute = map(lambda x: x.to(device), batch)

            """[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
                                    that should be masked with float('-inf') and False values will be unchanged.
                                    This mask ensures that no information will be taken from position i if
                                    it is masked, and has a separate mask for each sequence in a batch."""

            hyp = beam_search(image0, image1, model, args, vocab, image0_attribute, image1_attribute)

            hyp = hyp.split("<end>")[0].strip()

            hypotheses[count] = [hyp]

            count += 1

        scorer = Bleu(4)

        score, _ = scorer.compute_score(list_of_refs_dev, hypotheses)

    return score
Exemple #12
0
    def get_caption(self,
                    target_img,
                    candidate_img,
                    target_attr,
                    candidate_attr,
                    return_cap=False):
        pad_idx = self.vocab('<pad>')
        if self.decode_mode == 'beam_search':
            packed_results = [
                beam_search(candidate_img[i].unsqueeze(dim=0).unsqueeze(dim=0),
                            target_img[i].unsqueeze(dim=0).unsqueeze(dim=0),
                            self.model, self.opt, self.vocab,
                            candidate_attr[i].unsqueeze(dim=0),
                            target_attr[i].unsqueeze(dim=0))
                for i in range(target_img.size(0))
            ]

            pad_cap_idx = []
            caps = []
            for cap in packed_results:
                caps.append(cap[1])
                if len(cap[0]) > self.max_seq_len:
                    pad_cap_idx.append(cap[0][:self.max_seq_len])
                else:
                    pad_cap_idx.append(cap[0] + [pad_idx] *
                                       (self.max_seq_len - len(cap[0])))

            pad_cap_idx = torch.tensor(pad_cap_idx, dtype=torch.long)
        else:
            pad_cap_idx, caps = greedy_search(candidate_img.unsqueeze(dim=1),
                                              target_img.unsqueeze(dim=1),
                                              self.model, self.opt, self.vocab,
                                              candidate_attr, target_attr)
        if return_cap:
            return pad_cap_idx, caps
        return pad_cap_idx
Exemple #13
0
def translate_sentence(sentence, model, opt, SRC, TRG):
    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)  # 预处理输入数据
    # print("sentence",sentence)
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
            indexed.append(SRC.vocab.stoi[tok])
    sentence = Variable(torch.LongTensor([indexed]))  # 转tensor数据
    # print("sentence",sentence)
    if opt.device == 0:
        sentence = sentence.cuda()
    sentence = beam_search(sentence, model, SRC, TRG, opt)
    if opt.k == 1:
        return multiple_replace(
            {
                ' ?': '?',
                ' !': '!',
                ' .': '.',
                '\' ': '\'',
                ' ,': ','
            }, sentence[0])
    else:
        res = []
        for i in sentence:
            print("i", i)
            res += i.split(" ")
        # print("list(set(res))",list(set(res)))
        return multiple_replace(
            {
                ' ?': '?',
                ' !': '!',
                ' .': '.',
                '\' ': '\'',
                ' ,': ','
            }, " ".join(list(set(res))))
Exemple #14
0
def train_model(model, opt, SRC, TRG):
    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()

    for epoch in range(opt.epochs):
        model.train()
        total_loss = 0
        avg_loss = 1e5
        print("   %dm: epoch %d [%s]  %d%%  loss = %s" % ((time.time() - start) // 60, epoch + 1, "".join(' ' * 20), 0, '...'), end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()

            assert src.shape[0] == 1
            src_tokens = ' '.join(pred_to_vocab(SRC, src[0]))
            refs, steps = train_split_input(SRC, src_tokens)
            steps = [torch.LongTensor([step]).cuda() for step in steps]

            fake_trg = torch.ones((1, opt.max_strlen)).type(torch.LongTensor).cuda()
            fake_trg[:, 0] = fake_trg[:, 0] * 2
            real_trg = batch.trg.transpose(0, 1).cuda()
            fake_trg_input, real_trg_input = fake_trg[:, :-1], real_trg[:, :-1]
            _, fake_trg_mask = create_hard_masks(src, fake_trg_input, opt)
            _, real_trg_mask = create_masks(src, real_trg_input, opt)

            sep_tensor = model.encoder.embed(torch.LongTensor([model.sep_token]).cuda()).unsqueeze(0)
            decoder_embed = model.decoder.embed.get_weights()

            try:
                preds = model(sep_tensor, decoder_embed, steps, refs, fake_trg_input, real_trg_input, fake_trg_mask, real_trg_mask)
            except RuntimeError:
                continue
            ys = real_trg[:, 1:].contiguous().view(-1)

            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
            try:
                loss.backward()
            except RuntimeError:
                continue
            opt.optimizer.step()

            # print('success on step length', len(steps), '; token length', src.shape[1])
            if opt.SGDR:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and ((time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" % \
              ((time.time() - start) // 60, epoch + 1, "".join('#' * (100 // 5)), "".join(' ' * (20 - (100 // 5))), 100,
               avg_loss, epoch + 1, avg_loss))

        if opt.calculate_val_loss:
            model.eval()
            val_losses = []
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                trg = batch.trg.transpose(0, 1).cuda()
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                ys = trg[:, 1:].contiguous().view(-1)
                opt.optimizer.zero_grad()
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
                val_losses.append(loss.item())

            print('validation loss:', sum(val_losses) / len(val_losses), '\n')

        if opt.val_forward_pass:
            model.eval()
            val_losses = []
            val_losses_no_eos = []
            val_eos_dict = {k: [0, 0] for k in range(1, 601)}
            val_pad_dict = {k: [0, 0] for k in range(1, 601)}
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                # trg = batch.trg.transpose(0, 1).cuda()
                # TODO: confirm swap below
                real_trg = batch.trg.transpose(0, 1).cuda()
                trg = torch.ones_like(real_trg).type(torch.LongTensor).cuda()
                trg[:, 0] = trg[:, 0] * 2

                bs = src.shape[0]
                add_pad = math.ceil(random.random() * 3)
                trg = torch.cat((trg, torch.ones((bs, add_pad)).type(torch.LongTensor).cuda()), dim=1)

                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                pred_tokens = torch.argmax(preds, dim=-1)
                # ys = trg[:, 1:] TODO: was swapped with real
                ys = real_trg[:, 1:]

                for b_ in range(bs):
                    pred_tok = pred_tokens[b_]
                    y = ys[b_]
                    sl = y.shape[0]

                    eos_index = ((y == 3).nonzero(as_tuple=True)[0])[0]  # 3 = eos in vocab
                    if type(eos_index) != int:
                        eos_index = eos_index.item()

                    if torch.equal(pred_tok[:eos_index], y[:eos_index]):
                        val_losses.append(1)
                        if eos_index in val_eos_dict.keys():  # add to seq length counter
                            val_eos_dict[eos_index][0] += 1
                    else:
                        val_losses.append(0)

                    if eos_index in val_eos_dict.keys():
                        val_eos_dict[sl][1] += 1

                    pad_index = ((y == 1).nonzero(as_tuple=True)[0])  # 1 = pad in vocab
                    if pad_index.shape[0] == 0:
                        pad_index = y.shape[0]
                    else:
                        pad_index = pad_index[0]
                    if type(pad_index) != int:
                        pad_index = pad_index.item()

                    if torch.equal(pred_tok[:pad_index], y[:pad_index]):
                        val_losses_no_eos.append(1)
                        if pad_index in val_pad_dict.keys():  # add to seq length counter
                            val_pad_dict[pad_index][0] += 1
                    else:
                        val_losses_no_eos.append(0)

                    if sl in val_pad_dict.keys():
                        val_pad_dict[pad_index][1] += 1

            print('forward pass validation accuracy - no eos:',
                  round(sum(val_losses) / len(val_losses) * 100, 2), '%')
            print('forward pass validation accuracy - no pad:',
                  round(sum(val_losses_no_eos) / len(val_losses_no_eos) * 100, 2), '%')

        if (epoch + 1) % opt.val_check_every_n == 0:
            model.eval()
            val_acc, val_success = 0, 0
            val_data = zip_io_data(opt.data_path + '/val')
            for j, e in enumerate(val_data[:opt.n_val]):
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True

                            break

                    if not comp_failure:
                        try:
                            val_acc += simple_em(intermediates[-1], e_tgt)
                            val_success += 1
                        except Exception as e:
                            continue
                else:
                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        val_acc += simple_em(sentence, e_tgt)
                        val_success += 1
                    except Exception as e:
                        continue

            if val_success == 0:
                val_success = 1
            val_acc = val_acc / val_success
            print('epoch', epoch, '- val accuracy:', round(val_acc * 100, 2))
            print()
            opt.scheduler.step(val_acc)

        if epoch == opt.epochs - 1 and opt.do_test:
            model.eval()
            test_data = zip_io_data(opt.data_path + '/test')
            test_predictions = ''
            test_acc, test_success = 0, 0
            for j, e in enumerate(test_data[:opt.n_test]):
                if (j + 1) % 10000 == 0:
                    print(round(j / len(test_data) * 100, 2), '% complete with testing')
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True
                            break

                    if not comp_failure:
                        try:
                            test_acc += simple_em(sentence, e_tgt)
                            test_success += 1
                            test_predictions += sentence + '\n'
                        except Exception as e:
                            test_predictions += '\n'
                            continue
                    else:
                        test_predictions += '\n'
                else:
                    indexed = []
                    sentence = SRC.preprocess(e_src)
                    pass_bool = False
                    for tok in sentence:
                        if SRC.vocab.stoi[tok] != 0:
                            indexed.append(SRC.vocab.stoi[tok])
                        else:
                            pass_bool = True
                            break
                    if pass_bool:
                        continue

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        test_acc += simple_em(sentence, e_tgt)
                        test_success += 1
                        test_predictions += sentence + '\n'
                    except Exception as e:
                        test_predictions += '\n'
                        continue

            if test_success == 0:
                test_success = 1
            test_acc = test_acc / test_success
            print('test accuracy:', round(test_acc * 100, 2))
            print()

            if not os.path.exists(opt.output_dir):
                os.makedirs(opt.output_dir)

            with open(opt.output_dir + '/test_generations.txt', 'w', encoding='utf-8') as f:
                f.write(test_predictions)
Exemple #15
0
def train_model(model, opt, SRC, TRG):
    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()
                 
    for epoch in range(opt.epochs):
        model.train()
        total_loss = 0
        errors_per_epoch = 0
        if opt.floyd is False:
            print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
            ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')
        
        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()
            trg = batch.trg.transpose(0, 1).cuda()

            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)
            preds = model(src, trg_input, src_mask, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)

            opt.optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()

            if opt.SGDR:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                 p = int(100 * (i + 1) / opt.train_len)
                 avg_loss = total_loss/opt.printevery
                 if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                 else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                 total_loss = 0

            if opt.checkpoint > 0 and ((time.time()-cptime)//60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))

        print('errors per epoch:', errors_per_epoch)
        if opt.calculate_val_loss:
            model.eval()
            val_losses = []
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                trg = batch.trg.transpose(0, 1).cuda()
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                ys = trg[:, 1:].contiguous().view(-1)
                opt.optimizer.zero_grad()
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
                val_losses.append(loss.item())

            print('validation loss:', sum(val_losses)/len(val_losses), '\n')

        if (epoch + 1) % opt.val_check_every_n == 0:
            model.eval()
            val_acc, val_success = 0, 0
            val_data = zip_io_data(opt.data_path + '/val')
            for j, e in enumerate(val_data[:opt.n_val]):
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True

                            break

                    if not comp_failure:
                        try:
                            val_acc += simple_em(intermediates[-1], e_tgt)
                            val_success += 1
                        except Exception as e:
                            continue
                else:
                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        val_acc += simple_em(sentence, e_tgt)
                        val_success += 1
                    except Exception as e:
                        continue

            if val_success == 0:
                val_success = 1
            val_acc = val_acc / val_success
            print('epoch', epoch, '- val accuracy:', round(val_acc * 100, 2))
            print()
            opt.scheduler.step(val_acc)

        if epoch == opt.epochs - 1 and opt.do_test:
            model.eval()
            test_data = zip_io_data(opt.data_path + '/test')
            test_predictions = ''
            test_acc, test_success = 0, 0
            for j, e in enumerate(test_data[:opt.n_test]):
                if (j + 1) % 10000 == 0:
                    print(round(j/len(test_data) * 100, 2), '% complete with testing')
                e_src, e_tgt = e[0], e[1]

                if opt.compositional_eval:
                    controller = eval_split_input(e_src)
                    intermediates = []
                    comp_failure = False
                    for controller_input in controller:
                        if len(controller_input) == 1:
                            controller_src = controller_input[0]

                        else:
                            controller_src = ''
                            for src_index in range(len(controller_input) - 1):
                                controller_src += intermediates[controller_input[src_index]] + ' @@SEP@@ '
                            controller_src += controller_input[-1]
                            controller_src = remove_whitespace(controller_src)

                        indexed = []
                        sentence = SRC.preprocess(controller_src)
                        for tok in sentence:
                            if SRC.vocab.stoi[tok] != 0:
                                indexed.append(SRC.vocab.stoi[tok])
                            else:
                                comp_failure = True
                                break
                        if comp_failure:
                            break

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG, opt)
                            intermediates.append(sentence)
                        except Exception as e:
                            comp_failure = True
                            break

                    if not comp_failure:
                        try:
                            test_acc += simple_em(sentence, e_tgt)
                            test_success += 1
                            test_predictions += sentence + '\n'
                        except Exception as e:
                            test_predictions += '\n'
                            continue
                    else:
                        test_predictions += '\n'
                else:
                    indexed = []
                    sentence = SRC.preprocess(e_src)
                    pass_bool = False
                    for tok in sentence:
                        if SRC.vocab.stoi[tok] != 0:
                            indexed.append(SRC.vocab.stoi[tok])
                        else:
                            pass_bool = True
                            break
                    if pass_bool:
                        continue

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue
                    try:
                        test_acc += simple_em(sentence, e_tgt)
                        test_success += 1
                        test_predictions += sentence + '\n'
                    except Exception as e:
                        test_predictions += '\n'
                        continue


            if test_success == 0:
                test_success = 1
            test_acc = test_acc / test_success
            print('test accuracy:', round(test_acc * 100, 2))
            print()

            if not os.path.exists(opt.output_dir):
                os.makedirs(opt.output_dir)

            with open(opt.output_dir + '/test_generations.txt', 'w', encoding='utf-8') as f:
                f.write(test_predictions)
Exemple #16
0
def train_model(model, opt, SRC, TRG):
    print("training model...")

    model.train()
    start = time.time()
    mask_prob = opt.mask_prob
    if opt.checkpoint > 0:
        cptime = time.time()

    for epoch in range(opt.epochs):
        model.train()
        total_loss = 0
        avg_loss = 1e5
        print("   %dm: epoch %d [%s]  %d%%  loss = %s" %
              ((time.time() - start) // 60, epoch + 1, "".join(
                  ' ' * 20), 0, '...'),
              end='\r')

        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')

        for i, batch in enumerate(opt.train):
            src = batch.src.transpose(0, 1).cuda()
            real_trg = batch.trg.transpose(0, 1).cuda()
            bs = src.shape[0]
            add_pad = math.ceil(random.random() * 3)

            masked = False
            if opt.task == 'e_snli_o':
                trg = real_trg
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
            elif random.random() > mask_prob:
                trg = torch.cat((real_trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                dim=1)
                real_trg = torch.cat((real_trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                     dim=1)
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
            else:
                masked = True
                trg = torch.ones_like(real_trg).type(torch.LongTensor).cuda()
                trg[:, 0] = trg[:, 0] * 2
                trg = torch.cat((trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                dim=1)
                real_trg = torch.cat((real_trg, torch.ones(
                    (bs, add_pad)).type(torch.LongTensor).cuda()),
                                     dim=1)
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_hard_masks(src, trg_input, opt)

            if opt.task == 'e_snli_o':
                preds = model(src, src_mask)
            else:
                preds = model(src, trg_input, src_mask, trg_mask)

            # for non-classifier tasks:
            ys = real_trg[:, 1:].contiguous().view(-1)

            opt.optimizer.zero_grad()
            if opt.task == 'e_snli_o':
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                       trg.contiguous().view(-1),
                                       ignore_index=opt.trg_pad)
            else:
                if masked:
                    peaked_soft = torch.exp(opt.alpha *
                                            F.softmax(preds, dim=-1))
                    peaked_soft_sum = torch.sum(peaked_soft,
                                                dim=-1).unsqueeze(2)
                    new_preds = torch.div(peaked_soft, peaked_soft_sum)
                    loss = F.cross_entropy(new_preds.view(-1, preds.size(-1)),
                                           ys,
                                           ignore_index=opt.trg_pad)
                else:
                    loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                           ys,
                                           ignore_index=opt.trg_pad)
            loss.backward()
            opt.optimizer.step()

            if opt.wandb:
                if i % opt.log_interval == 0:
                    wandb.log({"loss": loss})

            if opt.SGDR:
                opt.sched.step()

            total_loss += loss.item()

            if (i + 1) % opt.printevery == 0:
                p = int(100 * (i + 1) / opt.train_len)
                avg_loss = total_loss / opt.printevery
                if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss), end='\r')
                else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                          ((time.time() - start) // 60, epoch + 1, "".join('#' * (p // 5)),
                           "".join(' ' * (20 - (p // 5))), p, avg_loss))
                total_loss = 0

            if opt.checkpoint > 0 and (
                (time.time() - cptime) // 60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()

        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" % \
              ((time.time() - start) // 60, epoch + 1, "".join('#' * (100 // 5)), "".join(' ' * (20 - (100 // 5))), 100,
               avg_loss, epoch + 1, avg_loss))

        if opt.calculate_val_loss:
            model.eval()
            val_losses = []
            for i, batch in enumerate(opt.val):
                src = batch.src.transpose(0, 1).cuda()
                trg = batch.trg.transpose(0, 1).cuda()
                trg_input = trg[:, :-1]
                src_mask, trg_mask = create_masks(src, trg_input, opt)
                preds = model(src, trg_input, src_mask, trg_mask)
                ys = trg[:, 1:].contiguous().view(-1)
                opt.optimizer.zero_grad()
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)),
                                       ys,
                                       ignore_index=opt.trg_pad)
                val_losses.append(loss.item())

            print('validation loss:', sum(val_losses) / len(val_losses), '\n')

        if opt.val_forward_pass:
            model.eval()
            val_losses_no_eos = []

            if opt.task == 'toy_task':
                for i, batch in enumerate(opt.val):
                    src = batch.src.transpose(0, 1).cuda()
                    real_trg = batch.trg.transpose(0, 1).cuda()
                    trg = torch.ones_like(real_trg).type(
                        torch.LongTensor).cuda()
                    trg[:, 0] = trg[:, 0] * 2

                    bs = src.shape[0]
                    add_pad = math.ceil(random.random() * 3)
                    trg = torch.cat((trg, torch.ones(
                        (bs, add_pad)).type(torch.LongTensor).cuda()),
                                    dim=1)

                    trg_input = trg[:, :-1]
                    src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                    preds = model(src, trg_input, src_mask, trg_mask)
                    pred_tokens = torch.argmax(preds, dim=-1)
                    ys = real_trg[:, 1:]

                    for b_ in range(bs):
                        pred_tok = pred_tokens[b_]
                        y = ys[b_]

                        pad_index = ((y == 1).nonzero(as_tuple=True)[0]
                                     )  # 1 = pad in vocab
                        if pad_index.shape[0] == 0:
                            pad_index = y.shape[0]
                        else:
                            pad_index = pad_index[0]
                        if type(pad_index) != int:
                            pad_index = pad_index.item()

                        if torch.equal(pred_tok[:pad_index], y[:pad_index]):
                            val_losses_no_eos.append(1)
                        else:
                            val_losses_no_eos.append(0)
            elif opt.task == 'e_snli_r':
                val_label_accuracy = []
                for i, batch in enumerate(opt.val):
                    src = batch.src.transpose(0, 1).cuda()
                    real_trg1 = batch.trg1.transpose(0, 1).cuda()
                    real_trg2 = batch.trg2.transpose(0, 1).cuda()
                    real_trg3 = batch.trg3.transpose(0, 1).cuda()
                    labels = batch.label.transpose(0, 1).cuda()

                    bs = src.shape[0]
                    max_sl = max([
                        real_trg1.shape[1], real_trg2.shape[1],
                        real_trg3.shape[1]
                    ])

                    trg = torch.ones(
                        (bs, max_sl)).type(torch.LongTensor).cuda()
                    trg[:, 0] = trg[:, 0] * 2

                    add_pad = math.ceil(random.random() * 3)
                    trg = torch.cat((trg, torch.ones(
                        (bs, add_pad)).type(torch.LongTensor).cuda()),
                                    dim=1)

                    trg_input = trg[:, :-1]
                    src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                    preds = model(src, trg_input, src_mask, trg_mask)
                    pred_tokens = torch.argmax(preds, dim=-1)
                    ys1 = real_trg1[:, 1:]
                    ys2 = real_trg2[:, 1:]
                    ys3 = real_trg3[:, 1:]

                    for b_ in range(bs):
                        pred_tok = pred_tokens[b_]
                        y1 = ys1[b_]
                        y2 = ys2[b_]
                        y3 = ys3[b_]

                        correct = False
                        for y in [y1, y2, y3]:
                            pad_index = ((y == 1).nonzero(as_tuple=True)[0]
                                         )  # 1 = pad in vocab
                            if pad_index.shape[0] == 0:
                                pad_index = y.shape[0]
                            else:
                                pad_index = pad_index[0]
                            if type(pad_index) != int:
                                pad_index = pad_index.item()

                            if torch.equal(pred_tok[:pad_index],
                                           y[:pad_index]):
                                correct = True

                        if correct:
                            val_losses_no_eos.append(1)
                        else:
                            val_losses_no_eos.append(0)

                        rationale = []
                        for t in pred_tok:
                            word = TRG.vocab.itos[t]
                            if word == '<eos>':
                                break  # TODO: take out for differentiable version
                            if word in opt.classifier_SRC.vocab.stoi.keys():
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi[word])
                            else:
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi['<unk>'])

                        rationale = torch.Tensor([rationale]).type(
                            torch.LongTensor).cuda()
                        pred_label = classify(rationale, opt.classifier,
                                              opt.classifier_SRC,
                                              opt.classifier_TRG)

                        if pred_label == opt.classifier_TRG.vocab.itos[
                                labels[b_]]:  # latter is true label
                            val_label_accuracy.append(1)
                        else:
                            val_label_accuracy.append(0)
            else:
                raise NotImplementedError(
                    "No validation accuracy support for CoS-E or any other tasks yet."
                )

            if opt.wandb:
                wandb.log({
                    'validation forward accuracy':
                    round(
                        sum(val_losses_no_eos) / len(val_losses_no_eos) * 100,
                        2)
                })
                if opt.task == 'e_snli_r':
                    wandb.log({
                        'validation forward label accuracy':
                        round(
                            sum(val_label_accuracy) / len(val_label_accuracy) *
                            100, 2)
                    })
                    print(
                        'validation forward label accuracy:',
                        round(
                            sum(val_label_accuracy) / len(val_label_accuracy) *
                            100, 2))
            print(
                'validation forward accuracy:',
                round(
                    sum(val_losses_no_eos) / len(val_losses_no_eos) * 100, 2),
                '%')

        if (epoch + 1) % opt.val_check_every_n == 0:
            model.eval()
            val_acc, val_success = 0, 0
            val_data = zip_io_data(opt.data_path + '/val')

            if opt.task == 'toy_task':
                for j, e in enumerate(val_data[:opt.n_val]):
                    e_src, e_tgt = e[0], e[1]

                    if opt.compositional_eval:
                        controller = eval_split_input(e_src)
                        intermediates = []
                        comp_failure = False
                        for controller_input in controller:
                            if len(controller_input) == 1:
                                controller_src = controller_input[0]

                            else:
                                controller_src = ''
                                for src_index in range(
                                        len(controller_input) - 1):
                                    controller_src += intermediates[
                                        controller_input[
                                            src_index]] + ' @@SEP@@ '
                                controller_src += controller_input[-1]
                                controller_src = remove_whitespace(
                                    controller_src)

                            indexed = []
                            sentence = SRC.preprocess(controller_src)
                            for tok in sentence:
                                if SRC.vocab.stoi[tok] != 0:
                                    indexed.append(SRC.vocab.stoi[tok])
                                else:
                                    comp_failure = True
                                    break
                            if comp_failure:
                                break

                            sentence = Variable(torch.LongTensor([indexed]))
                            if opt.device == 0:
                                sentence = sentence.cuda()

                            try:
                                sentence = beam_search(sentence, model, SRC,
                                                       TRG, opt)
                                intermediates.append(sentence)
                            except Exception as e:
                                comp_failure = True

                                break

                        if not comp_failure:
                            try:
                                val_acc += simple_em(intermediates[-1], e_tgt)
                                val_success += 1
                            except Exception as e:
                                continue
                    else:
                        sentence = SRC.preprocess(e_src)
                        indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                        sentence = Variable(torch.LongTensor([indexed]))
                        if opt.device == 0:
                            sentence = sentence.cuda()

                        try:
                            sentence = beam_search(sentence, model, SRC, TRG,
                                                   opt)
                        except Exception as e:
                            continue
                        try:
                            val_acc += simple_em(sentence, e_tgt)
                            val_success += 1
                        except Exception as e:
                            continue
            elif opt.task == 'e_snli_r':
                val_labels = zip_io_data(opt.label_path + '/val')
                beam_label_acc = []
                for j, e in enumerate(val_data[:opt.n_val]):
                    e_src, e_tgt = e[0], e[1]
                    e_label = val_labels[j][1]

                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = beam_search(sentence, model, SRC, TRG, opt)
                    except Exception as e:
                        continue

                    candidate_trgs = e_tgt.split(' @@SEP@@ ')
                    correct = False
                    try:
                        for candidate in candidate_trgs:
                            if simple_em(sentence, candidate):
                                correct = True
                    except Exception as e:
                        continue

                    if correct:
                        val_acc += 1
                    val_success += 1

                    rationale = []
                    for word in sentence.split():
                        if word in opt.classifier_SRC.vocab.stoi.keys():
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi[word])
                        else:
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi['<unk>'])

                    rationale = torch.Tensor([rationale
                                              ]).type(torch.LongTensor).cuda()
                    pred_label = classify(rationale, opt.classifier,
                                          opt.classifier_SRC,
                                          opt.classifier_TRG)

                    if pred_label == e_label:  # latter is true label
                        beam_label_acc.append(1)
                    else:
                        beam_label_acc.append(0)

            elif opt.task == 'e_snli_o':
                for j, e in enumerate(val_data[:opt.n_val]):
                    e_src, e_tgt = e[0], e[1]
                    sentence = SRC.preprocess(e_src)
                    indexed = [SRC.vocab.stoi[tok] for tok in sentence]

                    sentence = Variable(torch.LongTensor([indexed]))
                    if opt.device == 0:
                        sentence = sentence.cuda()

                    try:
                        sentence = classify(sentence, model, SRC, TRG)
                    except Exception as e:
                        continue

                    if simple_em(sentence, e_tgt):
                        val_acc += 1
                    val_success += 1
            else:
                raise NotImplementedError(
                    "no beam search support for CoS-E or other tasks")

            if val_success == 0:
                val_success = 1
            val_acc = val_acc / val_success
            print('validation beam accuracy:', round(val_acc * 100, 2))
            if opt.wandb:
                wandb.log(
                    {'validation beam accuracy': round(val_acc * 100, 2)})
                if opt.task == 'e_snli_r':
                    wandb.log({
                        'validation beam label accuracy':
                        round(
                            sum(beam_label_acc) / len(beam_label_acc) * 100, 2)
                    })
                    print(
                        'validation beam label accuracy:',
                        round(
                            sum(beam_label_acc) / len(beam_label_acc) * 100,
                            2))
            opt.scheduler.step(val_acc)
            print('-' * 10, '\n')

        if epoch == opt.epochs - 1 and opt.do_test:
            model.eval()
            test_data = zip_io_data(opt.data_path + '/test')
            test_beam_predictions = ''
            test_fwd_predictions = ''
            beam_acc, fwd_acc = [], []
            beam_label_acc, fwd_label_acc = [], []

            if not os.path.exists(opt.output_dir):
                os.makedirs(opt.output_dir)

            if opt.task == 'e_snli_r':
                test_labels = zip_io_data(opt.label_path + '/test')

            for j, e in enumerate(test_data[:opt.n_test]):
                if (j + 1) % 10000 == 0:
                    print(round(j / len(test_data) * 100, 2),
                          '% complete with testing')
                e_src, e_tgt = e[0], e[1]
                if opt.task == 'e_snli_r':
                    e_label = test_labels[j][1]

                indexed = []
                sentence = SRC.preprocess(e_src)
                pass_bool = False
                for tok in sentence:
                    if SRC.vocab.stoi[tok] != 0:
                        indexed.append(SRC.vocab.stoi[tok])
                    else:
                        pass_bool = True
                        break
                if pass_bool:
                    continue

                sentence = Variable(torch.LongTensor([indexed]))
                if opt.device == 0:
                    sentence = sentence.cuda()

                if opt.val_forward_pass:
                    src = sentence
                    trg = torch.ones(
                        (1, opt.max_strlen)).type(torch.LongTensor).cuda()
                    trg[:, 0] = trg[:, 0] * 2

                    trg_input = trg[:, :-1]
                    src_mask, trg_mask = create_hard_masks(src, trg_input, opt)
                    preds = model(src, trg_input, src_mask, trg_mask)
                    pred_tokens = torch.argmax(preds, dim=-1)
                    ys = [TRG.vocab.stoi[tok] for tok in e_tgt.split()
                          ] + [3]  # TODO: remove hardcode of EOS (3)
                    pred_tok = pred_tokens[0].tolist()

                    if pred_tok[:len(ys)] == ys:
                        fwd_acc.append(1)
                    else:
                        fwd_acc.append(0)

                    pred_nl = ' '.join(
                        [TRG.vocab.itos[tok] for tok in pred_tok])
                    if ' <eos>' in pred_nl:
                        pred_nl = pred_nl[:pred_nl.index(
                            ' <eos>'
                        )]  # TODO: take this out for differentiable version
                    if ' .' in pred_nl:
                        pred_nl = pred_nl[:pred_nl.index(
                            ' .'
                        ) + 2]  # TODO: take this out for differentiable version
                    test_fwd_predictions += pred_nl + '\n'

                    if opt.task == 'e_snli_r':
                        rationale = []
                        for word in pred_nl.split():
                            if word in opt.classifier_SRC.vocab.stoi.keys():
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi[word])
                            else:
                                rationale.append(
                                    opt.classifier_SRC.vocab.stoi['<unk>'])

                        rationale = torch.Tensor([rationale]).type(
                            torch.LongTensor).cuda()
                        pred_label = classify(rationale, opt.classifier,
                                              opt.classifier_SRC,
                                              opt.classifier_TRG)

                        if pred_label == e_label:  # latter is true label
                            fwd_label_acc.append(1)
                        else:
                            fwd_label_acc.append(0)

                try:
                    sentence = beam_search(sentence, model, SRC, TRG, opt)
                except Exception as e:
                    continue
                try:
                    beam_acc.append(simple_em(sentence, e_tgt))
                    test_beam_predictions += sentence + '\n'
                except Exception as e:
                    test_beam_predictions += '\n'
                    continue

                if opt.task == 'e_snli_r':
                    rationale = []
                    for word in sentence.split():
                        if word in opt.classifier_SRC.vocab.stoi.keys():
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi[word])
                        else:
                            rationale.append(
                                opt.classifier_SRC.vocab.stoi['<unk>'])

                    rationale = torch.Tensor([rationale
                                              ]).type(torch.LongTensor).cuda()
                    pred_label = classify(rationale, opt.classifier,
                                          opt.classifier_SRC,
                                          opt.classifier_TRG)

                    if pred_label == e_label:
                        beam_label_acc.append(1)
                    else:
                        beam_label_acc.append(0)

            # beam search logging
            if opt.wandb:
                wandb.log({
                    'test beam accuracy':
                    round(sum(beam_acc) / len(beam_acc) * 100, 2)
                })
            print('test beam accuracy:',
                  round(sum(beam_acc) / len(beam_acc) * 100, 2))
            with open(opt.output_dir + '/test_beam_generations.txt',
                      'w',
                      encoding='utf-8') as f:
                f.write(test_beam_predictions)

            # fwd pass logging
            if opt.val_forward_pass:
                print('test forward accuracy:',
                      round(sum(fwd_acc) / len(fwd_acc) * 100, 2), '%')
                if opt.wandb:
                    wandb.log({
                        'test forward accuracy':
                        round(sum(fwd_acc) / len(fwd_acc) * 100, 2)
                    })
                with open(opt.output_dir + '/test_fwd_generations.txt',
                          'w',
                          encoding='utf-8') as f:
                    f.write(test_fwd_predictions)

            # e_snli_r logging
            if opt.task == 'e_snli_r':
                if opt.wandb:
                    wandb.log({
                        'test beam label accuracy':
                        round(
                            sum(beam_label_acc) / len(beam_label_acc) * 100, 2)
                    })
                print(
                    'test beam label accuracy:',
                    round(sum(beam_label_acc) / len(beam_label_acc) * 100, 2))

                if opt.val_forward_pass:
                    if opt.wandb:
                        wandb.log({
                            'test forward label accuracy':
                            round(
                                sum(fwd_label_acc) / len(fwd_label_acc) * 100,
                                2)
                        })
                    print(
                        'test forward label accuracy:',
                        round(
                            sum(fwd_label_acc) / len(fwd_label_acc) * 100, 2))
Exemple #17
0
def test(opt):

    transform = transforms.Compose([
        transforms.CenterCrop(opt.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    vocab = Vocabulary()

    vocab.load(opt.vocab)

    data_loader = get_loader_test(opt.data_test,
                                  vocab,
                                  transform,
                                  opt.batch_size,
                                  shuffle=False,
                                  attribute_len=opt.attribute_len)

    list_of_refs = load_ori_token_data_new(opt.data_test)

    model = get_model(opt, load_weights=True)

    count = 0

    hypotheses = {}

    model.eval()

    for batch in tqdm(data_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):

        image0, image1, image0_attribute, image1_attribute = map(
            lambda x: x.to(opt.device), batch)

        hyp = beam_search(image0, image1, model, opt, vocab, image0_attribute,
                          image1_attribute)
        #         hyp = greedy_search(image1.to(device), image2.to(device), model, opt, vocab)

        hyp = hyp.split("<end>")[0].strip()

        hypotheses[count] = ["it " + hyp]

        count += 1

    # =================================================
    # Set up scorers
    # =================================================
    print('setting up scorers...')
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        # (Meteor(),"METEOR"),
        (Rouge(), "ROUGE_L"),
        # (Cider(), "CIDEr"),
        (Cider(), "CIDEr"),
        (CiderD(), "CIDEr-D")
        # (Spice(), "SPICE")
    ]

    for scorer, method in scorers:
        print('computing %s score...' % (scorer.method()))
        score, scores = scorer.compute_score(list_of_refs, hypotheses)
        if type(method) == list:
            for sc, scs, m in zip(score, scores, method):
                # self.setEval(sc, m)
                # self.setImgToEvalImgs(scs, gts.keys(), m)
                print("%s: %0.3f" % (m, sc))
        else:
            # self.setEval(score, method)
            # self.setImgToEvalImgs(scores, gts.keys(), method)
            print("%s: %0.3f" % (method, score))

    for i in range(len(hypotheses)):
        ref = {i: list_of_refs[i]}
        hyp = {i: hypotheses[i]}
        print(ref)
        print(hyp)
        for scorer, method in scorers:
            print('computing %s score...' % (scorer.method()))
            score, scores = scorer.compute_score(ref, hyp)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    # self.setEval(sc, m)
                    # self.setImgToEvalImgs(scs, gts.keys(), m)
                    print("%s: %0.3f" % (m, sc))
            else:
                # self.setEval(score, method)
                # self.setImgToEvalImgs(scores, gts.keys(), method)
                print("%s: %0.3f" % (method, score))