Пример #1
0
def validate(args, device_id, pt, step, tokenizer):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args, load_dataset(args, 'dev', shuffle=False),
                                        args.batch_size, device,
                                        shuffle=False, is_test=False)

    symbols = {'BOS': tokenizer.convert_tokens_to_ids('<s>'), 'EOS': tokenizer.convert_tokens_to_ids('</s>'),
               'PAD': tokenizer.convert_tokens_to_ids('[PAD]')}

    valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #2
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    #if (pt != ''):
    if not (args.test_from):
        test_from = pt

    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)

    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)

    tokenizer = BertTokenizer.from_pretrained(
        '../ETRI_koBERT/003_bert_eojeol_pytorch/vocab.txt',
        do_lower_case=False,
        cache_dir=args.temp_dir)

    if not args.share_emb:
        tokenizer = add_tokens(tokenizer)

    symbols = {
        'BOS': tokenizer.vocab['<S>'],
        'EOS': tokenizer.vocab['<T>'],
        'PAD': tokenizer.vocab['[PAD]']
    }
    # symbols = {'BOS': tokenizer.vocab['[BOS]'], 'EOS': tokenizer.vocab['[EOS]'],
    #            'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[EOQ]']}
    # symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
    #            'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}

    # print(tokenizer.vocab_size)
    # print([(key, value) for key, value in tokenizer.vocab.items()][-10:])
    # exit()
    valid_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          train=False,
                          device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #3
0
def test_ext(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)
    if args.ext_sum_dec:
        model = SentenceTransformer(args, device, checkpoint, sum_or_jigsaw=0)
    else:
        model = ExtSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    trainer = build_trainer(args, device_id, model, None)
    trainer.test(test_iter, step)
Пример #4
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpu == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(args.vocab_path)
    word_padding_idx = spm.PieceToId('<PAD>')
    vocab_size = len(spm)
    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        {'PAD': word_padding_idx},
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter)
    trainer._report_step(0, step, valid_stats=stats)
    return stats.xent()
Пример #5
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    if pt != "":
        test_from = pt
    else:
        test_from = args.test_from
    logger.info("Loading checkpoint from %s" % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint["opt"])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args, device, load_pretrained_bert=True, bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(
        args,
        load_dataset(args, "valid", shuffle=False),
        args.batch_size,
        device,
        shuffle=False,
        is_test=False,
    )
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #6
0
 def train_iter_fct():
     return data_loader.Dataloader(args,
                                   load_dataset(args, 'train',
                                                shuffle=True),
                                   args.batch_size,
                                   device,
                                   shuffle=True)
Пример #7
0
 def val_iter_fct():
     return data_loader.Dataloader(args,
                                   load_dataset(args, 'val', shuffle=False),
                                   args.test_batch_size,
                                   device,
                                   shuffle=False,
                                   is_test=True)
Пример #8
0
def test_text_abs(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)

    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {'BOS': 1, 'EOS': 2, 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': 3}
    predictor = build_predictor(args, tokenizer, symbols, model, logger)
    predictor.translate(test_iter, step)
Пример #9
0
def test(args, device_id, pt, step):

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config)
    model.load_cp(checkpoint)
    model.eval()

    test_iter =data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                  args.batch_size, device,
                                  shuffle=False, is_test=True)
    trainer = build_trainer(args, device_id, model, None)
    trainer.test(test_iter,step)
Пример #10
0
def test_ext(args, device_id, pt, step):
    if device_id == -1:
        device = "cpu"
    else:
        device = "cuda"
    logger.info('Device ID %s' % ','.join(map(str, device_id)))
    logger.info('Device %s' % device)
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.load_model
    logger.info('Loading model_checkpoint from %s' % test_from)
    model_checkpoint = torch.load(test_from,
                                  map_location=lambda storage, loc: storage)
    args.doc_classifier = model_checkpoint['opt'].doc_classifier
    args.nbr_class_neurons = model_checkpoint['opt'].nbr_class_neurons
    model = Ext_summarizer(args, device, model_checkpoint)

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False)
    trainer = trainer = build_trainer(args, device_id, model, None, None, None)

    for ref_patents, summaries, output_probas, prediction_contradiction, str_context in trainer.test(
            test_iter):
        yield ref_patents, summaries, output_probas, prediction_contradiction, str_context
Пример #11
0
def test_text_abs(args, device_id, pt, step, tokenizer):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)

    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                       args.test_batch_size, device,
                                       shuffle=False, is_test=True)
    symbols = {'BOS': tokenizer.convert_tokens_to_ids('<s>'), 'EOS': tokenizer.convert_tokens_to_ids('</s>'),
               'PAD': tokenizer.convert_tokens_to_ids('[PAD]')}
    predictor = build_predictor(args, tokenizer, symbols, model, logger)
    predictor.translate(test_iter, step)
Пример #12
0
def test_abs(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if pt != '':
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)

    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)
    symbols, tokenizer = get_symbol_and_tokenizer(args.encoder, args.temp_dir)
    model = AbsSummarizer(args, device, checkpoint, symbols=symbols)
    model.eval()

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True,
                                       tokenizer=tokenizer)

    predictor = build_predictor(args, tokenizer, symbols, model, logger)
    predictor.translate(test_iter, step)
Пример #13
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False),
                                        args.batch_size, device,
                                        shuffle=False, is_test=False)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
    symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
               'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}

    valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #14
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = ExtSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #15
0
    def gen_features_vector(self, step=None):
        if not step:
            try:
                step = int(self.args.test_from.split('.')[-2].split('_')[-1])
            except IndexError:
                step = 0

        logger.info('Loading checkpoint from %s' % self.args.test_from)
        checkpoint = torch.load(self.args.test_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if k in self.model_flags:
                setattr(self.args, k, opt[k])

        config = BertConfig.from_json_file(self.args.bert_config_path)
        model = model_builder.Summarizer(self.args,
                                         self.device,
                                         load_pretrained_bert=False,
                                         bert_config=config)
        model.load_cp(checkpoint)
        model.eval()
        # logger.info(model)
        trainer = build_trainer(self.args, self.device_id, model, None)
        test_iter = data_loader.DataLoader(self.args,
                                           data_loader.load_dataset(
                                               self.args,
                                               'test',
                                               shuffle=False),
                                           self.args.batch_size,
                                           self.device,
                                           shuffle=False,
                                           is_test=True)
        trainer.gen_features_vector(test_iter, step)
Пример #16
0
    def validate(self, step):

        logger.info('Loading checkpoint from %s' % self.args.validate_from)
        checkpoint = torch.load(self.args.validate_from,
                                map_location=lambda storage, loc: storage)

        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if k in self.model_flags:
                setattr(self.args, k, opt[k])
        print(self.args)

        config = BertConfig.from_json_file(self.args.bert_config_path)
        model = model_builder.Summarizer(self.args,
                                         self.device,
                                         load_pretrained_bert=False,
                                         bert_config=config)
        model.load_cp(checkpoint)
        model.eval()

        valid_iter = data_loader.DataLoader(self.args,
                                            data_loader.load_dataset(
                                                self.args,
                                                'valid',
                                                shuffle=False),
                                            self.args.batch_size,
                                            self.device,
                                            shuffle=False,
                                            is_test=False)
        trainer = build_trainer(self.args, self.device_id, model, None)
        stats = trainer.validate(valid_iter, step)
        return stats.xent()
Пример #17
0
def test_text_abs(args, device_id, pt, step, predictor):
    start_t= time.time()

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                       args.test_batch_size, device,
                                       shuffle=False, is_test=True)
    return predictor.translate(test_iter, step, args.report_rouge), str(time.time()-start_t)
Пример #18
0
 def valid_iter():
     return data_loader.Dataloader(args,
                                   load_dataset(args, 'test',
                                                shuffle=False),
                                   args.batch_size,
                                   device,
                                   shuffle=False,
                                   is_test=True)
Пример #19
0
 def train_iter(self):
     return data_loader.DataLoader(self.args,
                                   data_loader.load_dataset(self.args,
                                                            'train',
                                                            shuffle=True),
                                   self.args.batch_size,
                                   self.device,
                                   shuffle=True,
                                   is_test=False)
Пример #20
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)

    parser = argparse.ArgumentParser()
    parser.add_argument('--bpe-codes',
                        default="/content/PhoBERT_base_transformers/bpe.codes",
                        required=False,
                        type=str,
                        help='path to fastBPE BPE')
    args1, unknown = parser.parse_known_args()
    bpe = fastBPE(args1)

    # Load the dictionary
    vocab = Dictionary()
    vocab.add_from_file("/content/PhoBERT_base_transformers/dict.txt")

    tokenizer = bpe
    symbols = {
        'BOS': vocab.indices['[unused0]'],
        'EOS': vocab.indices['[unused1]'],
        'PAD': vocab.indices['[PAD]'],
        'EOQ': vocab.indices['[unused2]']
    }

    valid_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          train=False,
                          device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #21
0
 def train_iter_fct():
     return data_loader.Dataloader(
         args,
         load_dataset(args, "train", shuffle=True),
         args.batch_size,
         device,
         shuffle=True,
         is_test=False,
     )
Пример #22
0
def test_abs(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)

    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    if (args.bert_model == 'bert-base-multilingual-cased'):
        tokenizer = BertTokenizer.from_pretrained(
            'bert-base-multilingual-cased',
            do_lower_case=False,
            cache_dir=args.temp_dir)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                                  do_lower_case=True,
                                                  cache_dir=args.temp_dir)
        print(len(tokenizer.vocab))
        if (len(tokenizer.vocab) == 31748):
            f = open(args.bert_model + "/vocab.txt", "a")
            f.write(
                "\n[unused1]\n[unused2]\n[unused3]\n[unused4]\n[unused5]\n[unused6]\n[unused7]"
            )
            f.close()
            tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                                      do_lower_case=True)
        print(len(tokenizer.vocab))

    symbols = {
        'BOS': tokenizer.vocab['[unused1]'],
        'EOS': tokenizer.vocab['[unused2]'],
        'PAD': tokenizer.vocab['[PAD]'],
        'EOQ': tokenizer.vocab['[unused3]']
    }

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    predictor = build_predictor(args, tokenizer, symbols, model, logger)
    predictor.translate(test_iter, step)
Пример #23
0
 def train_iter_fct():
     # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True)
     return data_loader.Dataloader(args,
                                   load_dataset(args, 'train',
                                                shuffle=True),
                                   {'PAD': word_padding_idx},
                                   args.batch_size,
                                   device,
                                   shuffle=True,
                                   is_test=False)
Пример #24
0
def baseline(args, cal_lead=False, cal_oracle=False):
    test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                       args.batch_size, 'cpu',
                                       shuffle=False, is_test=True)

    trainer = build_trainer(args, '-1', None, None, None)
    #
    if (cal_lead):
        trainer.test(test_iter, 0, cal_lead=True)
    elif (cal_oracle):
        trainer.test(test_iter, 0, cal_oracle=True)
Пример #25
0
 def train_iter_fct():
     if args.is_debugging:
         print("YES it is debugging")
         return data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=False)
     else:
         return data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'train',
                                                    shuffle=True),
                                       args.batch_size,
                                       device,
                                       shuffle=True,
                                       is_test=False)
Пример #26
0
    def test_model(self, corpus_type, topn=0):
        model_file = _top_model(self.model_path, n=topn)
        logger.info('Test GuidAbs model %s' % model_file)
        fn_touch = path.join(
            self.model_path,
            'finished_%s.test_guidabs_model%s' % (corpus_type, topn))
        if path.exists(fn_touch):
            return
        args = self._build_abs_args()
        args.mode = 'test'
        args.bert_data_path = path.join(self.data_path, 'cnndm')
        args.model_path = self.model_path
        args.log_file = path.join(
            self.model_path,
            'test_abs_bert_cnndm_%s_top%s.log' % (corpus_type, topn))
        args.result_path = path.join(self.model_path,
                                     'cnndm_%s_top%s' % (corpus_type, topn))
        init_logger(args.log_file)
        step = int(model_file.split('.')[-2].split('_')[-1])
        # load abs model
        step_abs = int(model_file.split('.')[-2].split('_')[-1])
        checkpoint = torch.load(model_file,
                                map_location=lambda storage, loc: storage)
        model_abs = model_bld.AbsSummarizer(args, args.device, checkpoint)
        model_abs.eval()
        # init model testers
        tokenizer = BertTokenizer.from_pretrained(path.join(
            args.bert_model_path, model_abs.bert.model_name),
                                                  do_lower_case=True,
                                                  cache_dir=args.temp_dir)
        symbols = {
            'BOS': tokenizer.vocab['[unused0]'],
            'EOS': tokenizer.vocab['[unused1]'],
            'PAD': tokenizer.vocab['[PAD]'],
            'EOQ': tokenizer.vocab['[unused2]']
        }

        predictor = pred_abs.build_predictor(args, tokenizer, symbols,
                                             model_abs, logger)
        test_iter = data_ldr.Dataloader(args,
                                        data_ldr.load_dataset(args,
                                                              corpus_type,
                                                              shuffle=False),
                                        args.test_batch_size,
                                        args.device,
                                        shuffle=False,
                                        is_test=True,
                                        keep_order=True)

        avg_f1 = test_abs(logger, args, predictor, step_abs, test_iter)
        os.system('touch %s' % fn_touch)
        return avg_f1
Пример #27
0
    def predict(self):

        test_iter = data_loader.DataLoader(self.args,
                                           data_loader.load_dataset(
                                               self.args,
                                               'test',
                                               shuffle=False),
                                           self.args.batch_size,
                                           self.device,
                                           shuffle=False,
                                           is_test=True)
        trainer = build_trainer(self.args, self.device_id, self.model, None)
        trainer.predict(test_iter, self.step)
def baseline(args, cal_lead=False, cal_oracle=False):
    test_iter = data_loader.Dataloader(
        args,
        load_dataset(args, "test", shuffle=False),
        args.batch_size,
        "cpu",
        shuffle=False,
        is_test=True,
    )

    trainer = build_trainer(args, "-1", None, None, None)
    if cal_lead:
        trainer.test(test_iter, 0, cal_lead=True)
    elif cal_oracle:
        trainer.test(test_iter, 0, cal_oracle=True)
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    if pt != "":
        test_from = pt
    else:
        test_from = args.test_from
    logger.info("Loading checkpoint from %s" % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint["opt"])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(
        args,
        load_dataset(args, "valid", shuffle=False),
        args.batch_size,
        device,
        shuffle=False,
        is_test=False,
    )

    tokenizer = BertTokenizer.from_pretrained(
        "chinese_roberta_wwm_ext_pytorch",
        do_lower_case=True,
        cache_dir=args.temp_dir)
    symbols = {
        "BOS": tokenizer.vocab["[unused1]"],
        "EOS": tokenizer.vocab["[unused2]"],
        "PAD": tokenizer.vocab["[PAD]"],
        "EOQ": tokenizer.vocab["[unused3]"],
    }

    valid_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          train=False,
                          device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Пример #30
0
    def baseline(self, cal_lead=False, cal_oracle=False):
        test_iter = data_loader.DataLoader(self.args,
                                           data_loader.load_dataset(
                                               self.args,
                                               'test',
                                               shuffle=False),
                                           self.args.batch_size,
                                           self.device,
                                           shuffle=False,
                                           is_test=True)

        trainer = build_trainer(self.args, self.device_id, None, None)

        if cal_lead:
            trainer.test(test_iter, 0, cal_lead=True)
        elif cal_oracle:
            trainer.test(test_iter, 0, cal_oracle=True)