コード例 #1
0
ファイル: train.py プロジェクト: peter-xbs/SUMO
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpu == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(args.vocab_path)
    word_padding_idx = spm.PieceToId('<PAD>')
    vocab_size = len(spm)
    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        {'PAD': word_padding_idx},
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter)
    trainer._report_step(0, step, valid_stats=stats)
    return stats.xent()
コード例 #2
0
ファイル: train.py プロジェクト: domgoodwin/in-the-loop
def test(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = Summarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    trainer = build_trainer(args, device_id, model, None)
    trainer.test(test_iter, step)
コード例 #3
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    def train_iter_fct():
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                                 shuffle=True, is_test=False)

    # temp change for reducing gpu memory
    model = Summarizer(args, device, load_pretrained_bert=True)
    #config = BertConfig.from_json_file(args.bert_config_path)
    #model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config)

    if args.train_from != '': #train another part from beginning
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint, strict=False)
        # keys can not match
        #optim = model_builder.build_optim(args, model, checkpoint)
        optim = model_builder.build_optim(args, model, None)
        if args.model_name == "ctx" and args.fix_scorer:
            logger.info("fix the saliency scorer")
            #for param in self.bert.model.parameters():
            for param in model.parameters():
                param.requires_grad = False

            if hasattr(model.encoder, "selector") and model.encoder.selector is not None:
                for param in model.encoder.selector.parameters():
                    param.requires_grad = True
            #print([p for p in model.parameters() if p.requires_grad])
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    _, neg_valid_loss = trainer.train(train_iter_fct, args.train_steps)
    while len(neg_valid_loss) > 0:
        #from 3rd to 2nd to 1st.
        neg_loss, saved_model = heapq.heappop(neg_valid_loss)
        print(-neg_loss, saved_model)
        step = int(saved_model.split('.')[-2].split('_')[-1])
        test(args, device_id, saved_model, step)
    logger.info("Finish!")
コード例 #4
0
def load_model(model_type):
    checkpoint = torch.load(f'checkpoints/{model_type}.pt', map_location='cpu')
    config = BertConfig.from_json_file('models/config.json')
    model = Summarizer(args=None,
                       device="cpu",
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    return model
コード例 #5
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    logger.info('data begin load %s',
                time.strftime('%H:%M:%S', time.localtime(time.time())))

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    logger.info('data end load %s',
                time.strftime('%H:%M:%S', time.localtime(time.time())))
    model = Summarizer(args, device, load_pretrained_bert=True)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
コード例 #6
0
ファイル: train.py プロジェクト: parker84/BertSum
def train(args, device_id):
    # import ipdb; ipdb.set_trace()
    # import pdb; pdb.set_trace()
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        print("device_id = " + str(device_id))
        torch.cuda.set_device(device_id)
        print("device set")
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = Summarizer(args, device, load_pretrained_bert=True)
    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    with comet_experiment.train():
        trainer.train(train_iter_fct, args.train_steps)
コード例 #7
0
ファイル: train.py プロジェクト: summarizers/BertSum
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    logger.info("Device ID %d" % device_id)
    logger.info("Device %s" % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(
            args,
            load_dataset(args, "train", shuffle=True),
            args.batch_size,
            device,
            shuffle=True,
            is_test=False,
        )

    model = Summarizer(args, device, load_pretrained_bert=True)
    if args.train_from != "":
        logger.info("Loading checkpoint from %s" % args.train_from)
        checkpoint = torch.load(
            args.train_from, map_location=lambda storage, loc: storage
        )
        opt = vars(checkpoint["opt"])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
コード例 #8
0
ファイル: train.py プロジェクト: NoraH3/rickmorty-nlp
def train(args):
    device = "cpu"

    model = Summarizer(args, device, load_pretrained_bert=True)
    print(model)

    train_data, train_results, test_data, test_results = preprocess_data.get_data(
    )

    ids = torch.tensor(train_data[0].ids)
    model(ids)  # TODO
コード例 #9
0
ファイル: train.py プロジェクト: summarizers/BertSum
def test(args, device_id, pt, step):

    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    if pt != "":
        test_from = pt
    else:
        test_from = args.test_from
    logger.info("Loading checkpoint from %s" % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint["opt"])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args, device, load_pretrained_bert=True, bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(
        args,
        load_dataset(args, "test", shuffle=False),
        args.batch_size,
        device,
        shuffle=False,
        is_test=True,
    )
    trainer = build_trainer(args, device_id, model, None)
    trainer.test(test_iter, step)
コード例 #10
0
ファイル: train.py プロジェクト: johndpope/BertSum
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
コード例 #11
0
def test(args, device_id, pt, step):

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config)
    #model.load_cp(checkpoint) #TODO: change it back to strict=True
    model.load_cp(checkpoint, strict=False)
    model.eval()

    trainer = build_trainer(args, device_id, model, None)
    #if False:
    #args.block_trigram = True
    #if not args.only_initial or args.model_name == 'seq':
    if args.model_name == 'base':
        test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                      args.batch_size, device,
                                      shuffle=False, is_test=True)
        trainer.test(test_iter,step)
    else:
        test_iter =data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                  args.batch_size, device,
                                  shuffle=False, is_test=True)
        trainer.iter_test(test_iter,step)
コード例 #12
0
ファイル: train.py プロジェクト: tuhinjubcse/BertSum
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True)
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                                 shuffle=True, is_test=False)

    model = Summarizer(args, device, checkpoint)
    # optim = model_builder.build_optim(args, model.reg, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    # optim = BertAdam()
    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    #
    trainer.train(train_iter_fct, args.train_steps)
コード例 #13
0
def save_state_dict(args,  device_id):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    ckpt = "cnndm_bertsum_classifier_best.pt"

    logger.info('Loading checkpoint from %s' % ckpt)
    checkpoint = torch.load(ckpt, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config)
    model.load_cp(checkpoint)
    model.eval()

    # save state_dict
    torch.save(model.state_dict(), "weights.pt")
コード例 #14
0
ファイル: train.py プロジェクト: peter-xbs/SUMO
def train(args, device_id):
    init_logger(args.log_file)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(args.vocab_path)
    word_padding_idx = spm.PieceToId('<PAD>')
    vocab_size = len(spm)

    def train_iter_fct():
        # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True)
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      {'PAD': word_padding_idx},
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    #
    trainer.train(train_iter_fct, args.train_steps)
コード例 #15
0
def getTranslator():
    # set up model

    device = "cpu"

    logger.info('Loading checkpoint from %s' % args.test_from)
    checkpoint = torch.load(args.test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])

    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    return build_trainer(args, -1, model, None)
コード例 #16
0
    device = "cuda"
    device_id = -1

    if args.seed:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    def train_loader_fct():
        return DataLoader(args.data_folder,
                          512,
                          args.batch_size,
                          device=device,
                          shuffle=True)

    model = Summarizer(device, args)
    if args.train_from != '':
        print('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = dict(checkpoint['opt'])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
        model.load_cp(checkpoint['model'])
        optim = build_optim(args, model, checkpoint)
    else:
        optim = build_optim(args, model, None)

    trainer = build_trainer(args, model, optim)
    trainer.train(train_loader_fct, args.train_steps)
コード例 #17
0
def abs_train(args, device_id, pt, recover_all=False):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    # load extractive model
    if pt != None:
        test_from = pt
        logger.info('Loading checkpoint from %s' % test_from)
        checkpoint = torch.load(test_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        print(args)
    config = BertConfig.from_json_file(args.bert_config_path)
    # build extractive model
    model = Summarizer(args,
                       device_id,
                       load_pretrained_bert=False,
                       bert_config=config)

    # decoder
    decoder = Decoder(model.bert.model.config.hidden_size // 2,
                      model.bert.model.config.vocab_size,
                      model.bert.model.config.hidden_size,
                      model.bert.model.embeddings,
                      device_id)  # 2*hidden_dim = embedding_size

    # get initial s_t
    s_t_1 = get_initial_s(model.bert.model.config.hidden_size, device_id)
    if recover_all:
        model.load_cp(checkpoint)
        s_t_1.load_cp(checkpoint)
        decoder.load_cp(checkpoint)
        optim = model_builder.build_optim(args, [model, decoder, s_t_1],
                                          checkpoint)

    elif pt != None:
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, [model, decoder, s_t_1],
                                          checkpoint)
    else:
        optim = model_builder.build_optim(args, [model, decoder, s_t_1], None)

    # tokenizer,nlp
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-uncased',
        do_lower_case=True,
        never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                     '[unused2]', '[UNK]'),
        no_word_piece=True)
    nlp = StanfordCoreNLP(r'/home1/bqw/stanford-corenlp-full-2018-10-05')

    # build optim

    # load train dataset
    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device_id,
                                      shuffle=True,
                                      is_test=False)

    # build trainer
    trainer = build_trainer(args,
                            device_id,
                            model,
                            optim,
                            decoder=decoder,
                            get_s_t=s_t_1,
                            device=device_id,
                            tokenizer=tokenizer,
                            nlp=nlp)
    trainer.abs_train(train_iter_fct, args.train_steps)
コード例 #18
0
            step = int(cp.split('.')[-2].split('_')[-1])
        except:
            step = 0

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (cp != ''):
        test_from = cp
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    model_args = args

    # start flask app
    app.run(host='0.0.0.0')
コード例 #19
0
def abs_decoder(args, device_id, pt):

    step = int(pt.split('.')[-2].split('_')[-1])

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)

    # decoder
    decoder = Decoder(model.bert.model.config.hidden_size // 2,
                      model.bert.model.config.vocab_size,
                      model.bert.model.config.hidden_size,
                      model.bert.model.embeddings, device,
                      logger)  # 2*hidden_dim = embedding_size

    # get initial s_t
    s_t_1 = get_initial_s(model.bert.model.config.hidden_size, device)

    model.load_cp(checkpoint)
    s_t_1.load_cp(checkpoint)
    decoder.load_cp(checkpoint)

    model.eval()
    decoder.eval()
    s_t_1.eval()
    # tokenizer,nlp
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-uncased',
        do_lower_case=True,
        never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                     '[unused2]', '[UNK]'),
        no_word_piece=True)
    nlp = StanfordCoreNLP(r'/home1/bqw/stanford-corenlp-full-2018-10-05')
    # nlp.logging_level = 10

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    trainer = build_trainer(args,
                            device_id,
                            model,
                            None,
                            decoder=decoder,
                            get_s_t=s_t_1,
                            device=device_id,
                            tokenizer=tokenizer,
                            nlp=nlp,
                            extract_num=args.extract_num)
    trainer.abs_decode(test_iter, step)