def inference():
    parser = argparse.ArgumentParser(description="Image Captioning Evaluation")
    parser.add_argument('--vocab_path', default='data/vocab.pickle', type=str)
    parser.add_argument('--img_path', default='data/test2017/', type=str)
    parser.add_argument('--test_visual_feature_path',
                        default='data/visual_feature_test.pickle',
                        type=str)
    parser.add_argument("--test_path", type=str, help="model path")
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--is_train', type=str, default=False)
    parser.add_argument('--eval_coco_idx_path',
                        default='data/test_coco_idx.npy',
                        type=str)
    parser.add_argument("--eval_path",
                        default='eval/',
                        type=str,
                        help="evaluation result path")
    parser.add_argument("--shuffle", default='False', type=str)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--max_sub_len', type=int, default=30)

    args = parser.parse_args()

    checkpoint = torch.load(os.path.join(args.test_path, 'model.ckpt'))
    eval_dataloader = get_eval_dataloader(args)
    translator = Translator(args, checkpoint)

    eval_result = translate(args, translator, eval_dataloader)

    mkdirp(args.eval_path)
    result_path = os.path.join(args.eval_path, start_time())
    mkdirp(result_path)

    filename = os.path.join(result_path, 'pred.jsonl')
    save_jsonl(eval_result, filename)
    logger.info("Save predict json file at {}".format(result_path))
Пример #2
0
def main():
    """Main Function"""

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model',
                        required=True,
                        help='Path to model weight file')
    parser.add_argument('-data_pkl',
                        required=True,
                        help='Pickle file with both instances and vocabulary.')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5)
    parser.add_argument('-max_seq_len', type=int, default=100)
    parser.add_argument('-no_cuda', action='store_true')

    # TODO: Translate bpe encoded files
    # parser.add_argument('-src', required=True,
    #                    help='Source sequence to decode (one line per sequence)')
    # parser.add_argument('-vocab', required=True,
    #                    help='Source sequence to decode (one line per sequence)')
    # TODO: Batch translation
    # parser.add_argument('-batch_size', type=int, default=30,
    #                    help='Batch size')
    # parser.add_argument('-n_best', type=int, default=1,
    #                    help="""If verbose is set, will output the n_best
    #                    decoded sentences""")

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    data = pickle.load(open(opt.data_pkl, 'rb'))
    SRC, TRG = data['vocab']['src'], data['vocab']['trg']
    opt.src_pad_idx = SRC.vocab.stoi[constants.PAD_WORD]
    opt.trg_pad_idx = TRG.vocab.stoi[constants.PAD_WORD]
    opt.trg_bos_idx = TRG.vocab.stoi[constants.BOS_WORD]
    opt.trg_eos_idx = TRG.vocab.stoi[constants.EOS_WORD]

    test_loader = Dataset(examples=data['test'],
                          fields={
                              'src': SRC,
                              'trg': TRG
                          })

    device = torch.device('cuda' if opt.cuda else 'cpu')
    translator = Translator(model=load_model(opt, device),
                            beam_size=opt.beam_size,
                            max_seq_len=opt.max_seq_len,
                            src_pad_idx=opt.src_pad_idx,
                            trg_pad_idx=opt.trg_pad_idx,
                            trg_bos_idx=opt.trg_bos_idx,
                            trg_eos_idx=opt.trg_eos_idx).to(device)

    unk_idx = SRC.vocab.stoi[SRC.unk_token]
    with open(opt.output, 'w') as f:
        for example in tqdm(test_loader,
                            mininterval=2,
                            desc='  - (Test)',
                            leave=False):
            # print(' '.join(example.src))
            src_seq = [
                SRC.vocab.stoi.get(word, unk_idx) for word in example.src
            ]
            pred_seq = translator.translate_sentence(
                torch.LongTensor([src_seq]).to(device))
            pred_line = ' '.join(TRG.vocab.itos[idx] for idx in pred_seq)
            pred_line = pred_line.replace(constants.BOS_WORD,
                                          '').replace(constants.EOS_WORD, '')
            # print(pred_line)
            f.write(pred_line.strip() + '\n')

    print('[Info] Finished.')
Пример #3
0
 def test_1000000_to_str(self):
     self.assertEqual("one million", Translator.num_to_string(1000000))
Пример #4
0
 def test_21121_to_str(self):
     self.assertEqual("twenty-one thousand one hundred twenty-one",
                      Translator.num_to_string(21121))
Пример #5
0
 def test_21000_to_str(self):
     self.assertEqual("twenty-one thousand",
                      Translator.num_to_string(21000))
Пример #6
0
 def test_4373_to_str(self):
     self.assertEqual("four thousand three hundred seventy-three",
                      Translator.num_to_string(4373))
Пример #7
0
 def test_1010_to_str(self):
     self.assertEqual("one thousand ten", Translator.num_to_string(1010))
Пример #8
0
 def test_1001_to_str(self):
     self.assertEqual("one thousand one", Translator.num_to_string(1001))
Пример #9
0
 def test_7_dig_to_str(self):
     self.assertEqual("seven", Translator.num_to_string(7))
Пример #10
0
 def test_negative_number_raise_error(self):
     try:
         Translator.num_to_string(-55)
     except ValueError:
         self.assertTrue(True)
Пример #11
0
 def test_not_int_raise_error(self):
     try:
         Translator.num_to_string(111.0)
     except TypeError:
         self.assertTrue(True)
Пример #12
0
 def test_2_dig_to_str(self):
     self.assertEqual("two", Translator.num_to_string(2))
Пример #13
0
 def test_1000111_to_str(self):
     self.assertEqual("one million one hundred eleven",
                      Translator.num_to_string(1000111))
Пример #14
0
 def test_101_to_str(self):
     self.assertEqual("one hundred one", Translator.num_to_string(101))
Пример #15
0
 def test_10_to_str(self):
     self.assertEqual("ten", Translator.num_to_string(10))
Пример #16
0
 def test_987_to_str(self):
     self.assertEqual("nine hundred eighty-seven",
                      Translator.num_to_string(987))
Пример #17
0
 def test_11_to_str(self):
     self.assertEqual("eleven", Translator.num_to_string(11))
Пример #18
0
 def test_0_dig_to_str(self):
     self.assertEqual("zero", Translator.num_to_string(0))
Пример #19
0
 def test_12_to_str(self):
     self.assertEqual("twelve", Translator.num_to_string(12))
Пример #20
0
 def test_1110_to_str(self):
     self.assertEqual("one thousand one hundred ten",
                      Translator.num_to_string(1110))
Пример #21
0
 def test_17_to_str(self):
     self.assertEqual("seventeen", Translator.num_to_string(17))
Пример #22
0
 def test_10000_to_str(self):
     self.assertEqual("ten thousand", Translator.num_to_string(10000))
Пример #23
0
 def test_21_to_str(self):
     self.assertEqual("twenty-one", Translator.num_to_string(21))
Пример #24
0
 def test_1_dig_to_str(self):
     self.assertEqual("one", Translator.num_to_string(1))
Пример #25
0
 def test_20_to_str(self):
     self.assertEqual("twenty", Translator.num_to_string(20))
Пример #26
0
 def test_111111_to_str(self):
     self.assertEqual("one hundred eleven thousand one hundred eleven",
                      Translator.num_to_string(111111))
Пример #27
0
 def test_49_to_str(self):
     self.assertEqual("forty-nine", Translator.num_to_string(49))
Пример #28
0
def main():
    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-data_pkl',
                        help='Pickle file with both instances and vocabulary.')
    parser.add_argument('-experiment_name', required=True)
    parser.add_argument('-model_numbers', nargs='+', type=int, required=True)
    parser.add_argument('-output', default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=4)
    parser.add_argument('-batch_size', type=int, default=1)
    parser.add_argument('-max_seq_len', type=int, default=130)
    parser.add_argument('-alpha', type=float, default=0.6)
    parser.add_argument('-device', choices=['cpu', 'cuda'], default='cuda')
    parser.add_argument('-langs', nargs='+', required=True)

    args = parser.parse_args()
    device = torch.device(args.device)

    for i, number in enumerate(args.model_numbers):
        model_name = f'{args.experiment_name}-{number}.chkpt'
        if i == 0:
            model = load_model(model_name, device)
        else:
            temp_model = load_model(model_name, device)
            temp_params = dict(temp_model.named_parameters())
            for name, param in model.named_parameters():
                temp_params[name].data.copy_(param.data + temp_params[name].data)
            model.load_state_dict(temp_params)
    for _, param in model.named_parameters():
        param.data.copy_(param.data / len(args.model_numbers))

    args.data_reduce_size = -1
    test_loader, total_tokens, SRC, TRG = load_data_dict(
        experiment_name=args.experiment_name,
        corpora_type='dev',
        langs=args.langs,
        args=args,
        device=device
    )

    args.src_pad_idx = SRC.vocab.stoi[PAD_WORD]
    args.trg_pad_idx = TRG.vocab.stoi[PAD_WORD]
    args.trg_bos_idx = TRG.vocab.stoi[BOS_WORD]
    args.trg_eos_idx = TRG.vocab.stoi[EOS_WORD]
    args.trg_unk_idx = TRG.vocab.stoi[UNK_WORD]
    translator = Translator(
        model=model,
        beam_size=args.beam_size,
        max_seq_len=args.max_seq_len,
        src_pad_idx=args.src_pad_idx,
        trg_pad_idx=args.trg_pad_idx,
        trg_bos_idx=args.trg_bos_idx,
        trg_eos_idx=args.trg_eos_idx,
        device=device,
        alpha=args.alpha
    ).to(device)

    total_bleu, total_sentence = 0, 0
    bleu_score = 0
    for example in tqdm(test_loader, mininterval=20, desc='  - (Test)', leave=False, total=total_tokens//args.batch_size):
        source_sequence = patch_source(example.src).to(device)
        target_sequence, gold = map(lambda x: x.to(device), patch_target(example.trg))
        # prediction = model(source_sequence,target_sequence[:,:2])
        # output = model.generator(prediction)
        # print(torch.argmax(output[0],dim=1))
        pred_seq, ends = translator.translate_sentence(source_sequence)
        # pred_seq = translator.greedy_decoder(source_sequence)

        bleu = translation_score(pred_seq, ends, gold, TRG)
        total_bleu += bleu[0]
        total_sentence += bleu[1]
        bleu_score = (total_bleu / total_sentence) * 100
        print('\n', bleu_score)
    bleu_score = (total_bleu / total_sentence) * 100
    print('BLEU score for model: ', bleu_score)
Пример #29
0
def run_one_epoch(model,
                  data,
                  args,
                  device,
                  TRG,
                  total_tokens,
                  optimizer=None,
                  smoothing=False,
                  bleu=False):
    ''' Epoch operation in training phase'''
    training = optimizer is not None
    total_loss, total_num_words, total_num_correct_words, total_bleu, total_sentence = 0, 0, 0, 0, 0
    if training:
        desc = '  - (Training)   '
        model.train()
    else:
        desc = '  - (Validation) '
        model.eval()
    if bleu:
        translator = Translator(model=model,
                                beam_size=args.beam_size,
                                max_seq_len=args.max_seq_len,
                                src_pad_idx=args.src_pad_idx,
                                trg_pad_idx=args.trg_pad_idx,
                                trg_bos_idx=args.trg_bos_idx,
                                trg_eos_idx=args.trg_eos_idx,
                                device=device)
        translator = TranslatorParallel(translator)
        # translator = CustomDataParallel(translator)

    for batch in tqdm(data,
                      mininterval=10,
                      desc=desc,
                      leave=False,
                      total=total_tokens // args.batch_size):
        # prepare data
        source_sequence = patch_source(batch.src).to(device)
        target_sequence, gold = map(lambda x: x.to(device),
                                    patch_target(batch.trg))
        # source_sequence = nn.DataParallel(patch_source(batch.src))
        # target_sequence, gold = map(lambda x: nn.DataParallel(x), patch_target(batch.trg))
        # forward pass
        if training:
            optimizer.zero_grad()
        if bleu:
            pred_seq, ends = translator.translate_sentence(source_sequence)
            score = translation_score(pred_seq, ends, gold, TRG)
            total_bleu += score[0]
            total_sentence += score[1]
        prediction = model(source_sequence, target_sequence)
        output = model.generator(prediction)
        output = output.view(-1, output.size(-1))
        # backward pass and update parameters
        loss, num_correct, num_words = calculate_metrics(
            output,
            gold.contiguous().view(-1),
            args.trg_pad_idx,
            smoothing=smoothing)
        if training:
            loss.backward()
            optimizer.step_and_update_lr()
        total_num_words += num_words
        total_num_correct_words += num_correct
        total_loss += loss.item()
    if total_num_words != 0:
        loss_per_word = total_loss / total_num_words
        accuracy = total_num_correct_words / total_num_words
        if bleu:
            bleu_score = total_bleu / total_sentence
            print('current BLEU score: ', bleu_score)
        else:
            bleu_score = None
        return loss_per_word, accuracy, bleu_score
    else:
        return 0, 0, None