Пример #1
0
    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(self.opt['gpu'])

        torch.set_grad_enabled(False)

        model_config = get_model_config()
        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)
        self.reply_checker = ReplyChecker(correct_generative=self.opt['correct_generative'],
                                          split_into_sentences=self.opt['split_into_sentences'])

        self.replace_repeat = self.opt['replace_repeat']
        self.replace_ngram = self.opt['replace_ngram']
        self.ngram_size = self.opt['ngram_size']
        self.detokenize = self.opt['detokenize']
        self.emoji_prob = self.opt['emoji_prob']
        self.add_questions = self.opt['add_questions']
        self.beam_size = self.opt['beam_size']

        self.clean_emoji = self.opt['clean_emoji']
        self.check_grammar = self.opt['check_grammar']

        # 'max_seq_len': 128,
        # 'beam_size': 1,
        # 'diversity_coef': 0,
        # 'diversity_groups': 1,
        # 'annealing_topk': None,
        # 'annealing': 0,
        # 'length_penalty': 0.6,

        if self.opt['annealing_topk'] is not None:
            assert self.opt['annealing_topk'] >= self.opt['beam_size']

        assert self.opt['diversity_coef'] >= 0
        assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0

        if shared is None:
            self.model = TransformerModel(n_layers=model_config.n_layers,
                                          n_embeddings=len(self.vocab),
                                          n_pos_embeddings=model_config.n_pos_embeddings,
                                          embeddings_size=model_config.embeddings_size,
                                          padding_idx=self.vocab.pad_id,
                                          n_heads=model_config.n_heads,
                                          dropout=model_config.dropout,
                                          embed_dropout=model_config.embed_dropout,
                                          attn_dropout=model_config.attn_dropout,
                                          ff_dropout=model_config.ff_dropout,
                                          bos_id=self.vocab.bos_id,
                                          eos_id=self.vocab.eos_id,
                                          max_seq_len=self.opt['max_seq_len'],
                                          beam_size=self.opt['beam_size'],
                                          length_penalty=self.opt['length_penalty'],
                                          n_segments=model_config.n_segments,
                                          sample=self.opt['sample'],
                                          annealing_topk=self.opt['annealing_topk'],
                                          annealing=self.opt['annealing'],
                                          diversity_coef=self.opt['diversity_coef'],
                                          diversity_groups=self.opt['diversity_groups'])
            self.retrieval_bot = RetrievalBot()

            state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage)
            if 'model' in state_dict:
                state_dict = state_dict['model']

            self.model.load_state_dict(state_dict)
            print('Weights loaded from {}'.format(model_config.checkpoint_path))

            if self.use_cuda:
                self.model = self.model.cuda()

            self.model.eval()

        else:
            self.model = shared['model']
            self.retrieval_bot = shared['retrieval']

        self.reset()
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help="Distributed training.")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Used for debugging on GPU machine.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Used for debugging on GPU machine.")
    args = parser.parse_args()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.ERROR)
    logger = logging.getLogger(__file__)
    if args.server_ip and args.server_port and args.local_rank in [-1, 0]:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    model_config = get_model_config()
    trainer_config = get_trainer_config()

    # Log only on main process
    if args.local_rank not in [-1, 0]:
        sys.stdout = open(f"./runs/log_distributed_{args.local_rank}",
                          "w")  # dump sdtout
        writer = DummyWriter()
    else:
        writer = SummaryWriter(comment=trainer_config.writer_comment)

    logger.info("model config: {}".format(model_config))
    logger.info("trainer config: {}".format(trainer_config))
    log_dir = writer.log_dir
    interrupt_checkpoint_path = os.path.join(
        log_dir, trainer_config.interrupt_checkpoint_path)
    last_checkpoint_path = os.path.join(log_dir,
                                        trainer_config.last_checkpoint_path)
    logger.info(
        "Logging to {}".format(log_dir)
    )  # Let's save everything on an experiment in the ./runs/XXX/directory
    if args.local_rank in [-1, 0]:
        with open(os.path.join(log_dir, "model_config.json"), "w") as f:
            json.dump(model_config, f)
        with open(os.path.join(log_dir, "trainer_config.json"), "w") as f:
            json.dump(trainer_config, f)

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path,
                                zero_shot=trainer_config.zero_shot)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        normalize_embeddings=model_config.normalize_embeddings,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        sent_dialog_id=vocab.sent_dialog_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups,
        multiple_choice_head=model_config.multiple_choice_head,
        constant_embedding=model_config.constant_embedding,
        single_input=model_config.single_input,
        dialog_embeddings=model_config.dialog_embeddings,
        share_models=model_config.share_models,
        successive_attention=model_config.successive_attention,
        sparse_embeddings=model_config.sparse_embeddings,
        shared_attention=model_config.shared_attention,
        bs_temperature=model_config.bs_temperature,
        bs_nucleus_p=model_config.bs_nucleus_p,
        vocab=None)  # for beam search debugging

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        if not model_config.share_models:
            load_openai_weights(transformer.encoder_module,
                                trainer_config.openai_parameters_dir,
                                n_special_tokens=vocab.n_special_tokens)
        logger.info('OpenAI weights loaded from {}, model shared: {}'.format(
            trainer_config.openai_parameters_dir, model_config.share_models))

    logger.info('loading datasets')
    train_dataset = FacebookDataset(
        trainer_config.train_datasets,
        vocab,
        max_lengths=(transformer.n_pos_embeddings - 1) //
        (3 if model_config.single_input else 1),  # A bit restrictive here
        dialog_embeddings=model_config.dialog_embeddings,
        cache=trainer_config.train_datasets_cache,
        use_start_end=model_config.use_start_end,
        negative_samples=trainer_config.negative_samples,
        augment=trainer_config.persona_augment,
        aug_syn_proba=trainer_config.persona_aug_syn_proba,
        limit_size=trainer_config.limit_train_size)
    test_dataset = FacebookDataset(
        trainer_config.test_datasets,
        vocab,
        max_lengths=(transformer.n_pos_embeddings - 1) //
        (3 if model_config.single_input else 1),  # A bit restrictive here
        dialog_embeddings=model_config.dialog_embeddings,
        cache=trainer_config.test_datasets_cache,
        use_start_end=model_config.use_start_end,
        negative_samples=-1,  # Keep all negative samples
        augment=False,
        aug_syn_proba=0.0,
        limit_size=trainer_config.limit_eval_size)
    logger.info(
        f'train dataset {len(train_dataset)} test dataset {(test_dataset)}')

    if args.local_rank != -1:

        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        transformer.distribute(device)

    model_trainer = Trainer(
        transformer,
        train_dataset,
        writer,
        test_dataset,
        train_batch_size=trainer_config.train_batch_size,
        batch_split=trainer_config.batch_split,
        test_batch_size=trainer_config.test_batch_size,
        lr=trainer_config.lr,
        lr_warmup=trainer_config.lr_warmup,
        weight_decay=trainer_config.weight_decay,
        s2s_weight=trainer_config.s2s_weight,
        lm_weight=trainer_config.lm_weight,
        risk_weight=trainer_config.risk_weight,
        hits_weight=trainer_config.hits_weight,
        single_input=model_config.single_input,
        n_jobs=trainer_config.n_jobs,
        clip_grad=trainer_config.clip_grad,
        device=device,
        ignore_idxs=vocab.special_tokens_ids,
        local_rank=args.local_rank,
        apex_level=model_config.apex_level,
        apex_loss_scale=trainer_config.apex_loss_scale,
        linear_schedule=trainer_config.linear_schedule,
        n_epochs=trainer_config.n_epochs,
        evaluate_full_sequences=trainer_config.evaluate_full_sequences)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.load_last, map_location=device)
        model_trainer.load_state_dict(state_dict)
        logger.info('Weights loaded from {}'.format(trainer_config.load_last))

    # helpers -----------------------------------------------------
    def external_metrics_func(full_references,
                              full_predictions,
                              epoch,
                              metric=None):
        references_file_path = os.path.join(
            writer.log_dir,
            trainer_config.eval_references_file + "_{}".format(epoch))
        predictions_file_path = os.path.join(
            writer.log_dir,
            trainer_config.eval_predictions_file + "_{}".format(epoch))
        with open(references_file_path, 'w', encoding='utf-8') as f:
            f.write(unicode('\n'.join(full_references)))
        with open(predictions_file_path, 'w', encoding='utf-8') as f:
            f.write(unicode('\n'.join(full_predictions)))

        if metric is not None:
            return specified_nlp_metric([references_file_path],
                                        predictions_file_path, metric)

        nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(
            [references_file_path], predictions_file_path)

        metrics = {'meteor': meteor, 'avg_len': avg_len}
        for name, metric in (('nist', nist), ('entropy', entropy),
                             ('div', div), ('bleu', bleu)):
            for i, m in enumerate(metric, 1):
                metrics['{}_{}'.format(name, i)] = m

        return metrics

    def save_func(epoch):
        if epoch != -1:
            torch.save(model_trainer.state_dict(), last_checkpoint_path)

    def sample_text_func(epoch):
        n_samples = 0
        model_trainer.model.eval()
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target, _ in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device)
                for c in [persona_info, dialog] if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = vocab.ids2string(persona_info[1:-1])
            dialog_str = vocab.ids2string(dialog)
            dialog_str = dialog_str.replace(vocab.talker1_bos,
                                            '\n\t- ').replace(
                                                vocab.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(vocab.talker1_eos,
                                            '').replace(vocab.talker2_eos, '')
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)

            logger.info('\n')
            logger.info('Persona info:\n\t{}'.format(persona_info_str))
            logger.info('Dialog:{}'.format(dialog_str))
            logger.info('Target:\n\t{}'.format(target_str))
            logger.info('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs, external_metrics_func, epoch)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        assert all([0 <= s <= 1.0 for s in scores])
        return [1 - s for s in scores]

    def get_risk_metric_func(risk_metric):
        """ risk_metric selected in:
            f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4}
        """
        def external_metric_risk(predictions, targets):
            string_targets = list(vocab.ids2string(t) for t in targets)
            string_predictions = list(vocab.ids2string(t) for t in predictions)
            metrics = [
                external_metrics_func([t], [p], epoch=-1, metric=risk_metric)
                for p, t in zip(string_predictions, string_targets)
            ]

            if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]):
                return [-m for m in metrics]

            assert all([0 <= s <= 1.0 for s in metrics]), metrics

            return [1 - m for m in metrics]

        if risk_metric == 'f1':
            return f1_risk

        return external_metric_risk

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(
            after_epoch_funcs=[save_func, sample_text_func, test_func],
            risk_func=get_risk_metric_func(trainer_config.risk_metric))
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        if args.local_rank in [-1, 0]:
            torch.save(model_trainer.state_dict(), interrupt_checkpoint_path)
        raise e
Пример #3
0
                else:
                    dialog_idx = int(line[:space_idx])

                if int(dialog_idx) == 1:
                    data.append({'persona_info': [], 'dialog': []})

                dialog_line = line[space_idx + 1:].split('\t')
                dialog_line = [l.strip() for l in dialog_line]

                if dialog_line[0].startswith('your persona:'):
                    persona_info = dialog_line[0].replace('your persona: ', '')
                    data[-1]['persona_info'].append(persona_info)

                elif len(dialog_line) > 1:
                    data[-1]['dialog'].append(dialog_line[0])
                    data[-1]['dialog'].append(dialog_line[1])
def make_dataset(data, vocab, max_lengths):
        dataset = []
        for chat in data:
            persona_info = [vocab.string2ids(s) for s in chat['persona_info']]
            dialog = [vocab.string2ids(s) for s in chat['dialog']]

            if len(dialog) % 2 == 1:
                dialog = dialog[:-1]
           
            dataset.append((persona_info, dialog))

        return dataset

vocab = BPEVocab.from_files('./parameters/bpe.vocab', './parameters/bpe.vocab')
data_updated =make_dataset(data, vocab, 511)
Пример #4
0
def main():
    model_config = get_model_config()
    trainer_config = get_trainer_config()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        print('OpenAI weights loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = FacebookDataset(trainer_config.train_datasets, vocab,
                                    transformer.n_pos_embeddings - 1)
    test_dataset = FacebookDataset(trainer_config.test_datasets, vocab,
                                   transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(transformer,
                            train_dataset,
                            test_dataset,
                            batch_size=trainer_config.batch_size,
                            batch_split=trainer_config.batch_split,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=vocab.special_tokens_ids)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path,
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path))

    # helpers -----------------------------------------------------
    def save_func(epoch):
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path)

    def sample_text_func(epoch):
        n_samples = 5
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device)
                for c in [persona_info, dialog] if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = vocab.ids2string(persona_info[1:-1])
            dialog_str = vocab.ids2string(dialog)
            dialog_str = dialog_str.replace(vocab.talker1_bos,
                                            '\n\t- ').replace(
                                                vocab.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(vocab.talker1_eos,
                                            '').replace(vocab.talker2_eos, '')
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)

            print('\n')
            print('Persona info:\n\t{}'.format(persona_info_str))
            print('Dialog:{}'.format(dialog_str))
            print('Target:\n\t{}'.format(target_str))
            print('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        return [1 - s for s in scores]

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(
            trainer_config.n_epochs,
            after_epoch_funcs=[save_func, sample_text_func, test_func],
            risk_func=f1_risk)
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        torch.save(model_trainer.state_dict(),
                   trainer_config.interrupt_checkpoint_path)
        raise e
    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(self.opt['gpu'])

        torch.set_grad_enabled(False)

        model_config = get_model_config()
        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

        self.dialog_embeddings = model_config.dialog_embeddings
        self.use_start_end = model_config.use_start_end
        self.single_input = model_config.single_input
        self.apex_level = model_config.apex_level

        # 'max_seq_len': 128,
        # 'beam_size': 1,
        # 'diversity_coef': 0,
        # 'diversity_groups': 1,
        # 'annealing_topk': None,
        # 'annealing': 0,
        # 'length_penalty': 0.6,

        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

        if self.opt['annealing_topk'] is not None:
            assert self.opt['annealing_topk'] > self.opt['beam_size']

        assert self.opt['diversity_coef'] >= 0
        assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0

        if shared is None:
            self.model = TransformerModel(n_layers=model_config.n_layers,
                                          n_embeddings=len(self.vocab),
                                          n_pos_embeddings=model_config.n_pos_embeddings,
                                          embeddings_size=model_config.embeddings_size,
                                          padding_idx=self.vocab.pad_id,
                                          n_heads=model_config.n_heads,
                                          dropout=model_config.dropout,
                                          embed_dropout=model_config.embed_dropout,
                                          attn_dropout=model_config.attn_dropout,
                                          ff_dropout=model_config.ff_dropout,
                                          bos_id=self.vocab.bos_id,
                                          eos_id=self.vocab.eos_id,
                                          sent_dialog_id=self.vocab.sent_dialog_id,
                                          max_seq_len=self.opt['max_seq_len'],
                                          beam_size=self.opt['beam_size'],
                                          length_penalty=self.opt['length_penalty'],
                                          n_segments=model_config.n_segments,
                                          sample=self.opt['sample'],
                                          annealing_topk=self.opt['annealing_topk'],
                                          annealing=self.opt['annealing'],
                                          diversity_coef=self.opt['diversity_coef'],
                                          diversity_groups=self.opt['diversity_groups'],
                                          normalize_embeddings=model_config.normalize_embeddings,
                                          multiple_choice_head=model_config.multiple_choice_head,
                                          constant_embedding=model_config.constant_embedding,
                                          vocab=self.vocab,
                                          single_input=model_config.single_input,
                                          dialog_embeddings=model_config.dialog_embeddings,
                                          share_models=model_config.share_models,
                                          successive_attention=model_config.successive_attention,
                                          sparse_embeddings=model_config.sparse_embeddings,
                                          shared_attention=model_config.sparse_embeddings,
                                          bs_temperature=model_config.bs_temperature,
                                          bs_nucleus_p=model_config.bs_nucleus_p
                                          )

            state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage)
            if 'model' in state_dict:
                state_dict = state_dict['model']

            self.model.load_state_dict(state_dict)
            print('Weights loaded from {}'.format(model_config.checkpoint_path))

            if self.use_cuda:
                self.model = self.model.cuda()

            self.model.eval()

            self.model = apex_model(self.model, apex_level=self.apex_level)

        else:
            self.model = shared['model']

        self.reset()
Пример #6
0
def get_trainer():
    model_config = get_model_config()
    trainer_config = get_trainer_config()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        print('OpenAI weights loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = FacebookDataset(trainer_config.train_datasets, vocab,
                                    transformer.n_pos_embeddings - 1)
    test_dataset = FacebookDataset(trainer_config.test_datasets, vocab,
                                   transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(transformer,
                            train_dataset,
                            test_dataset,
                            batch_size=trainer_config.batch_size,
                            batch_split=trainer_config.batch_split,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=vocab.special_tokens_ids)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path,
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path))

    return model_trainer