Beispiel #1
0
    def __init__(
        self,
        n_class,
        cnn_kwargs=None,
        encoder_kwargs=None,
        encoder_type="Conformer",
        pooling="token",
        layer_init="pytorch",
    ):
        super(SEDModel, self).__init__()

        self.cnn = CNN(n_in_channel=1, **cnn_kwargs)
        input_dim = self.cnn.nb_filters[-1]
        adim = encoder_kwargs["adim"]
        self.pooling = pooling

        if encoder_type == "Transformer":
            self.encoder = TransformerEncoder(input_dim, **encoder_kwargs)
        elif encoder_type == "Conformer":
            self.encoder = ConformerEncoder(input_dim, **encoder_kwargs)
        else:
            raise ValueError("Choose encoder_type in ['Transformer', 'Conformer']")

        self.classifier = torch.nn.Linear(adim, n_class)

        if self.pooling == "attention":
            self.dense = torch.nn.Linear(adim, n_class)
            self.sigmoid = torch.sigmoid
            self.softmax = torch.nn.Softmax(dim=-1)

        elif self.pooling == "token":
            self.linear_emb = torch.nn.Linear(1, input_dim)

        self.reset_parameters(layer_init)
Beispiel #2
0
 def __init__(self,
              d_model,
              num_heads,
              num_encoder_layers,
              num_decoder_layers,
              dropout=0.0):
     super().__init__()
     self.encoder = TransformerEncoder(num_encoder_layers, d_model,
                                       num_heads, dropout)
     self.decoder = TransformerDecoder(num_decoder_layers, d_model,
                                       num_heads, dropout)
Beispiel #3
0
    def __init__(self,
                initialize_xavier: bool,
                src_vocab_size: int,
                tgt_vocab_size: int,
                src_embedding: nn.Embedding,
                tgt_embedding: nn.Embedding,
                d_word_vec: int,
                n_enc_blocks=constants.DEFAULT_ENCODER_BLOCKS,
                n_head=constants.DEFAULT_NUMBER_OF_ATTENTION_HEADS,
                d_model=constants.DEFAULT_LAYER_SIZE,
                dropout_rate=constants.DEFAULT_MODEL_DROPOUT,
                pointwise_layer_size=constants.DEFAULT_DIMENSION_OF_PWFC_HIDDEN_LAYER,
                d_k=constants.DEFAULT_DIMENSION_OF_KEYQUERY_WEIGHTS,
                d_v=constants.DEFAULT_DIMENSION_OF_VALUE_WEIGHTS):
        """Initializes a Transformer Model
        
        Arguments:
            initialize_xavier {boolean} -- [description]
            src_vocab_size {int} -- Size / Length of the source vocabulary (how many tokens)
            tgt_vocab_size {[type]} -- Size / Length of the target vocabulary (how many tokens)
            d_word_vec {[type]} -- Dimension of word vectors (default: 512)
        
        Keyword Arguments:
            n_enc_blocks {[type]} -- [description] (default: {constants.DEFAULT_ENCODER_BLOCKS})
            n_head {[type]} -- [description] (default: {constants.DEFAULT_NUMBER_OF_ATTENTION_HEADS})
            d_model {[type]} -- [description] (default: {constants.DEFAULT_LAYER_SIZE})
            dropout_rate {[type]} -- [description] (default: {constants.DEFAULT_MODEL_DROPOUT})
            pointwise_layer_size {[type]} -- [description] (default: {constants.DEFAULT_DIMENSION_OF_PWFC_HIDDEN_LAYER})
            d_k {[type]} -- [description] (default: {constants.DEFAULT_DIMENSION_OF_KEYQUERY_WEIGHTS})
            d_v {[type]} -- [description] (default: {constants.DEFAULT_DIMENSION_OF_VALUE_WEIGHTS})
        """

        super(GoogleTransformer, self).__init__()

        assert d_model == d_word_vec, \
        'To facilitate the residual connections, \
         the dimensions of all module outputs shall be the same. Input (Word Embedding) = Output of Encoder Layer'

        self._init_loggers()

        self.encoder = TransformerEncoder(src_embedding)
        self.decoder = TransformerDecoder(tgt_embedding)

        # generate a last layer that projects from the last output of the decoder layer
        # with size d_model to the size of the target vocabulary
        self.decoder_to_tgt_vocabulary = nn.Linear(d_model, tgt_vocab_size, bias=False)

        self.generator = TransformerTargetGenerator(d_model, src_vocab_size)

        if initialize_xavier:
            for p in self.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)
def load_model(dataset, rc, experiment_name):
    loss = LossCombiner(4, dataset.class_weights, NllLoss)
    transformer = TransformerEncoder(dataset.source_embedding,
                                     hyperparameters=rc)
    model = JointAspectTagger(transformer, rc, 4, 20, dataset.target_names)
    optimizer = get_optimizer(model, rc)
    trainer = Trainer(model,
                      loss,
                      optimizer,
                      rc,
                      dataset,
                      experiment_name,
                      enable_tensorboard=False,
                      verbose=False)
    return trainer
Beispiel #5
0
def get_encoder(st_ds_conf: dict):
    emb_sz = st_ds_conf['emb_sz']
    if st_ds_conf['encoder'] == 'lstm':
        encoder = StackedEncoder(
            [
                PytorchSeq2SeqWrapper(
                    torch.nn.LSTM(emb_sz, emb_sz, batch_first=True))
                for _ in range(st_ds_conf['num_enc_layers'])
            ],
            emb_sz,
            emb_sz,
            input_dropout=st_ds_conf['intermediate_dropout'])
    elif st_ds_conf['encoder'] == 'bilstm':
        encoder = StackedEncoder(
            [
                PytorchSeq2SeqWrapper(
                    torch.nn.LSTM(
                        emb_sz, emb_sz, batch_first=True, bidirectional=True))
            ] + [
                PytorchSeq2SeqWrapper(
                    torch.nn.LSTM(emb_sz * 2,
                                  emb_sz,
                                  batch_first=True,
                                  bidirectional=True))
                for _ in range(st_ds_conf['num_enc_layers'] - 1)
            ],
            emb_sz,
            emb_sz * 2,
            input_dropout=st_ds_conf['intermediate_dropout'])
    elif st_ds_conf['encoder'] == 'transformer':
        encoder = StackedEncoder([
            TransformerEncoder(
                input_dim=emb_sz,
                num_layers=st_ds_conf['num_enc_layers'],
                num_heads=st_ds_conf['num_heads'],
                feedforward_hidden_dim=emb_sz,
                feedforward_dropout=st_ds_conf['feedforward_dropout'],
                residual_dropout=st_ds_conf['residual_dropout'],
                attention_dropout=st_ds_conf['attention_dropout'],
            ) for _ in range(st_ds_conf['num_enc_layers'])
        ],
                                 emb_sz,
                                 emb_sz,
                                 input_dropout=0.)
    else:
        assert False
    return encoder
    def load_model(self, dataset, rc, experiment_name, iteration):
        loss = LossCombiner(4, dataset.class_weights, NllLoss)

        if self.produce_baseline:
            iteration = 0

        if iteration == 0:
            self.current_transformer = TransformerEncoder(
                dataset.source_embedding, hyperparameters=rc)
            model = JointAspectTagger(self.current_transformer,
                                      rc,
                                      4,
                                      20,
                                      dataset.target_names,
                                      initialize_params=True)
        else:
            model = JointAspectTagger(self.current_transformer,
                                      rc,
                                      4,
                                      20,
                                      dataset.target_names,
                                      initialize_params=False)

        optimizer = get_optimizer(model, rc)
        trainer = Trainer(model,
                          loss,
                          optimizer,
                          rc,
                          dataset,
                          experiment_name,
                          enable_tensorboard=True,
                          verbose=True)

        # see if we might be able to restore the source model
        if iteration == 0 and self.load_model_path is not None:
            model, optimizer, epoch = trainer.load_model(
                custom_path=self.load_model_path)
            self.skip_source_training = True

        return trainer
Beispiel #7
0
	def load_model(self, dataset, rc, experiment_name):

		transformer = TransformerEncoder(dataset.source_embedding,
										hyperparameters=rc)

		if rc.use_random_classifier:
			from models.random_model import RandomModel
			model = RandomModel(rc, dataset.target_size, len(dataset.target_names), dataset.target_names)
			loss = NllLoss(dataset.target_size, dataset.class_weights[0])

			
		else:
			# NER or ABSA-task?
			if rc.task == 'ner':
				from models.transformer_tagger import TransformerTagger
				from models.output_layers import SoftmaxOutputLayer
				loss = NllLoss(dataset.target_size, dataset.class_weights[0])
				softmax = SoftmaxOutputLayer(rc.model_size, dataset.target_size)
				model = TransformerTagger(transformer, softmax)

			else:
				from models.jointAspectTagger import JointAspectTagger
				loss = LossCombiner(dataset.target_size, dataset.class_weights, NllLoss)
				model = JointAspectTagger(transformer, rc, dataset.target_size, len(dataset.target_names), dataset.target_names)


		optimizer = get_optimizer(model, rc)
		trainer = Trainer(
							model,
							loss,
							optimizer,
							rc,
							dataset,
							experiment_name,
							enable_tensorboard=False,
							verbose=False)
		return trainer
Beispiel #8
0
def main():
    parser = utils.opt_parser.get_trainer_opt_parser()
    parser.add_argument('models',
                        nargs='*',
                        help='pretrained models for the same setting')
    parser.add_argument('--test', action="store_true", help='use testing mode')
    parser.add_argument('--num-layer',
                        type=int,
                        help="stacked layer of transformer model")

    args = parser.parse_args()

    reader = data_adapter.GeoQueryDatasetReader()
    training_set = reader.read(config.DATASETS[args.dataset].train_path)
    try:
        validation_set = reader.read(config.DATASETS[args.dataset].dev_path)
    except:
        validation_set = None

    vocab = allennlp.data.Vocabulary.from_instances(training_set)
    st_ds_conf = config.TRANSFORMER_CONF[args.dataset]
    if args.num_layer:
        st_ds_conf['num_layers'] = args.num_layer

    encoder = TransformerEncoder(
        input_dim=st_ds_conf['emb_sz'],
        num_layers=st_ds_conf['num_layers'],
        num_heads=st_ds_conf['num_heads'],
        feedforward_hidden_dim=st_ds_conf['emb_sz'],
    )
    decoder = TransformerDecoder(
        input_dim=st_ds_conf['emb_sz'],
        num_layers=st_ds_conf['num_layers'],
        num_heads=st_ds_conf['num_heads'],
        feedforward_hidden_dim=st_ds_conf['emb_sz'],
        feedforward_dropout=0.1,
    )
    source_embedding = allennlp.modules.Embedding(
        num_embeddings=vocab.get_vocab_size('nltokens'),
        embedding_dim=st_ds_conf['emb_sz'])
    target_embedding = allennlp.modules.Embedding(
        num_embeddings=vocab.get_vocab_size('lftokens'),
        embedding_dim=st_ds_conf['emb_sz'])
    model = ParallelSeq2Seq(
        vocab=vocab,
        encoder=encoder,
        decoder=decoder,
        source_embedding=source_embedding,
        target_embedding=target_embedding,
        target_namespace='lftokens',
        start_symbol=START_SYMBOL,
        eos_symbol=END_SYMBOL,
        max_decoding_step=st_ds_conf['max_decoding_len'],
    )

    if args.models:
        model.load_state_dict(torch.load(args.models[0]))

    if not args.test or not args.models:
        iterator = BucketIterator(sorting_keys=[("source_tokens", "num_tokens")
                                                ],
                                  batch_size=st_ds_conf['batch_sz'])
        iterator.index_with(vocab)

        optim = torch.optim.Adam(model.parameters())

        savepath = os.path.join(
            config.SNAPSHOT_PATH, args.dataset, 'transformer',
            datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
        if not os.path.exists(savepath):
            os.makedirs(savepath, mode=0o755)

        trainer = allennlp.training.Trainer(
            model=model,
            optimizer=optim,
            iterator=iterator,
            train_dataset=training_set,
            validation_dataset=validation_set,
            serialization_dir=savepath,
            cuda_device=args.device,
            num_epochs=config.TRAINING_LIMIT,
        )

        trainer.train()

    else:
        testing_set = reader.read(config.DATASETS[args.dataset].test_path)
        model.eval()

        predictor = allennlp.predictors.SimpleSeq2SeqPredictor(model, reader)

        for instance in testing_set:
            print('SRC: ', instance.fields['source_tokens'].tokens)
            print(
                'GOLD:', ' '.join(
                    str(x)
                    for x in instance.fields['target_tokens'].tokens[1:-1]))
            del instance.fields['target_tokens']
            output = predictor.predict_instance(instance)
            print('PRED:', ' '.join(output['predicted_tokens']))
Beispiel #9
0
def main():
    parser = utils.opt_parser.get_trainer_opt_parser()
    parser.add_argument('models',
                        nargs='*',
                        help='pretrained models for the same setting')
    parser.add_argument('--test', action="store_true", help='use testing mode')
    parser.add_argument('--num-layer',
                        type=int,
                        help='maximum number of stacked layers')
    parser.add_argument(
        '--use-ut',
        action="store_true",
        help='Use universal transformer instead of transformer')

    args = parser.parse_args()

    reader = data_adapter.GeoQueryDatasetReader()
    training_set = reader.read(config.DATASETS[args.dataset].train_path)
    try:
        validation_set = reader.read(config.DATASETS[args.dataset].dev_path)
    except:
        validation_set = None

    vocab = allennlp.data.Vocabulary.from_instances(training_set)
    st_ds_conf = config.TRANS2SEQ_CONF[args.dataset]
    if args.num_layer:
        st_ds_conf['max_num_layers'] = args.num_layer
    if args.epoch:
        config.TRAINING_LIMIT = args.epoch
    if args.batch:
        st_ds_conf['batch_sz'] = args.batch
    bsz = st_ds_conf['batch_sz']
    emb_sz = st_ds_conf['emb_sz']

    src_embedder = BasicTextFieldEmbedder(
        token_embedders={
            "tokens": Embedding(vocab.get_vocab_size('nltokens'), emb_sz)
        })

    if args.use_ut:
        transformer_encoder = UTEncoder(
            input_dim=emb_sz,
            max_num_layers=st_ds_conf['max_num_layers'],
            num_heads=st_ds_conf['num_heads'],
            feedforward_hidden_dim=emb_sz,
            feedforward_dropout=st_ds_conf['feedforward_dropout'],
            attention_dropout=st_ds_conf['attention_dropout'],
            residual_dropout=st_ds_conf['residual_dropout'],
            use_act=st_ds_conf['act'],
            use_vanilla_wiring=st_ds_conf['vanilla_wiring'])
    else:
        transformer_encoder = TransformerEncoder(
            input_dim=emb_sz,
            num_layers=st_ds_conf['max_num_layers'],
            num_heads=st_ds_conf['num_heads'],
            feedforward_hidden_dim=emb_sz,
            feedforward_dropout=st_ds_conf['feedforward_dropout'],
            attention_dropout=st_ds_conf['attention_dropout'],
            residual_dropout=st_ds_conf['residual_dropout'],
        )

    model = allennlp.models.SimpleSeq2Seq(
        vocab,
        source_embedder=src_embedder,
        encoder=transformer_encoder,
        max_decoding_steps=50,
        attention=allennlp.modules.attention.DotProductAttention(),
        beam_size=6,
        target_namespace="lftokens",
        use_bleu=True)

    if args.models:
        model.load_state_dict(torch.load(args.models[0]))

    if not args.test or not args.models:
        iterator = BucketIterator(sorting_keys=[("source_tokens", "num_tokens")
                                                ],
                                  batch_size=bsz)
        iterator.index_with(vocab)

        optim = torch.optim.Adam(model.parameters())

        savepath = os.path.join(
            config.SNAPSHOT_PATH, args.dataset, 'transformer2seq',
            datetime.datetime.now().strftime('%Y%m%d-%H%M%S') + "--" +
            args.memo)
        if not os.path.exists(savepath):
            os.makedirs(savepath, mode=0o755)

        trainer = allennlp.training.Trainer(
            model=model,
            optimizer=optim,
            iterator=iterator,
            train_dataset=training_set,
            validation_dataset=validation_set,
            serialization_dir=savepath,
            cuda_device=args.device,
            num_epochs=config.TRAINING_LIMIT,
        )

        trainer.train()

    else:
        testing_set = reader.read(config.DATASETS[args.dataset].test_path)
        model.eval()

        predictor = allennlp.predictors.SimpleSeq2SeqPredictor(model, reader)

        for instance in tqdm.tqdm(testing_set, total=len(testing_set)):
            print('SRC: ', instance.fields['source_tokens'].tokens)
            print(
                'GOLD:', ' '.join(
                    str(x)
                    for x in instance.fields['target_tokens'].tokens[1:-1]))
            del instance.fields['target_tokens']
            output = predictor.predict_instance(instance)
            print('PRED:', ' '.join(output['predicted_tokens']))
Beispiel #10
0
def get_model(vocab, st_ds_conf):
    emb_sz = st_ds_conf['emb_sz']

    source_embedding = allennlp.modules.Embedding(
        num_embeddings=vocab.get_vocab_size('nltokens'), embedding_dim=emb_sz)
    target_embedding = allennlp.modules.Embedding(
        num_embeddings=vocab.get_vocab_size('lftokens'), embedding_dim=emb_sz)

    if st_ds_conf['encoder'] == 'lstm':
        encoder = StackedEncoder(
            [
                PytorchSeq2SeqWrapper(
                    torch.nn.LSTM(emb_sz, emb_sz, batch_first=True))
                for _ in range(st_ds_conf['num_enc_layers'])
            ],
            emb_sz,
            emb_sz,
            input_dropout=st_ds_conf['intermediate_dropout'])
    elif st_ds_conf['encoder'] == 'bilstm':
        encoder = StackedEncoder(
            [
                PytorchSeq2SeqWrapper(
                    torch.nn.LSTM(
                        emb_sz, emb_sz, batch_first=True, bidirectional=True))
                for _ in range(st_ds_conf['num_enc_layers'])
            ],
            emb_sz,
            emb_sz,
            input_dropout=st_ds_conf['intermediate_dropout'])
    elif st_ds_conf['encoder'] == 'transformer':
        encoder = StackedEncoder(
            [
                TransformerEncoder(
                    input_dim=emb_sz,
                    num_layers=st_ds_conf['num_enc_layers'],
                    num_heads=st_ds_conf['num_heads'],
                    feedforward_hidden_dim=emb_sz,
                    feedforward_dropout=st_ds_conf['feedforward_dropout'],
                    residual_dropout=st_ds_conf['residual_dropout'],
                    attention_dropout=st_ds_conf['attention_dropout'],
                ) for _ in range(st_ds_conf['num_enc_layers'])
            ],
            emb_sz,
            emb_sz,
            input_dropout=st_ds_conf['intermediate_dropout'])
    else:
        assert False

    enc_out_dim = encoder.get_output_dim()
    dec_out_dim = emb_sz

    dec_hist_attn = get_attention(st_ds_conf, st_ds_conf['dec_hist_attn'])
    enc_attn = get_attention(st_ds_conf, st_ds_conf['enc_attn'])
    if st_ds_conf['enc_attn'] == 'dot_product':
        assert enc_out_dim == dec_out_dim, "encoder hidden states must be able to multiply with decoder output"

    def sum_attn_dims(attns, dims):
        return sum(dim for attn, dim in zip(attns, dims) if attn is not None)

    if st_ds_conf['concat_attn_to_dec_input']:
        dec_in_dim = dec_out_dim + sum_attn_dims([enc_attn, dec_hist_attn],
                                                 [enc_out_dim, dec_out_dim])
    else:
        dec_in_dim = dec_out_dim
    rnn_cell = get_rnn_cell(st_ds_conf, dec_in_dim, dec_out_dim)

    if st_ds_conf['concat_attn_to_dec_input']:
        proj_in_dim = dec_out_dim + sum_attn_dims([enc_attn, dec_hist_attn],
                                                  [enc_out_dim, dec_out_dim])
    else:
        proj_in_dim = dec_out_dim

    word_proj = torch.nn.Linear(proj_in_dim, vocab.get_vocab_size('lftokens'))

    model = BaseSeq2Seq(
        vocab=vocab,
        encoder=encoder,
        decoder=rnn_cell,
        word_projection=word_proj,
        source_embedding=source_embedding,
        target_embedding=target_embedding,
        target_namespace='lftokens',
        start_symbol=START_SYMBOL,
        eos_symbol=END_SYMBOL,
        max_decoding_step=st_ds_conf['max_decoding_len'],
        enc_attention=enc_attn,
        dec_hist_attn=dec_hist_attn,
        intermediate_dropout=st_ds_conf['intermediate_dropout'],
        concat_attn_to_dec_input=st_ds_conf['concat_attn_to_dec_input'],
    )
    return model
def get_model(vocab, st_ds_conf):
    emb_sz = st_ds_conf['emb_sz']

    source_embedding = allennlp.modules.Embedding(
        num_embeddings=vocab.get_vocab_size('nltokens'), embedding_dim=emb_sz)
    target_embedding = allennlp.modules.Embedding(
        num_embeddings=vocab.get_vocab_size('lftokens'), embedding_dim=emb_sz)

    if st_ds_conf['encoder'] == 'lstm':
        encoder = allennlp.modules.seq2seq_encoders.PytorchSeq2SeqWrapper(
            torch.nn.LSTM(emb_sz,
                          emb_sz,
                          st_ds_conf['num_enc_layers'],
                          batch_first=True))
    elif st_ds_conf['encoder'] == 'bilstm':
        encoder = allennlp.modules.seq2seq_encoders.PytorchSeq2SeqWrapper(
            torch.nn.LSTM(emb_sz,
                          emb_sz,
                          st_ds_conf['num_enc_layers'],
                          batch_first=True,
                          bidirectional=True))
    elif st_ds_conf['encoder'] == 'transformer':
        encoder = TransformerEncoder(
            input_dim=emb_sz,
            num_layers=st_ds_conf['num_enc_layers'],
            num_heads=st_ds_conf['num_heads'],
            feedforward_hidden_dim=emb_sz,
            feedforward_dropout=st_ds_conf['feedforward_dropout'],
            residual_dropout=st_ds_conf['residual_dropout'],
            attention_dropout=st_ds_conf['attention_dropout'],
        )
    else:
        assert False

    enc_out_dim = encoder.get_output_dim()
    dec_out_dim = emb_sz

    dwa = get_attention(st_ds_conf, st_ds_conf['dwa'])
    dec_hist_attn = get_attention(st_ds_conf, st_ds_conf['dec_hist_attn'])
    enc_attn = get_attention(st_ds_conf, st_ds_conf['enc_attn'])
    if st_ds_conf['enc_attn'] == 'dot_product':
        assert enc_out_dim == dec_out_dim, "encoder hidden states must be able to multiply with decoder output"

    def sum_attn_dims(attns, dims):
        return sum(dim for attn, dim in zip(attns, dims) if attn is not None)

    dec_in_dim = dec_out_dim + 1 + sum_attn_dims([enc_attn, dec_hist_attn],
                                                 [enc_out_dim, dec_out_dim])
    rnn_cell = get_rnn_cell(st_ds_conf, dec_in_dim, dec_out_dim)

    halting_in_dim = dec_out_dim + sum_attn_dims(
        [enc_attn, dec_hist_attn, dwa],
        [enc_out_dim, dec_out_dim, dec_out_dim])
    halting_fn = torch.nn.Sequential(
        # torch.nn.Dropout(st_ds_conf['act_dropout']),
        torch.nn.Linear(halting_in_dim, halting_in_dim),
        torch.nn.ReLU(),
        torch.nn.Dropout(st_ds_conf['act_dropout']),
        torch.nn.Linear(halting_in_dim, 1),
        torch.nn.Sigmoid(),
    )

    decoder = ACTRNNCell(
        rnn_cell=rnn_cell,
        halting_fn=halting_fn,
        use_act=st_ds_conf['act'],
        act_max_layer=st_ds_conf['act_max_layer'],
        act_epsilon=st_ds_conf['act_epsilon'],
        rnn_input_dropout=st_ds_conf['decoder_dropout'],
        depth_wise_attention=dwa,
        state_mode=st_ds_conf['act_mode'],
    )
    proj_in_dim = dec_out_dim + sum_attn_dims([enc_attn, dec_hist_attn],
                                              [enc_out_dim, dec_out_dim])
    word_proj = torch.nn.Linear(proj_in_dim, vocab.get_vocab_size('lftokens'))
    model = AdaptiveSeq2Seq(
        vocab=vocab,
        encoder=encoder,
        decoder=decoder,
        word_projection=word_proj,
        source_embedding=source_embedding,
        target_embedding=target_embedding,
        target_namespace='lftokens',
        start_symbol=START_SYMBOL,
        eos_symbol=END_SYMBOL,
        max_decoding_step=st_ds_conf['max_decoding_len'],
        enc_attention=enc_attn,
        dec_hist_attn=dec_hist_attn,
        act_loss_weight=st_ds_conf['act_loss_weight'],
        prediction_dropout=st_ds_conf['prediction_dropout'],
        embedding_dropout=st_ds_conf['embedding_dropout'],
    )
    return model