예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help="Distributed training.")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Used for debugging on GPU machine.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Used for debugging on GPU machine.")
    args = parser.parse_args()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.ERROR)
    logger = logging.getLogger(__file__)
    if args.server_ip and args.server_port and args.local_rank in [-1, 0]:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    model_config = get_model_config()
    trainer_config = get_trainer_config()

    # Log only on main process
    if args.local_rank not in [-1, 0]:
        sys.stdout = open(f"./runs/log_distributed_{args.local_rank}",
                          "w")  # dump sdtout
        writer = DummyWriter()
    else:
        writer = SummaryWriter(comment=trainer_config.writer_comment)

    logger.info("model config: {}".format(model_config))
    logger.info("trainer config: {}".format(trainer_config))
    log_dir = writer.log_dir
    interrupt_checkpoint_path = os.path.join(
        log_dir, trainer_config.interrupt_checkpoint_path)
    last_checkpoint_path = os.path.join(log_dir,
                                        trainer_config.last_checkpoint_path)
    logger.info(
        "Logging to {}".format(log_dir)
    )  # Let's save everything on an experiment in the ./runs/XXX/directory
    if args.local_rank in [-1, 0]:
        with open(os.path.join(log_dir, "model_config.json"), "w") as f:
            json.dump(model_config, f)
        with open(os.path.join(log_dir, "trainer_config.json"), "w") as f:
            json.dump(trainer_config, f)

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path,
                                zero_shot=trainer_config.zero_shot)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        normalize_embeddings=model_config.normalize_embeddings,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        sent_dialog_id=vocab.sent_dialog_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups,
        multiple_choice_head=model_config.multiple_choice_head,
        constant_embedding=model_config.constant_embedding,
        single_input=model_config.single_input,
        dialog_embeddings=model_config.dialog_embeddings,
        share_models=model_config.share_models,
        successive_attention=model_config.successive_attention,
        sparse_embeddings=model_config.sparse_embeddings,
        shared_attention=model_config.shared_attention,
        bs_temperature=model_config.bs_temperature,
        bs_nucleus_p=model_config.bs_nucleus_p,
        vocab=None)  # for beam search debugging

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        if not model_config.share_models:
            load_openai_weights(transformer.encoder_module,
                                trainer_config.openai_parameters_dir,
                                n_special_tokens=vocab.n_special_tokens)
        logger.info('OpenAI weights loaded from {}, model shared: {}'.format(
            trainer_config.openai_parameters_dir, model_config.share_models))

    logger.info('loading datasets')
    train_dataset = FacebookDataset(
        trainer_config.train_datasets,
        vocab,
        max_lengths=(transformer.n_pos_embeddings - 1) //
        (3 if model_config.single_input else 1),  # A bit restrictive here
        dialog_embeddings=model_config.dialog_embeddings,
        cache=trainer_config.train_datasets_cache,
        use_start_end=model_config.use_start_end,
        negative_samples=trainer_config.negative_samples,
        augment=trainer_config.persona_augment,
        aug_syn_proba=trainer_config.persona_aug_syn_proba,
        limit_size=trainer_config.limit_train_size)
    test_dataset = FacebookDataset(
        trainer_config.test_datasets,
        vocab,
        max_lengths=(transformer.n_pos_embeddings - 1) //
        (3 if model_config.single_input else 1),  # A bit restrictive here
        dialog_embeddings=model_config.dialog_embeddings,
        cache=trainer_config.test_datasets_cache,
        use_start_end=model_config.use_start_end,
        negative_samples=-1,  # Keep all negative samples
        augment=False,
        aug_syn_proba=0.0,
        limit_size=trainer_config.limit_eval_size)
    logger.info(
        f'train dataset {len(train_dataset)} test dataset {(test_dataset)}')

    if args.local_rank != -1:

        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        transformer.distribute(device)

    model_trainer = Trainer(
        transformer,
        train_dataset,
        writer,
        test_dataset,
        train_batch_size=trainer_config.train_batch_size,
        batch_split=trainer_config.batch_split,
        test_batch_size=trainer_config.test_batch_size,
        lr=trainer_config.lr,
        lr_warmup=trainer_config.lr_warmup,
        weight_decay=trainer_config.weight_decay,
        s2s_weight=trainer_config.s2s_weight,
        lm_weight=trainer_config.lm_weight,
        risk_weight=trainer_config.risk_weight,
        hits_weight=trainer_config.hits_weight,
        single_input=model_config.single_input,
        n_jobs=trainer_config.n_jobs,
        clip_grad=trainer_config.clip_grad,
        device=device,
        ignore_idxs=vocab.special_tokens_ids,
        local_rank=args.local_rank,
        apex_level=model_config.apex_level,
        apex_loss_scale=trainer_config.apex_loss_scale,
        linear_schedule=trainer_config.linear_schedule,
        n_epochs=trainer_config.n_epochs,
        evaluate_full_sequences=trainer_config.evaluate_full_sequences)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.load_last, map_location=device)
        model_trainer.load_state_dict(state_dict)
        logger.info('Weights loaded from {}'.format(trainer_config.load_last))

    # helpers -----------------------------------------------------
    def external_metrics_func(full_references,
                              full_predictions,
                              epoch,
                              metric=None):
        references_file_path = os.path.join(
            writer.log_dir,
            trainer_config.eval_references_file + "_{}".format(epoch))
        predictions_file_path = os.path.join(
            writer.log_dir,
            trainer_config.eval_predictions_file + "_{}".format(epoch))
        with open(references_file_path, 'w', encoding='utf-8') as f:
            f.write(unicode('\n'.join(full_references)))
        with open(predictions_file_path, 'w', encoding='utf-8') as f:
            f.write(unicode('\n'.join(full_predictions)))

        if metric is not None:
            return specified_nlp_metric([references_file_path],
                                        predictions_file_path, metric)

        nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(
            [references_file_path], predictions_file_path)

        metrics = {'meteor': meteor, 'avg_len': avg_len}
        for name, metric in (('nist', nist), ('entropy', entropy),
                             ('div', div), ('bleu', bleu)):
            for i, m in enumerate(metric, 1):
                metrics['{}_{}'.format(name, i)] = m

        return metrics

    def save_func(epoch):
        if epoch != -1:
            torch.save(model_trainer.state_dict(), last_checkpoint_path)

    def sample_text_func(epoch):
        n_samples = 0
        model_trainer.model.eval()
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target, _ in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device)
                for c in [persona_info, dialog] if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = vocab.ids2string(persona_info[1:-1])
            dialog_str = vocab.ids2string(dialog)
            dialog_str = dialog_str.replace(vocab.talker1_bos,
                                            '\n\t- ').replace(
                                                vocab.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(vocab.talker1_eos,
                                            '').replace(vocab.talker2_eos, '')
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)

            logger.info('\n')
            logger.info('Persona info:\n\t{}'.format(persona_info_str))
            logger.info('Dialog:{}'.format(dialog_str))
            logger.info('Target:\n\t{}'.format(target_str))
            logger.info('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs, external_metrics_func, epoch)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        assert all([0 <= s <= 1.0 for s in scores])
        return [1 - s for s in scores]

    def get_risk_metric_func(risk_metric):
        """ risk_metric selected in:
            f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4}
        """
        def external_metric_risk(predictions, targets):
            string_targets = list(vocab.ids2string(t) for t in targets)
            string_predictions = list(vocab.ids2string(t) for t in predictions)
            metrics = [
                external_metrics_func([t], [p], epoch=-1, metric=risk_metric)
                for p, t in zip(string_predictions, string_targets)
            ]

            if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]):
                return [-m for m in metrics]

            assert all([0 <= s <= 1.0 for s in metrics]), metrics

            return [1 - m for m in metrics]

        if risk_metric == 'f1':
            return f1_risk

        return external_metric_risk

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(
            after_epoch_funcs=[save_func, sample_text_func, test_func],
            risk_func=get_risk_metric_func(trainer_config.risk_metric))
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        if args.local_rank in [-1, 0]:
            torch.save(model_trainer.state_dict(), interrupt_checkpoint_path)
        raise e
예제 #2
0
def main():
    model_config = get_model_config()
    trainer_config = get_trainer_config()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        print('OpenAI weights loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = FacebookDataset(trainer_config.train_datasets, vocab,
                                    transformer.n_pos_embeddings - 1)
    test_dataset = FacebookDataset(trainer_config.test_datasets, vocab,
                                   transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(transformer,
                            train_dataset,
                            test_dataset,
                            batch_size=trainer_config.batch_size,
                            batch_split=trainer_config.batch_split,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=vocab.special_tokens_ids)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path,
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path))

    # helpers -----------------------------------------------------
    def save_func(epoch):
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path)

    def sample_text_func(epoch):
        n_samples = 5
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device)
                for c in [persona_info, dialog] if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = vocab.ids2string(persona_info[1:-1])
            dialog_str = vocab.ids2string(dialog)
            dialog_str = dialog_str.replace(vocab.talker1_bos,
                                            '\n\t- ').replace(
                                                vocab.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(vocab.talker1_eos,
                                            '').replace(vocab.talker2_eos, '')
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)

            print('\n')
            print('Persona info:\n\t{}'.format(persona_info_str))
            print('Dialog:{}'.format(dialog_str))
            print('Target:\n\t{}'.format(target_str))
            print('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        return [1 - s for s in scores]

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(
            trainer_config.n_epochs,
            after_epoch_funcs=[save_func, sample_text_func, test_func],
            risk_func=f1_risk)
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        torch.save(model_trainer.state_dict(),
                   trainer_config.interrupt_checkpoint_path)
        raise e
예제 #3
0
def main():
    args = InputConfig().args

    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.ERROR)
    logger = logging.getLogger(__file__)
    if args.server_ip and args.server_port and args.local_rank in [-1, 0]:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    trainer_config = get_trainer_config(args)
    # with open('/apdcephfs/share_916081/rainyucao/transformer_chatbot_experiments/test_log', 'w') as f:
    #     a = []
    #     a.append('args local rank is ' + str(args.local_rank) + '\n')
    #     a.append('cuda count' + str(torch.cuda.device_count()) + '\n')
    #     if args.local_rank not in [-1, 0] and torch.cuda.device_count() == 1:
    #         args.local_rank = -1
    #     a.append('args local rank is ' + str(args.local_rank) + '\n')
    #     f.writelines(a)

    # Log only on main process
    if args.local_rank not in [-1, 0]:
        sys.stdout = open("./runs/log_distributed_{}".format(args.local_rank), "w")  # dump sdtout
        writer = DummyWriter()
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.ERROR)
        logger = logging.getLogger(__file__)
    else:
        from datetime import datetime
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        if args.single_input:
            comment = '_{}_{}_single'.format(args.model_type, args.data_type)
        else:
            if args.model_type == 'seq2seq':
                comment = '_seq2seq_multi_{}_{}'.format(args.data_type, args.attention_fusion_type)
            else:
                comment = '_{}_{}_{}_{}_{}'.format(args.model_type, args.data_type, args.attention_fusion_type,
                           ('sm' if args.shared_module == 1 else 'nm'), ('sa' if args.shared_attention == 1 else 'na'))
        logdir = os.path.join('runs', current_time + comment)
        writer = SummaryWriter(logdir=logdir)
        logger = config_logger(os.path.join(logdir, 'train.log'))

    log_dir = writer.logdir
    logger.info("Training args: {}".format(args))
    logger.info("trainer config: {}".format(trainer_config))
    interrupt_checkpoint_path = os.path.join(log_dir, trainer_config.interrupt_checkpoint_path)
    last_checkpoint_path = os.path.join(log_dir, trainer_config.last_checkpoint_path)
    best_checkpoint_path = os.path.join(log_dir, 'best_model')
    logger.info("Logging to {}".format(log_dir))  # Let's save everything on an experiment in the ./runs/XXX/directory
    if args.local_rank in [-1, 0]:
        with open(os.path.join(log_dir, "trainer_config.json"), "w") as f:
            json.dump(trainer_config, f)

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    parsed_train_data, parsed_valid_data, parsed_test_data = None, None, None
    if args.model_type == 'gpt':
        if args.single_input:
            model = OpenAIGPTLMHeadModel.from_pretrained('./openai-gpt')
        else:
            model = OpenAIGPTEncoderDecoderModel.from_pretrained('./openai-gpt')
        tokenizer = OpenAIGPTTokenizer.from_pretrained('./openai-gpt')
    elif args.model_type == 'dialogpt':
        if args.single_input:
            model = GPT2DoubleHeadsModel.from_pretrained('./dialogpt')
        else:
            model = GPT2EncoderDecoderModel.from_pretrained('./dialogpt')
        tokenizer = GPT2Tokenizer.from_pretrained('./dialogpt')
    elif args.model_type == 'seq2seq':
        seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets,
                                 trainer_config.test_datasets, args.vocab_path, data_type=args.data_type)
        tokenizer = seq2seq_vocab.vocab
        parsed_train_data, parsed_valid_data, parsed_test_data = seq2seq_vocab.all_data[0], seq2seq_vocab.all_data[1], \
                                                                 seq2seq_vocab.all_data[2]
        args.dialog_embeddings = False
        model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size,
                               args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger,
                                multi_input=not args.single_input, attention_fusion_type=args.attention_fusion_type)
    else:
        if args.single_input:
            model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small')
        else:
            model = GPT2EncoderDecoderModel.from_pretrained('./gpt2-small')
        tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small')


    if args.model_type in ['gpt', 'dialogpt', 'gpt2']:
        tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type)
        model.embeddings_size = 768
        model.n_embeddings = len(tokenizer)
        model.shared_attention = (args.shared_attention == 1)
        model.shared_module = (args.shared_module == 1)
        model.attention_fusion_type = args.attention_fusion_type
        model.single_input = args.single_input
        if args.model_type == 'gpt':
            model_embedding_weight = model.transformer.tokens_embed.weight
            model.transformer.tokens_embed = nn.Embedding(model.n_embeddings, 768)
            model.lm_head = nn.Linear(768, model.n_embeddings, bias=False)
            model.transformer.tokens_embed.weight.data[:-additional_length, :] = model_embedding_weight.data
            model.transformer.tokens_embed.weight.data[-additional_length:, :] = 0
            model.lm_head.weight = model.transformer.tokens_embed.weight
        else:
            model_embedding_weight = model.transformer.wte.weight
            model.transformer.wte = nn.Embedding(model.n_embeddings, 768)
            model.lm_head = nn.Linear(768, model.n_embeddings, bias=False)
            model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data
            model.transformer.wte.weight.data[-additional_length:, :] = 0
            model.lm_head.weight = model.transformer.wte.weight

        if not args.single_input:
            model.reload_module_dict()
        model.sent_dialog_id = tokenizer.sent_dialog_id
    model.talker1_id = tokenizer.talker1_bos_id
    model.talker2_id = tokenizer.talker2_bos_id

    model.padding_idx = tokenizer.pad_id
    model.n_pos_embeddings = 512

    model.bos_id = tokenizer.bos_id
    model.eos_id = tokenizer.eos_id
    model.beam_size = args.beam_size
    model.diversity_groups = 1
    model.max_seq_len = 32
    model.dialog_embeddings = args.dialog_embeddings
    model.bs_temperature = args.bs_temperature
    model.bs_nucleus_p = args.bs_nucleus_p
    model.annealing_topk = args.annealing_topk
    model.length_penalty_coef = args.length_penalty
    model.vocab = None
    model.annealing = args.annealing
    model.diversity_coef = args.diversity_coef
    model.sample = False
    model.inference_mode = args.inference_mode
    model.response_k = args.response_k

    logger.info('loading datasets')
    train_dataset = FacebookDataset(trainer_config.train_datasets, tokenizer,
                                    max_lengths=model.n_pos_embeddings - 1,  # A bit restrictive here
                                    dialog_embeddings=args.dialog_embeddings,
                                    cache=trainer_config.train_datasets_cache,
                                    use_start_end=False,
                                    negative_samples=trainer_config.negative_samples,
                                    augment=trainer_config.persona_augment,
                                    aug_syn_proba=trainer_config.persona_aug_syn_proba,
                                    limit_size=trainer_config.limit_train_size,
                                    max_history_size=trainer_config.max_history_size,
                                    single_input=args.single_input,
                                    data_type=args.data_type,
                                    parsed_data=parsed_train_data)
    valid_dataset = FacebookDataset(trainer_config.valid_datasets, tokenizer,
                                    max_lengths=model.n_pos_embeddings - 1,  # A bit restrictive here
                                    dialog_embeddings=args.dialog_embeddings,
                                    cache=trainer_config.valid_datasets_cache,
                                    use_start_end=False,
                                    negative_samples=-1,  # Keep all negative samples
                                    augment=False,
                                    aug_syn_proba=0.0,
                                    limit_size=trainer_config.limit_eval_size,
                                    max_history_size=trainer_config.max_history_size,
                                    single_input=args.single_input,
                                    data_type=args.data_type,
                                    parsed_data=parsed_valid_data)
    test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer,
                                   max_lengths=model.n_pos_embeddings - 1,  # A bit restrictive here
                                   dialog_embeddings=args.dialog_embeddings,
                                   cache=trainer_config.test_datasets_cache,
                                   use_start_end=False,
                                   negative_samples=-1,  # Keep all negative samples
                                   augment=False,
                                   aug_syn_proba=0.0,
                                   limit_size=trainer_config.limit_eval_size,
                                   max_history_size=trainer_config.max_history_size,
                                   single_input=args.single_input,
                                   data_type=args.data_type,
                                   parsed_data=parsed_test_data)
    logger.info('train dataset {} valid dataset {} test dataset {}'
                .format(len(train_dataset), len(valid_dataset), len(test_dataset)))

    # if args.local_rank != -1:
    #     os.environ['MASTER_ADDR'] = 'localhost'
    #     os.environ['MASTER_PORT'] = '12355'
    #
    #     # initialize the process group
    #     torch.distributed.init_process_group("nccl", rank=args.local_rank, world_size=1)
    #     n = torch.cuda.device_count()
    #     device_ids = list(range(args.local_rank * n, (args.local_rank + 1) * n))
    #     torch.cuda.set_device(args.local_rank)
    #     device = torch.device('cuda', args.local_rank)
    #     transformer.distribute(device_ids[0], device_ids)
    '''Normal training will use normal trainer'''
    model_trainer = Trainer(model,
                            train_dataset,
                            writer,
                            logger=logger,
                            valid_dataset=valid_dataset,
                            test_dataset=test_dataset,
                            train_batch_size=trainer_config.train_batch_size,
                            batch_split=trainer_config.batch_split,
                            test_batch_size=trainer_config.test_batch_size,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            weight_decay=trainer_config.weight_decay,
                            s2s_weight=trainer_config.s2s_weight,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            hits_weight=trainer_config.hits_weight,
                            single_input=trainer_config.single_input,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=tokenizer.all_special_ids,
                            local_rank=args.local_rank,
                            apex_level=None,
                            apex_loss_scale=trainer_config.apex_loss_scale,
                            linear_schedule=trainer_config.linear_schedule,
                            n_epochs=trainer_config.n_epochs,
                            evaluate_full_sequences=trainer_config.evaluate_full_sequences,
                            full_input=trainer_config.full_input,
                            uncertainty_loss=args.uncertainty_loss,
                            best_model_path=best_checkpoint_path,
                            extra_module_lr_rate=args.extra_module_lr_rate,
                            no_persona=args.no_persona)

    if args.load_last:
        state_dict = torch.load(trainer_config.load_last, map_location=device)
        model_trainer.load_state_dict(state_dict)

    # helpers -----------------------------------------------------
    def external_metrics_func(full_references, full_predictions, epoch, metric=None, is_best=False):
        if epoch == -1:
            if is_best:
                references_file_path = os.path.join(writer.logdir, 'test_references_file')
                predictions_file_path = os.path.join(writer.logdir, 'test_predictions_file_best')
            else:
                references_file_path = os.path.join(writer.logdir, 'test_references_file')
                predictions_file_path = os.path.join(writer.logdir, 'test_predictions_file_last')
        else:
            references_file_path = os.path.join(writer.logdir, trainer_config.eval_references_file)
            predictions_file_path = os.path.join(writer.logdir, trainer_config.eval_predictions_file + "_{}".format(epoch))
        if not os.path.exists(references_file_path):
            with open(references_file_path, 'w', encoding='utf-8') as f:
                f.write('\n'.join(full_references))
        # print(len(full_predictions))
        with open(os.path.join(writer.logdir, 'tt.json'), 'w') as f:
            json.dump(full_predictions, f)
        with open(predictions_file_path, 'w', encoding='utf-8') as f:
            if len(full_predictions[-1]) == 0:
                full_predictions[-1] = 'a '
            f.write('\n'.join(full_predictions))

        bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy, meteor, \
                rouge_l, f1_score, avg_length = nlp_metrics(references_file_path, predictions_file_path, root_path=log_dir)

        metrics = {'meteor': meteor, 'avg_len': avg_length, 'rouge-l': rouge_l, 'bleu': bleu, 'nist': nist,
                   'nist-bleu': nist_bleu, 'f1': f1_score}
        for name, metric in (('bleu', bleu_list), ('nist', nist_list), ('nist_bleu', nist_bleu_list), ('entropy', entropy),
                             ('sentence_div', s_dist), ('corpus_div', c_dist)):
            for i, m in enumerate(metric, 1):
                metrics['{}_{}'.format(name, i)] = m

        return metrics

    def save_func(epoch):
        if epoch != -1:
            torch.save(model_trainer.model.state_dict(), last_checkpoint_path)
            logger.info('Model on Epoch %d has been saved', epoch)

    def sample_text_func(epoch):
        n_samples = 0
        model_trainer.model.eval()
        samples_idxs = random.sample(range(len(valid_dataset)), n_samples)
        samples = [valid_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target, _ in samples:
            contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [persona_info, dialog] if len(c) > 0]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = tokenizer.ids2string(persona_info[1:-1])
            dialog_str = tokenizer.ids2string(dialog)
            dialog_str = dialog_str.replace(tokenizer.talker1_bos, '\n\t- ').replace(tokenizer.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(tokenizer.talker1_eos, '').replace(tokenizer.talker2_eos, '')
            target_str = tokenizer.ids2string(target[1:-1])
            prediction_str = tokenizer.ids2string(prediction)

            logger.info('\n')
            logger.info('Persona info:\n\t{}'.format(persona_info_str))
            logger.info('Dialog:{}'.format(dialog_str))
            logger.info('Target:\n\t{}'.format(target_str))
            logger.info('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch+1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs, external_metrics_func, epoch)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        assert all([0 <= s <= 1.0 for s in scores])
        return [1 - s for s in scores]

    def get_risk_metric_func(risk_metric):
        """ risk_metric selected in:
            f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4}
        """
        def external_metric_risk(predictions, targets):
            string_targets = list(tokenizer.ids2string(t) for t in targets)
            string_predictions = list(tokenizer.ids2string(t) for t in predictions)
            metrics = [external_metrics_func([t], [p], epoch=-1, metric=risk_metric) for p, t in zip(string_predictions, string_targets)]

            if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]):
                return [-m for m in metrics]

            assert all([0 <= s <= 1.0 for s in metrics]), metrics

            return [1 - m for m in metrics]

        if risk_metric == 'f1':
            return f1_risk

        return external_metric_risk

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(after_epoch_funcs=[save_func, sample_text_func, test_func],
                            risk_func=get_risk_metric_func(trainer_config.risk_metric))
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        if args.local_rank in [-1, 0]:
            torch.save(model_trainer.state_dict(), interrupt_checkpoint_path)
        raise e
예제 #4
0
def training_procedure(args, trainer_config, model, tokenizer, device, writer, logger, best_checkpoint_path,
                       last_checkpoint_path, interrupt_checkpoint_path, log_dir, test_data_type=None):
    logger.info("trainer config: {}".format(trainer_config))
    logger.info('loading datasets')
    train_dataset = FacebookDataset(trainer_config.train_datasets, tokenizer,
                                    max_lengths=model.n_pos_embeddings - 1,  # A bit restrictive here
                                    dialog_embeddings=args.dialog_embeddings,
                                    cache=trainer_config.train_datasets_cache,
                                    use_start_end=False,
                                    augment=trainer_config.persona_augment,
                                    aug_syn_proba=trainer_config.persona_aug_syn_proba,
                                    limit_size=trainer_config.limit_train_size,
                                    max_history_size=trainer_config.max_history_size,
                                    data_type=trainer_config.data_type)
    valid_dataset = FacebookDataset(trainer_config.valid_datasets, tokenizer,
                                    max_lengths=model.n_pos_embeddings - 1,  # A bit restrictive here
                                    dialog_embeddings=args.dialog_embeddings,
                                    cache=trainer_config.valid_datasets_cache,
                                    use_start_end=False,
                                    augment=False,
                                    aug_syn_proba=0.0,
                                    limit_size=trainer_config.limit_eval_size,
                                    max_history_size=trainer_config.max_history_size,
                                    data_type=trainer_config.data_type)
    if test_data_type is None:
        test_data_type = trainer_config.data_type
    test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer,
                                   max_lengths=model.n_pos_embeddings - 1,  # A bit restrictive here
                                   dialog_embeddings=args.dialog_embeddings,
                                   cache=trainer_config.test_datasets_cache,
                                   use_start_end=False,
                                   augment=False,
                                   aug_syn_proba=0.0,
                                   limit_size=trainer_config.limit_eval_size,
                                   max_history_size=trainer_config.max_history_size,
                                   data_type=test_data_type)
    mixup_dataset = None
    if args.mixup:
        logger.info('Load Mixup neighbor dict')
        mixup_dataset = MixUpDataset(trainer_config.train_datasets, tokenizer, args.mixup_model_path,
                                     cache=trainer_config.mixup_cache, data_type=args.data_type,
                                     th=args.mixup_candidate_th)
    logger.info('train dataset {} valid dataset {} test dataset {}'
                .format(len(train_dataset), len(valid_dataset), len(test_dataset)))

    '''Normal training will use normal trainer'''
    model_trainer = Trainer(model,
                            train_dataset,
                            trainer_config,
                            writer,
                            logger=logger,
                            valid_dataset=valid_dataset,
                            test_dataset=test_dataset,
                            n_jobs=trainer_config.n_jobs,
                            device=device,
                            ignore_idxs=tokenizer.all_special_ids,
                            evaluate_full_sequences=trainer_config.evaluate_full_sequences,
                            full_input=trainer_config.full_input,
                            best_model_path=best_checkpoint_path,
                            no_persona=args.no_persona,
                            mixup=args.mixup,
                            mixup_dataset=mixup_dataset,
                            mixup_ratio=args.mixup_ratio,
                            bert_mixup=args.bert_mixup,
                            replace=args.replace,
                            pointer_gen=args.pointer_gen)

    if args.load_last:
        state_dict = torch.load(trainer_config.load_last, map_location=device)
        model_trainer.load_state_dict(state_dict)

    # helpers -----------------------------------------------------
    def external_metrics_func(full_references, full_predictions, epoch, is_best=False):
        if epoch == -1:
            if is_best:
                references_file_path = os.path.join(writer.logdir, trainer_config.test_references_file)
                predictions_file_path = os.path.join(writer.logdir,  trainer_config.test_predictions_file_best)
            else:
                references_file_path = os.path.join(writer.logdir, trainer_config.test_references_file)
                predictions_file_path = os.path.join(writer.logdir, trainer_config.test_predictions_file_last)
        else:
            references_file_path = os.path.join(writer.logdir, trainer_config.eval_references_file)
            predictions_file_path = os.path.join(writer.logdir,
                                                 trainer_config.eval_predictions_file + "_{}".format(epoch))

        if not os.path.exists(references_file_path):
            with open(references_file_path, 'w', encoding='utf-8') as f:
                f.write('\n'.join(full_references))
        # print(len(full_predictions))
        with open(os.path.join(writer.logdir, 'tt.json'), 'w') as f:
            json.dump(full_predictions, f)
        with open(predictions_file_path, 'w', encoding='utf-8') as f:
            if len(full_predictions[-1]) == 0:
                full_predictions[-1] = 'a '
            f.write('\n'.join(full_predictions))

        bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy, meteor, \
        rouge_l, f1_score, avg_length = nlp_metrics(references_file_path, predictions_file_path, root_path=log_dir)

        metrics = {'meteor': meteor * 100, 'avg_len': avg_length, 'rouge-l': rouge_l * 100, 'bleu': bleu, 'nist': nist,
                   'nist-bleu': nist_bleu, 'f1': f1_score * 100}
        for name, metric in (
        ('bleu', bleu_list), ('nist', nist_list), ('nist_bleu', nist_bleu_list), ('entropy', entropy),
        ('sentence_div', s_dist), ('corpus_div', c_dist)):
            for i, m in enumerate(metric, 1):
                if name == 'sentence_div' or name == 'corpus_div':
                    metrics['{}_{}'.format(name, i)] = m * 100
                else:
                    metrics['{}_{}'.format(name, i)] = m
        for k, v in metrics.items():
            metrics[k] = round(v, 6)

        return metrics

    def save_func(epoch):
        if epoch != -1:
            torch.save(model_trainer.model.state_dict(), last_checkpoint_path)
            logger.info('Model on Epoch %d has been saved', epoch)

    def sample_text_func(epoch):
        n_samples = 0
        model_trainer.model.eval()
        samples_idxs = random.sample(range(len(valid_dataset)), n_samples)
        samples = [valid_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target, _ in samples:
            contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [persona_info, dialog]
                        if len(c) > 0]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = tokenizer.ids2string(persona_info[1:-1])
            dialog_str = tokenizer.ids2string(dialog)
            dialog_str = dialog_str.replace(tokenizer.talker1_bos, '\n\t- ').replace(tokenizer.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(tokenizer.talker1_eos, '').replace(tokenizer.talker2_eos, '')
            target_str = tokenizer.ids2string(target[1:-1])
            prediction_str = tokenizer.ids2string(prediction)

            logger.info('\n')
            logger.info('Persona info:\n\t{}'.format(persona_info_str))
            logger.info('Dialog:{}'.format(dialog_str))
            logger.info('Target:\n\t{}'.format(target_str))
            logger.info('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs, external_metrics_func, epoch)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        assert all([0 <= s <= 1.0 for s in scores])
        return [1 - s for s in scores]

    def get_risk_metric_func(risk_metric):
        """ risk_metric selected in:
            f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4}
        """

        def external_metric_risk(predictions, targets):
            string_targets = list(tokenizer.ids2string(t) for t in targets)
            string_predictions = list(tokenizer.ids2string(t) for t in predictions)
            metrics = [external_metrics_func([t], [p], epoch=-1, metric=risk_metric) for p, t in
                       zip(string_predictions, string_targets)]

            if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]):
                return [-m for m in metrics]

            assert all([0 <= s <= 1.0 for s in metrics]), metrics

            return [1 - m for m in metrics]

        if risk_metric == 'f1':
            return f1_risk

        return external_metric_risk

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(after_epoch_funcs=[save_func, sample_text_func, test_func],
                            risk_func=get_risk_metric_func(trainer_config.risk_metric))
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        raise e
예제 #5
0
def main():
    model_config = get_model_config_dialog()
    trainer_config = get_trainer_config_dialog()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)
    # zrs
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    args = parser.parse_args()
    distributed = (args.local_rank != -1)
    if distributed:
        print(args.local_rank)
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    vocab = myVocab(model_config.vocab_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        temperature=model_config.temperature,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        openai_model = torch.load(trainer_config.openai_parameters_dir,
                                  map_location=device)
        openai_model.pop('decoder.pre_softmax.weight')
        b = list(openai_model.keys())
        for i in b:
            temp = i.split('.')
            keep = True
            for j in range(model_config.n_layers, 12):
                if str(j) in temp:
                    keep = False
                    break
            if keep:
                openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
            else:
                print(i)
                openai_model.pop(i)
            #openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
        transformer.transformer_module.load_state_dict(openai_model,
                                                       strict=True)
        # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir)
        print('OpenAI weights chinese loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab,
                                      transformer.n_pos_embeddings - 1)
    test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab,
                                     transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(
        transformer,
        train_dataset,
        test_dataset,
        batch_size=trainer_config.batch_size,
        batch_split=trainer_config.batch_split,
        lr=trainer_config.lr,
        lr_warmup=trainer_config.lr_warmup,
        lm_weight=trainer_config.lm_weight,
        risk_weight=trainer_config.risk_weight,
        n_jobs=trainer_config.n_jobs,
        clip_grad=trainer_config.clip_grad,
        # label_smoothing=trainer_config.label_smoothing,
        device=device,
        ignore_idxs=vocab.special_tokens_ids,
        distributed=distributed)
    if distributed:
        model_trainer.model.transformer_module = DistributedDataParallel(
            model_trainer.model.transformer_module,
            device_ids=[args.local_rank],
            output_device=args.local_rank)

    start_epoch = 0
    init_epoch = 0

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path +
                                str(init_epoch - 1),
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1
        start_epoch = init_epoch
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path + str(init_epoch - 1)))

    # helpers -----------------------------------------------------
    def save_func(epoch):
        dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1])
        if not os.path.exists(dirs):
            os.makedirs(dirs)
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path)
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path + str(epoch))
        if os.path.exists(trainer_config.last_checkpoint_path +
                          str(epoch - 100)):
            os.remove(trainer_config.last_checkpoint_path + str(epoch - 100))

    def sample_text_func(epoch):
        n_samples = 5
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for source, target in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device) for c in [source]
                if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]
            source_str = vocab.ids2string(source)
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)
            print('\n')
            print('Source:{}'.format(source_str))
            print('Target:\n\t{}'.format(target_str))
            print('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        return [1 - s for s in scores]

    # helpers -----------------------------------------------------

    # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1])
    try:
        if args.local_rank in [-1, 0]:
            model_trainer.train(
                start_epoch,
                trainer_config.n_epochs,
                after_epoch_funcs=[save_func, sample_text_func, test_func],
                risk_func=f1_risk)
        else:
            model_trainer.train(start_epoch, trainer_config.n_epochs)
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        torch.save(model_trainer.state_dict(),
                   trainer_config.interrupt_checkpoint_path)
        raise e
예제 #6
0
def get_trainer():
    model_config = get_model_config()
    trainer_config = get_trainer_config()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        print('OpenAI weights loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = FacebookDataset(trainer_config.train_datasets, vocab,
                                    transformer.n_pos_embeddings - 1)
    test_dataset = FacebookDataset(trainer_config.test_datasets, vocab,
                                   transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(transformer,
                            train_dataset,
                            test_dataset,
                            batch_size=trainer_config.batch_size,
                            batch_split=trainer_config.batch_split,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=vocab.special_tokens_ids)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path,
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path))

    return model_trainer