Exemplo n.º 1
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    def train_iter_fct():
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                                 shuffle=True, is_test=False)

    # temp change for reducing gpu memory
    model = Summarizer(args, device, load_pretrained_bert=True)
    #config = BertConfig.from_json_file(args.bert_config_path)
    #model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config)

    if args.train_from != '': #train another part from beginning
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint, strict=False)
        # keys can not match
        #optim = model_builder.build_optim(args, model, checkpoint)
        optim = model_builder.build_optim(args, model, None)
        if args.model_name == "ctx" and args.fix_scorer:
            logger.info("fix the saliency scorer")
            #for param in self.bert.model.parameters():
            for param in model.parameters():
                param.requires_grad = False

            if hasattr(model.encoder, "selector") and model.encoder.selector is not None:
                for param in model.encoder.selector.parameters():
                    param.requires_grad = True
            #print([p for p in model.parameters() if p.requires_grad])
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    _, neg_valid_loss = trainer.train(train_iter_fct, args.train_steps)
    while len(neg_valid_loss) > 0:
        #from 3rd to 2nd to 1st.
        neg_loss, saved_model = heapq.heappop(neg_valid_loss)
        print(-neg_loss, saved_model)
        step = int(saved_model.split('.')[-2].split('_')[-1])
        test(args, device_id, saved_model, step)
    logger.info("Finish!")
Exemplo n.º 2
0
def train(args, device_id):
    # import ipdb; ipdb.set_trace()
    # import pdb; pdb.set_trace()
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        print("device_id = " + str(device_id))
        torch.cuda.set_device(device_id)
        print("device set")
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = Summarizer(args, device, load_pretrained_bert=True)
    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    with comet_experiment.train():
        trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 3
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    logger.info('data begin load %s',
                time.strftime('%H:%M:%S', time.localtime(time.time())))

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    logger.info('data end load %s',
                time.strftime('%H:%M:%S', time.localtime(time.time())))
    model = Summarizer(args, device, load_pretrained_bert=True)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 4
0
def train_single_ext(args, device_id):
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    #TODO -> add ability to load model from chkpt
    if args.train_from != '':
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)

        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])

    else:
        checkpoint = None

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      args.device,
                                      shuffle=True,
                                      is_test=False)

    model = ExtSummarizer(args, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)

    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 5
0
def train_abs_single(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    if (args.load_from_extractive != ''):
        logger.info('Loading bert from extractive model %s' % args.load_from_extractive)
        bert_from_extractive = torch.load(args.load_from_extractive, map_location=lambda storage, loc: storage)
        bert_from_extractive = bert_from_extractive['model']
    else:
        bert_from_extractive = None
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                      shuffle=True, is_test=False)

    model = AbsSummarizer(args, device, checkpoint, bert_from_extractive)
    if (args.sep_optim):
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_bert, optim_dec]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    logger.info(model)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
    symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
               'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}

    train_loss = abs_loss(model.generator, symbols, model.vocab_size, device, train=True,
                          label_smoothing=args.label_smoothing)

    trainer = build_trainer(args, device_id, model, optim, train_loss)

    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 6
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    logger.info("Device ID %d" % device_id)
    logger.info("Device %s" % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(
            args,
            load_dataset(args, "train", shuffle=True),
            args.batch_size,
            device,
            shuffle=True,
            is_test=False,
        )

    model = Summarizer(args, device, load_pretrained_bert=True)
    if args.train_from != "":
        logger.info("Loading checkpoint from %s" % args.train_from)
        checkpoint = torch.load(
            args.train_from, map_location=lambda storage, loc: storage
        )
        opt = vars(checkpoint["opt"])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 7
0
def train_single_ext(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    def train_iter_fct():
        if args.is_debugging:
            print("YES it is debugging")
            return data_loader.Dataloader(args,
                                          load_dataset(args,
                                                       'test',
                                                       shuffle=False),
                                          args.batch_size,
                                          device,
                                          shuffle=False,
                                          is_test=False)
            # exit()
        else:
            return data_loader.Dataloader(args,
                                          load_dataset(args,
                                                       'train',
                                                       shuffle=True),
                                          args.batch_size,
                                          device,
                                          shuffle=True,
                                          is_test=False)

    model = ExtSummarizer(args, device, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    logger.info(model)

    trainer = build_trainer(args, device_id, model, optim)

    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 8
0
def train_single_jigsaw(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    def train_iter_fct():
        return jigsaw_data_loader.Dataloader(args,
                                             jigsaw_data_loader.load_dataset(
                                                 args, 'train', shuffle=True),
                                             args.batch_size,
                                             device,
                                             shuffle=True,
                                             is_test=False)

    jigsaw = args.jigsaw if 'jigsaw' in args else 'jigsaw_lab'
    if jigsaw == 'jigsaw_dec':
        model = SentenceTransformer(args, device, checkpoint, sum_or_jigsaw=1)
    else:
        model = Jigsaw(args, device, checkpoint)
    optim = build_optim(args, model, checkpoint)

    logger.info(model)
    # if args.fp16:
    #     opt_level = 'O1'  # typical fp16 training, can also try O2 to compare performance
    # else:
    #     opt_level = 'O0'  # pure fp32 traning
    # model, optim.optimizer = amp.initialize(model, optim.optimizer, opt_level=opt_level)
    # logger.info('type(optim)'+str(type(optim)))
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 9
0
    def train(self):
        model = model_builder.Summarizer(self.args,
                                         self.device,
                                         load_pretrained_bert=True)

        if self.args.train_from:
            logger.info('Loading checkpoint from %s' % self.args.train_from)
            checkpoint = torch.load(self.args.train_from,
                                    map_location=lambda storage, loc: storage)
            opt = vars(checkpoint['opt'])
            for k in opt.keys():
                if k in self.model_flags:
                    setattr(self.args, k, opt[k])
            model.load_cp(checkpoint)
            optimizer = model_builder.build_optim(self.args, model, checkpoint)
        else:
            optimizer = model_builder.build_optim(self.args, model, None)

        logger.info(model)
        trainer = build_trainer(self.args, self.device_id, model, optimizer)
        trainer.train(self.train_iter, self.args.train_steps)
Exemplo n.º 10
0
def train_abs_single(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_pretrain.Dataloader(args,
                                        load_dataset(args,
                                                     'train',
                                                     shuffle=True),
                                        args.batch_size,
                                        device,
                                        shuffle=True,
                                        is_test=False)

    model = PretrainModel(args, device, checkpoint)
    optim = [model_builder.build_optim(args, model, checkpoint)]
    logger.info(model)
    symbols = {'PAD': 0}
    train_loss = abs_loss(model.generator,
                          symbols,
                          model.dec_vocab_size,
                          device,
                          train=True,
                          label_smoothing=args.label_smoothing)
    trainer = build_trainer(args, device_id, model, optim, train_loss)
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 11
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True)
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                                 shuffle=True, is_test=False)

    model = Summarizer(args, device, checkpoint)
    # optim = model_builder.build_optim(args, model.reg, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    # optim = BertAdam()
    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    #
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 12
0
def single_train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        # 使用指定的gpu
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)

    else:
        checkpoint = None

    def train_iter_method():
        return DataLoaderBert(load_dataset(args, 'train', shuffle=True),
                              args.batch_size,
                              shuffle=True,
                              is_test=False)

    model = NextSentencePrediction(args, device, checkpoint)
    optim = build_optim(args, model, checkpoint)

    logger.info(model)

    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_method, args.train_steps)
Exemplo n.º 13
0
def train(args, device_id):
    init_logger(args.log_file)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(args.vocab_path)
    word_padding_idx = spm.PieceToId('<PAD>')
    vocab_size = len(spm)

    def train_iter_fct():
        # return data_loader.AbstractiveDataloader(load_dataset('train', True), symbols, FLAGS.batch_size, device, True)
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      {'PAD': word_padding_idx},
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    #
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 14
0
    device = "cuda"
    device_id = -1

    if args.seed:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    def train_loader_fct():
        return DataLoader(args.data_folder,
                          512,
                          args.batch_size,
                          device=device,
                          shuffle=True)

    model = Summarizer(device, args)
    if args.train_from != '':
        print('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = dict(checkpoint['opt'])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
        model.load_cp(checkpoint['model'])
        optim = build_optim(args, model, checkpoint)
    else:
        optim = build_optim(args, model, None)

    trainer = build_trainer(args, model, optim)
    trainer.train(train_loader_fct, args.train_steps)
Exemplo n.º 15
0
def train_abs_single(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    if (args.load_from_extractive != ''):
        logger.info('Loading bert from extractive model %s' %
                    args.load_from_extractive)
        bert_from_extractive = torch.load(
            args.load_from_extractive,
            map_location=lambda storage, loc: storage)
        bert_from_extractive = bert_from_extractive['model']
    else:
        bert_from_extractive = None
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    def valid_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'valid',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = AbsSummarizer(args, device, checkpoint, bert_from_extractive)
    if (args.sep_optim):
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_bert, optim_dec]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    logger.info(model)
    print("model.vocab_size" + str(model.vocab_size))

    parser = argparse.ArgumentParser()
    parser.add_argument('--bpe-codes',
                        default="/content/PhoBERT_base_transformers/bpe.codes",
                        required=False,
                        type=str,
                        help='path to fastBPE BPE')
    args1, unknown = parser.parse_known_args()
    bpe = fastBPE(args1)

    # Load the dictionary
    vocab = Dictionary()
    vocab.add_from_file("/content/PhoBERT_base_transformers/dict.txt")

    tokenizer = bpe
    symbols = {
        'BOS': vocab.indices['[unused0]'],
        'EOS': vocab.indices['[unused1]'],
        'PAD': vocab.indices['[PAD]'],
        'EOQ': vocab.indices['[unused2]']
    }

    train_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          device,
                          train=True,
                          label_smoothing=args.label_smoothing)

    trainer = build_trainer(args, device_id, model, optim, train_loss)

    trainer.train(train_iter_fct=train_iter_fct,
                  train_steps=args.train_steps,
                  valid_iter_fct=valid_iter_fct)
Exemplo n.º 16
0
def train_abs_single(args, device_id):
    # Initialize training parameters
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # Load checkpoint
    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from, map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    # Load pretrain model
    if args.pretrain_path != '':
        logger.info('Loading pretrain_model from %s' % args.pretrain_path)
        pretrain_model = torch.load(args.pretrain_path, map_location=lambda storage, loc: storage)
        pretrain_model = pretrain_model['model']
    else:
        pretrain_model = None

    # Initialize model
    def train_iter_fct():
        return data_hmm.Dataloader(args,\
                            load_dataset(args, 'train', shuffle=True),\
                            args.batch_size,\
                            device, shuffle=True,\
                            is_test=False)

    model = HMMModel(args, device, checkpoint, pretrain_model)
    if (args.sep_optim):
        optim_tok = model_builder.build_optim_tok(args, model, checkpoint)
        optim_hmm = model_builder.build_optim_hmm(args, model, checkpoint)
        optim = [optim_tok, optim_hmm]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]
    train_loss = hmm_loss(model.generator,\
                            args.pad_id,\
                            args.relation_path,\
                            args.fake_global, \
                            model.dec_vocab_size,\
                            device, train=True,\
                            label_smoothing=args.label_smoothing)
    trainer = build_trainer(args, device_id, model, optim, train_loss)
    logger.info(model)

    # Start training
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 17
0
def train_ext(args, device_id):
    init_logger(args.log_file)
    if device_id == -1:
        device = "cpu"
    else:
        device = "cuda"
    logger.info('Device ID %s' % ','.join(map(str, device_id)))
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if device_id != -1:
        torch.cuda.set_device(device_id[0])
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # Load checkpoint if necessary
    if args.load_model is not None:
        logger.info('Loading model_checkpoint from %s' % args.load_model)
        model_checkpoint = torch.load(
            args.load_model, map_location=lambda storage, loc: storage)
        if not args.transfer_learning:
            args.doc_classifier = model_checkpoint['opt'].doc_classifier
            args.nbr_class_neurons = model_checkpoint['opt'].nbr_class_neurons
    else:
        model_checkpoint = None

    if args.gan_mode and args.load_generator is not None:
        logger.info('Loading generator_checkpoint from %s' %
                    args.load_generator)
        generator_checkpoint = torch.load(
            args.load_generator, map_location=lambda storage, loc: storage)
        args.generator = generator_checkpoint['opt'].generator
    else:
        generator_checkpoint = None

    # Data generator for training
    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True)

    # Data generator for validation
    def valid_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args,
                                                   'valid',
                                                   shuffle=False),
                                      args.test_batch_size,
                                      device,
                                      shuffle=False)

    # Creation model
    model = Ext_summarizer(args, device, model_checkpoint)
    optim = model_builder.build_optim(args, model, model_checkpoint)
    logger.info(model)

    if args.gan_mode:
        # Creation generator if gan
        generator = Generator(args, model.length_embeddings, device,
                              generator_checkpoint)
        optim_generator = model_builder.build_optim_generator(
            args, generator, generator_checkpoint)
        logger.info(generator)
    else:
        generator = None
        optim_generator = None

    trainer = build_trainer(args, device_id, model, generator, optim,
                            optim_generator)
    trainer.train(train_iter_fct, args.train_steps, valid_iter_fct)
Exemplo n.º 18
0
def train_single_hybrid(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    # 重新设定随机种子
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                # 给attr加属性
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    if args.train_from_extractor != '':
        logger.info('Loading checkpoint from %s' % args.train_from_extractor)
        checkpoint_ext = torch.load(args.train_from_extractor,
                                    map_location=lambda storage, loc: storage)
        opt = vars(checkpoint_ext['opt'])
        for k in opt.keys():
            if (k in model_flags):
                # 给attr加属性
                setattr(args, k, opt[k])
    else:
        checkpoint_ext = None

    if args.train_from_abstractor != '':
        logger.info('Loading checkpoint from %s' % args.train_from_abstractor)
        checkpoint_abs = torch.load(args.train_from_abstractor,
                                    map_location=lambda storage, loc: storage)
        opt = vars(checkpoint_abs['opt'])
        for k in opt.keys():
            if (k in model_flags):
                # 给attr加属性
                setattr(args, k, opt[k])
    else:
        checkpoint_abs = None

    def train_iter_fct():
        # 读一次数据
        if args.is_debugging:
            print("YES it is debugging")
            # 第三个参数是batch_size
            return data_loader.Dataloader(args,
                                          load_dataset(args,
                                                       'test',
                                                       shuffle=False),
                                          args.batch_size,
                                          device,
                                          shuffle=False,
                                          is_test=False)
            # exit()
        else:
            return data_loader.Dataloader(args,
                                          load_dataset(args,
                                                       'train',
                                                       shuffle=True),
                                          args.batch_size,
                                          device,
                                          shuffle=True,
                                          is_test=False)

    # modules, consts, options = init_modules()
    # 选择模型: ExtSummarizer
    # print("1~~~~~~~~~~~~~~~~~~~~")
    model = HybridSummarizer(args,
                             device,
                             checkpoint,
                             checkpoint_ext=checkpoint_ext,
                             checkpoint_abs=checkpoint_abs)
    # 建优化器
    # print("2~~~~~~~~~~~~~~~~~~~~")
    # optim = model_builder.build_optim(args, model, checkpoint)
    if (args.sep_optim):
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_bert, optim_dec]
        # print("????????")
        # print("optim")
        # print(optim)
        # exit()

    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    # 做log
    logger.info(model)

    # print("3~~~~~~~~~~~~~~~~~~~~")
    # 建训练器
    # tokenizer = BertTokenizer.from_pretrained('/home/ybai/projects/PreSumm/PreSumm/temp/', do_lower_case=True, cache_dir=args.temp_dir)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {
        'BOS': tokenizer.vocab['[unused0]'],
        'EOS': tokenizer.vocab['[unused1]'],
        'PAD': tokenizer.vocab['[PAD]'],
        'EOQ': tokenizer.vocab['[unused2]']
    }
    train_loss = abs_loss(model.abstractor.generator,
                          symbols,
                          model.abstractor.vocab_size,
                          device,
                          train=True,
                          label_smoothing=args.label_smoothing)
    trainer = build_trainer(args, device_id, model, optim, train_loss)

    # print("4~~~~~~~~~~~~~~~~~~~~")
    # 开始训练
    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 19
0
def train_abs(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    if args.load_from_extractive != '':
        logger.info('Loading bert from extractive model %s' %
                    args.load_from_extractive)
        bert_from_extractive = torch.load(
            args.load_from_extractive,
            map_location=lambda storage, loc: storage)
        bert_from_extractive = bert_from_extractive['model']
    else:
        bert_from_extractive = None
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    symbols, tokenizer = get_symbol_and_tokenizer(args.encoder, args.temp_dir)

    model = AbsSummarizer(args,
                          device,
                          checkpoint,
                          bert_from_extractive,
                          symbols=symbols)
    if args.sep_optim:
        optim_enc = model_builder.build_optim_enc(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_enc, optim_dec]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    logger.info(model)

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False,
                                      tokenizer=tokenizer)

    train_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          device,
                          train=True,
                          label_smoothing=args.label_smoothing)

    trainer = build_trainer(args, device_id, model, optim, train_loss)

    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 20
0
def train_abs_single(args, device_id):
    """Implements training process (meta / non-meta)
    Args:
        device_id (int) : the GPU id to be used
    """

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d', device_id)
    logger.info('Device %s', device)

    # Fix random seed to control experiement
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    if device_id >= 0:  # if use GPU
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    # Load checkpoint and args
    if args.train_from != '':
        logger.info('Loading checkpoint from %s', args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])  # which is self.args
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    # Load extractive model as initial parameter (proposed by Presumm)
    if args.load_from_extractive != '':
        logger.info('Loading bert from extractive model %s',
                    args.load_from_extractive)
        bert_from_extractive = torch.load(
            args.load_from_extractive,
            map_location=lambda storage, loc: storage)
        bert_from_extractive = bert_from_extractive['model']
    else:
        bert_from_extractive = None

    # Prepare dataloader
    if args.meta_mode:

        def meta_train_iter_fct():
            return data_loader.MetaDataloader(args,
                                              load_meta_dataset(args,
                                                                'train',
                                                                shuffle=True),
                                              args.batch_size,
                                              device,
                                              shuffle=True,
                                              is_test=False)
    else:

        def train_iter_fct():
            return data_loader.Dataloader(args,
                                          load_dataset(args,
                                                       'train',
                                                       shuffle=True),
                                          args.batch_size,
                                          device,
                                          shuffle=True,
                                          is_test=False)

    # Prepare model
    if args.meta_mode:
        model = MTLAbsSummarizer(args, device, checkpoint,
                                 bert_from_extractive)
    else:
        model = AbsSummarizer(args, device, checkpoint, bert_from_extractive)

    # Prepare optimizer for inner loop
    # The optimizer for each task is seperated
    if args.meta_mode:
        optims_inner = []
        for _ in range(args.num_task):
            if args.sep_optim:
                optim_bert_inner = model_builder.build_optim_bert_inner(
                    args, model, checkpoint, 'maml')
                optim_dec_inner = model_builder.build_optim_dec_inner(
                    args, model, checkpoint, 'maml')
                optims_inner.append([optim_bert_inner, optim_dec_inner])
            else:
                optims_inner.append([
                    model_builder.build_optim_inner(args, model, checkpoint,
                                                    'maml')
                ])

    # Prepare optimizer for outer loop
    if args.sep_optim:
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optims = [optim_bert, optim_dec]
    else:
        optims = [model_builder.build_optim(args, model, checkpoint)]

    # Prepare tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {
        'BOS': tokenizer.vocab['[unused0]'],  # id = 1
        'EOS': tokenizer.vocab['[unused1]'],  # id = 2
        'EOQ': tokenizer.vocab['[unused2]'],  # id = 3
        'PAD': tokenizer.vocab['[PAD]']  # id = 0
    }

    # Self Check : special word ids
    special_words = [w for w in tokenizer.vocab.keys() if "[" in w]
    special_word_ids = [
        tokenizer.convert_tokens_to_ids(w) for w in special_words
    ]

    # Prepare loss computation
    train_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          device,
                          train=True,
                          label_smoothing=args.label_smoothing)

    # Prepare trainer and perform training
    if args.meta_mode:
        trainer = build_MTLtrainer(args, device_id, model, optims,
                                   optims_inner, train_loss)
        trainer.train(meta_train_iter_fct)
    else:
        trainer = build_trainer(args, device_id, model, optims, train_loss)
        trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 21
0
def validate(args, device_id, pt, step):
    ''' Implements validation process (meta / non-memta)
    Arguments:
        device_id (int) : the GPU id to be used
        pt() : checkpoint model
        step (int) : checkpoint step
    Process:
        - load checkpoint
        - prepare dataloader class
        - prepare model class
        - prepare loss func, which return loss class
        - prepare trainer
        - trainer.validate()
    Meta vs Normal
        - MetaDataloader      vs Dataloader
        - load_dataset        vs load_meta_dataset
        - MTLAbsSummarizer    vs AbsSummarizer
        - build_MTLtrainer    vs MTLTrainer
    '''
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)

    # Fix random seed to control experiement
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    # Load checkpoint ard args
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])  # which is self.args
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])

    # Prepare dataloader
    if (args.meta_mode):

        def valid_iter_fct():
            return data_loader.MetaDataloader(args,
                                              load_meta_dataset(args,
                                                                'valid',
                                                                shuffle=True),
                                              args.batch_size,
                                              device,
                                              shuffle=True,
                                              is_test=False)

    else:
        valid_iter = data_loader.Dataloader(args,
                                            load_dataset(args,
                                                         'valid',
                                                         shuffle=False),
                                            args.batch_size,
                                            device,
                                            shuffle=False,
                                            is_test=False)

    # Prepare model
    if (args.meta_mode):
        model = MTLAbsSummarizer(args, device, checkpoint)
    else:
        model = AbsSummarizer(args, device, checkpoint)
    #model.eval()

    # Prepare optimizer for inner loop
    # The optimizer for each task is seperated
    if (args.meta_mode):
        optims_inner = []
        for i in range(args.num_task):
            if (args.sep_optim):
                optim_bert_inner = model_builder.build_optim_bert_inner(
                    args, model, checkpoint, 'maml')
                optim_dec_inner = model_builder.build_optim_dec_inner(
                    args, model, checkpoint, 'maml')
                optims_inner.append([optim_bert_inner, optim_dec_inner])
            else:
                self.optims_inner.append([
                    model_builder.build_optim_inner(args, model, checkpoint,
                                                    'maml')
                ])

    # Prepare optimizer (not actually used, but get the step information)
    if (args.sep_optim):
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_bert, optim_dec]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    # Prepare loss
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {
        'BOS': tokenizer.vocab['[unused0]'],
        'EOS': tokenizer.vocab['[unused1]'],
        'PAD': tokenizer.vocab['[PAD]'],
        'EOQ': tokenizer.vocab['[unused2]']
    }

    # Prepare loss computation
    valid_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          device,
                          train=False)

    # Prepare trainer and perform validation
    if (args.meta_mode):
        trainer = build_MTLtrainer(args, device_id, model, optim, optims_inner,
                                   valid_loss)
        stats = trainer.validate(valid_iter_fct, step)
    else:
        trainer = build_trainer(args, device_id, model, None, valid_loss)
        stats = trainer.validate(valid_iter, step)

    return stats.xent()
Exemplo n.º 22
0
def train_abs_single(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    logger.info("Device ID %d" % device_id)
    logger.info("Device %s" % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != "":
        logger.info("Loading checkpoint from %s" % args.train_from)
        checkpoint = torch.load(
            args.train_from, map_location=lambda storage, loc: storage
        )
        opt = vars(checkpoint["opt"])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    if args.load_from_extractive != "":
        logger.info("Loading bert from extractive model %s" % args.load_from_extractive)
        bert_from_extractive = torch.load(
            args.load_from_extractive, map_location=lambda storage, loc: storage
        )
        bert_from_extractive = bert_from_extractive["model"]
    else:
        bert_from_extractive = None
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(
            args,
            load_dataset(args, "train", shuffle=True),
            args.batch_size,
            device,
            shuffle=True,
            is_test=False,
        )

    model = AbsSummarizer(args, device, checkpoint, bert_from_extractive)
    if args.sep_optim:
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_bert, optim_dec]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    logger.info(model)

    tokenizer = BertTokenizer.from_pretrained(
        "chinese_roberta_wwm_ext_pytorch/", do_lower_case=True, cache_dir=args.temp_dir
    )
    symbols = {
        "BOS": tokenizer.vocab["[unused1]"],
        "EOS": tokenizer.vocab["[unused2]"],
        "PAD": tokenizer.vocab["[PAD]"],
        "EOQ": tokenizer.vocab["[unused3]"],
    }

    train_loss = abs_loss(
        model.generator,
        symbols,
        model.vocab_size,
        device,
        train=True,
        label_smoothing=args.label_smoothing,
    )

    trainer = build_trainer(args, device_id, model, optim, train_loss)

    trainer.train(train_iter_fct, args.train_steps)
Exemplo n.º 23
0
def abs_train(args, device_id, pt, recover_all=False):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    # load extractive model
    if pt != None:
        test_from = pt
        logger.info('Loading checkpoint from %s' % test_from)
        checkpoint = torch.load(test_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        print(args)
    config = BertConfig.from_json_file(args.bert_config_path)
    # build extractive model
    model = Summarizer(args,
                       device_id,
                       load_pretrained_bert=False,
                       bert_config=config)

    # decoder
    decoder = Decoder(model.bert.model.config.hidden_size // 2,
                      model.bert.model.config.vocab_size,
                      model.bert.model.config.hidden_size,
                      model.bert.model.embeddings,
                      device_id)  # 2*hidden_dim = embedding_size

    # get initial s_t
    s_t_1 = get_initial_s(model.bert.model.config.hidden_size, device_id)
    if recover_all:
        model.load_cp(checkpoint)
        s_t_1.load_cp(checkpoint)
        decoder.load_cp(checkpoint)
        optim = model_builder.build_optim(args, [model, decoder, s_t_1],
                                          checkpoint)

    elif pt != None:
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, [model, decoder, s_t_1],
                                          checkpoint)
    else:
        optim = model_builder.build_optim(args, [model, decoder, s_t_1], None)

    # tokenizer,nlp
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-uncased',
        do_lower_case=True,
        never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                     '[unused2]', '[UNK]'),
        no_word_piece=True)
    nlp = StanfordCoreNLP(r'/home1/bqw/stanford-corenlp-full-2018-10-05')

    # build optim

    # load train dataset
    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device_id,
                                      shuffle=True,
                                      is_test=False)

    # build trainer
    trainer = build_trainer(args,
                            device_id,
                            model,
                            optim,
                            decoder=decoder,
                            get_s_t=s_t_1,
                            device=device_id,
                            tokenizer=tokenizer,
                            nlp=nlp)
    trainer.abs_train(train_iter_fct, args.train_steps)