コード例 #1
0
def test_text_abs(args):

    logger.info('Loading checkpoint from %s' % args.test_from)
    device = "cpu" if args.visible_gpus == '-1' else "cuda"

    checkpoint = torch.load(args.test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)
    logger.info('Loading args inside test_text_abs %s' % args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.load_text(args, args.text_src, args.text_tgt,
                                      device)

    logger.info('test_iter is %s' % test_iter)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {
        'BOS': tokenizer.vocab['[unused0]'],
        'EOS': tokenizer.vocab['[unused1]'],
        'PAD': tokenizer.vocab['[PAD]'],
        'EOQ': tokenizer.vocab['[unused2]']
    }

    logger.info('symbols is %s' % symbols)
    predictor = build_predictor(args, tokenizer, symbols, model, logger)
    predictor.translate(test_iter, -1)
コード例 #2
0
ファイル: custom_train.py プロジェクト: dchandak99/LongSumm
def train_ext(args):
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    #TODO -> add ability to load model from chkpt
    if args.train_from != '':
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)

        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])

    else:
        checkpoint = None

    model = ExtSummarizer(args, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)

    train_iter = data_loader.load_text(args, args.text_src, args.text_tgt,
                                       device)

    for i in range(train_iter):
        print(i)
コード例 #3
0
def test_text_ext(args):
    logger.info('Loading checkpoint from %s' % args.test_from)
    checkpoint = torch.load(args.test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    device_id = 0 if device == "cuda" else -1

    model = ExtSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.load_text(args, args.text_src, args.text_tgt, device)

    trainer = build_trainer(args, device_id, model, None)
    trainer.test(test_iter, -1)
コード例 #4
0
    device_id = 0 if device == "cuda" else -1

    model = ExtSummarizer(args, device, checkpoint)
    model.eval()

    
    # load data_files
    # args.text_src and args.result_path change for every paper

    file_dir_papers = "N:/Organisatorisches/Bereiche_Teams/ID/03_Studenten/Korte/Newsletter/Automatic Text Summarization/PreSumm_dev/cnndm/papers/"
    file_dir_results = "N:/Organisatorisches/Bereiche_Teams/ID/03_Studenten/Korte/Newsletter/Automatic Text Summarization/PreSumm_dev/cnndm/results/"


    for filename in os.listdir(file_dir_papers):
        print(filename)

        print("Inference for ", filename)
        #change parameter for every trial
        args.text_src = file_dir_papers + filename
        resultname = filename.replace('.raw_src', '')
        args.result_path = file_dir_results + "result_" + resultname

        try:
            test_iter = data_loader.load_text(args, args.text_src, args.text_tgt, device)

            trainer = build_trainer(args, device_id, model, None)
            trainer.test(test_iter, -1)
        except:
            print("Encoding Error at file ", filename)