def extract_args(self, features, mode, params):
     super().extract_args(features, mode, params)
     if self.hparams.vocab_size > 0:
         self.vocab = Vocabulary(size=self.hparams.vocab_size)
     else:
         self.vocab = Vocabulary(
             fname=self.hparams.vocab_file,
             skip_tokens=self.hparams.skip_tokens,
             skip_tokens_start=self.hparams.skip_tokens_start)
Exemplo n.º 2
0
def main(_):
    print("Loading hyperparameters..")
    params = util.load_params(FLAGS.params_file)

    print("Building model..")
    model_dir = FLAGS.model_dir
    if FLAGS.clean_model_dir:
        util.clean_model_dir(model_dir)
    if FLAGS.model_cls == "transformer":
        model_cls = TransformerEstimator
    elif FLAGS.model_cls == "seq2seq":
        model_cls = Seq2SeqEstimator
    else:
        raise ValueError("Model class not supported.")
    model = model_cls(model_dir, params)

    print("Getting sources..")
    fields = {"train/inputs": "int", "train/targets": "int"}
    train_source = DataSource(FLAGS.train_file, fields)
    test_source = DataSource(FLAGS.test_file, fields)

    field_map = {"inputs": "train/inputs", "targets": "train/targets"}
    train_input_fn = train_source.get_input_fn(
        "train_in", field_map, None, FLAGS.batch_size)
    test_input_fn = test_source.get_input_fn(
        "test_in", field_map, 1, FLAGS.batch_size)

    print("Processing model..")
    model.train(train_input_fn, steps=FLAGS.train_batches)
    model.evaluate(test_input_fn)

    if FLAGS.interactive:
        print("Interactive decoding...")
        vocab = Vocabulary(fname=params["vocab_file"])
        decoding.cmd_decode(model, vocab)
Exemplo n.º 3
0
 def extract_args(self, features, mode, params):
     super().extract_args(features, mode, params)
     if (self.hparams.src_vocab_size == 0
             and self.hparams.tgt_vocab_size == 0
             and self.hparams.src_vocab_file == ""
             and self.hparams.tgt_vocab_file == ""):
         self.src_vocab = self.vocab
         self.tgt_vocab = self.vocab
     else:
         if self.hparams.src_vocab_size > 0:
             self.src_vocab = Vocabulary(size=self.hparams.src_vocab_size)
         else:
             self.src_vocab = Vocabulary(fname=self.hparams.src_vocab_file)
         if self.hparams.tgt_vocab_size > 0:
             self.tgt_vocab = Vocabulary(size=self.hparams.tgt_vocab_size)
         else:
             self.tgt_vocab = Vocabulary(fname=self.hparams.tgt_vocab_file)
 def extract_args(self, features, mode, params):
     super().extract_args(features, mode, params)
     self.d_k = self.hparams.d_model // self.hparams.num_heads
     self.d_pos = self.hparams.d_pos if self.hparams.d_pos == 0 else self.hparams.d_pos
     self.d_ff = self.hparams.d_ff if self.hparams.d_ff == 0 else self.hparams.d_ff
     if self.hparams.vocab_size > 0:
         self.vocab = Vocabulary(size=self.hparams.vocab_size)
     else:
         self.vocab = Vocabulary(fname=self.hparams.vocab_file)
     if not self.hparams.fixed_learning_rate:
         self.train_step = tf.get_variable(
             'train_step',
             shape=[],
             dtype=tf.float32,
             initializer=tf.zeros_initializer(dtype=tf.int32),
             trainable=False)
         self.learning_rate = (  # magic formula provided in transformer paper
             tf.sqrt(1.0 / self.hparams.d_model) * tf.minimum(
                 self.train_step * tf.pow(self.hparams.warmup_steps, -1.5),
                 tf.pow(self.train_step, -0.5)))
Exemplo n.º 5
0
 def write_to_tfrecord(self, out_file, pipeline=None, max_lines=None):
     print("Writing to TFRecord..")
     writer = tf.python_io.TFRecordWriter(out_file)
     line_ctr = 0
     for row in self.row_gen():
         if not self.process_row(pipeline, row):
             continue
         feature = dict()
         for i in range(len(row)):
             key_ = self.headers[i].name
             type_ = self.headers[i].data_type
             vocab_ = self.headers[i].vocab_file
             mode_ = self.headers[i].vocab_mode
             if type_ == "text":
                 if vocab_ not in self.vocabs:
                     if mode_ != "write":
                         self.vocabs[vocab_] = Vocabulary(fname=vocab_)
                     else:
                         self.vocabs[vocab_] = Vocabulary()
                 row[i] = self.vocabs[vocab_].tokenize(
                     row[i], fixed_vocab=(mode_ == "read"))
                 feature[key_] = self.int64_feature(row[i])
             elif type_ == "int":
                 print([int(row[i])])
                 feature[key_] = self.int64_feature([int(row[i])])
             elif type_ == "float":
                 feature[key_] = self.float_feature([float(row[i])])
             else:
                 raise ValueError("Header type " + str(type_) +
                                  " not supported.")
         example = tf.train.Example(features=tf.train.Features(
             feature=feature))
         writer.write(example.SerializeToString())
         line_ctr = self.print_lines_processed(line_ctr)
         if max_lines is not None and line_ctr >= max_lines:
             break
     writer.close()
Exemplo n.º 6
0
 def build_vocab_files(self, count_cutoff=0):
     print("Building vocabularies..")
     read_only = True
     self.vocabs = dict()
     for i in range(len(self.headers)):
         vocab_ = self.headers[i].vocab_file
         mode = self.headers[i].vocab_mode
         if ((vocab_ is not None) and (vocab_ not in self.vocabs)
                 and (mode != "read")):
             read_only = False
             if mode == "write":
                 self.vocabs[vocab_] = Vocabulary()
             elif mode == "append":
                 self.vocabs[vocab_] = Vocabulary(fname=vocab_)
             else:
                 raise ValueError("Vocab mode " + str(mode) +
                                  " not supported.")
         elif vocab_ is not None and mode == "read":
             self.vocabs[vocab_] = Vocabulary(fname=vocab_)
     if read_only:
         return
     line_ctr = 0
     for row in self.row_gen():
         for i in range(len(row)):
             vocab_ = self.headers[i].vocab_file
             if vocab_ in self.vocabs:
                 self.vocabs[vocab_].tokenize(row[i], fixed_vocab=False)
         line_ctr = self.print_lines_processed(line_ctr)
     for vocab_ in self.vocabs:
         if count_cutoff >= 0:
             self.vocabs[vocab_].count_cutoff(count_cutoff)
         with open(vocab_, "w", encoding="utf8") as vocab_f:
             for word in self.vocabs[vocab_].words:
                 vocab_f.write(word + "\n")
     for i in range(len(self.headers)):
         self.headers[i].vocab_mode = "read"
Exemplo n.º 7
0
def main(_):
    print("Loading hyperparameters..")
    params = util.load_params(FLAGS.params_file)

    print("Building model..")
    validation_config = tf.estimator.RunConfig(
        save_checkpoints_steps=100,
        keep_checkpoint_max=None,
    )
    model_dir = FLAGS.model_dir
    if FLAGS.clean_model_dir:
        util.clean_model_dir(model_dir)
    if FLAGS.model_cls == "transformer":
        model_cls = TransformerEstimator
    elif FLAGS.model_cls == "seq2seq":
        model_cls = Seq2SeqEstimator
    else:
        raise ValueError("Model class not supported.")
    model = model_cls(model_dir, params, config=validation_config)

    print("Getting sources..")
    fields = {"train/inputs": "int", "train/targets": "int"}
    train_source = DataSource(FLAGS.train_file, fields)
    test_source = DataSource(FLAGS.test_file, fields)

    field_map = {"inputs": "train/inputs", "targets": "train/targets"}
    train_input_fn = train_source.get_input_fn("train_in", field_map, None,
                                               FLAGS.batch_size)
    test_input_fn = test_source.get_input_fn("test_in", field_map, 1,
                                             FLAGS.batch_size)

    print("Processing model..")
    model.train(train_input_fn, steps=FLAGS.train_batches)
    model.choose_best_checkpoint(test_input_fn)
    model.evaluate(test_input_fn)

    if FLAGS.interaction != "off":
        print("Interactive decoding...")
        vocab = Vocabulary(fname=params["vocab_file"])
        if FLAGS.interaction == "cmd":
            decoding.cmd_decode(model, vocab, persona=True)
        elif FLAGS.interaction == "gui":
            decoding.gui_decode(model, vocab)
Exemplo n.º 8
0
 def write_to_tfrecord(self,
                       out_file,
                       pipeline=None,
                       max_lines=None,
                       line_gen=None,
                       line_shard_len=None,
                       streamline=True,
                       traversal="depth_first",
                       max_pos_len=32):
     print("Writing to TFRecord..")
     writer = tf.python_io.TFRecordWriter(out_file)
     line_ctr = 0
     if line_gen is None:
         line_gen = self.row_gen()
     for row in line_gen:
         if not self.process_row(row, pipeline):
             continue
         feature = {}
         for i in range(len(row)):
             key_ = self.headers[i].name
             type_ = self.headers[i].data_type
             vocab_ = self.headers[i].vocab_file
             mode_ = self.headers[i].vocab_mode
             if type_ == "text" or type_ == "tree":
                 if vocab_ not in self.vocabs:
                     if mode_ != "write":
                         self.vocabs[vocab_] = Vocabulary(fname=vocab_)
                     else:
                         self.vocabs[vocab_] = Vocabulary()
                 if type_ == "text":
                     row[i] = self.vocabs[vocab_].tokenize(
                         row[i], fixed_vocab=(mode_ == "read"))
                     feature[key_] = self.int64_feature(row[i])
                 else:
                     tree_ints = []
                     tree_pos = []
                     for node in row[i].choose_traversal(traversal):
                         if streamline:
                             if node.value == "_NULL" and (
                                     not node.parent
                                     or node.parent.children[0].value
                                     == "_NULL"):
                                 continue
                             node.value = str(node.value)
                             if (not node.is_leaf()
                                 ) and node.children[0].value == "_NULL":
                                 if node.children[1].value == "_NULL":
                                     node.value = str(node.value) + "_0"
                                 else:
                                     node.value = str(node.value) + "_1"
                             if mode_ == "read" and node.value not in self.vocabs[
                                     vocab_].word2idx:
                                 if len(node.value
                                        ) > 2 and node.value[-2:] == "_0":
                                     node.value = "_UNK_0"
                                 elif len(node.value
                                          ) > 2 and node.value[-2:] == "_1":
                                     node.value = "_UNK_1"
                                 else:
                                     node.value = "_UNK"
                         tree_ints.append(self.vocabs[vocab_].get_token_id(
                             node.value, mode_ == "read"))
                         tree_pos += node.get_padded_positional_encoding(
                             max_pos_len)
                     field = self.headers[i].name
                     feature[field] = self.int64_feature(tree_ints)
                     feature[field + "_pos"] = self.float_feature(tree_pos)
             elif type_ == "int":
                 feature[key_] = self.int64_feature([int(row[i])])
             elif type_ == "float":
                 feature[key_] = self.float_feature([float(row[i])])
             else:
                 raise ValueError("Header type " + str(type_) +
                                  " not supported.")
         example = tf.train.Example(features=tf.train.Features(
             feature=feature))
         writer.write(example.SerializeToString())
         line_ctr = self.print_lines_processed(line_ctr, "trees")
         if max_lines is not None and line_ctr >= max_lines:
             break
     writer.close()
Exemplo n.º 9
0
 def apply_byte_pair_encodings(self, out_file, max_lines=None):
     self.build_vocab_files()
     print("Applying byte pair encodings..")
     all_bpe_vocabs = dict()
     word_encodings = dict()
     for vocab_ in self.vocabs:
         all_bpe_vocabs[vocab_] = Vocabulary(fname=vocab_)
         word_encodings[vocab_] = dict()
     length_headers = OrderedDict()
     for i in range(len(self.headers)):
         if self.headers[i].vocab_file is not None:
             length_headers[self.headers[i].name] = DataHeader(
                 self.headers[i].name + "/_length", "int")
     for header_name in length_headers:
         self.headers.append(length_headers[header_name])
     with open(out_file, "w", encoding="utf8") as out_f:
         line_ctr = 0
         for row in self.row_gen():
             row_extension = []
             for i in range(len(row)):
                 vocab_ = self.headers[i].vocab_file
                 if vocab_ is not None:
                     row_extension.append(len(row[i].strip().split()))
                     new_elem = ""
                     for word in row[i].strip().split():
                         if word in word_encodings[vocab_]:
                             encoding = word_encodings[vocab_][word]
                         else:
                             encoding = list(word) + ["</EOW>"]
                             bigrams = dict()
                             for j in range(len(encoding) - 1):
                                 bigram = encoding[j] + encoding[j + 1]
                                 if bigram in all_bpe_vocabs[
                                         vocab_].word2idx:
                                     bigrams[j] = all_bpe_vocabs[
                                         vocab_].word2idx[bigram]
                             while len(bigrams) > 0:
                                 bigrams_argmin = None
                                 for idx in bigrams:
                                     if bigrams_argmin is None or bigrams[
                                             idx] < bigrams[bigrams_argmin]:
                                         bigrams_argmin = idx
                                 encoding = encoding[0:bigrams_argmin] + \
                                     [encoding[bigrams_argmin] + encoding[bigrams_argmin+1]] + encoding[bigrams_argmin+2:]
                                 bigrams = dict()
                                 for j in range(len(encoding) - 1):
                                     bigram = encoding[j] + encoding[j + 1]
                                     if bigram in all_bpe_vocabs[
                                             vocab_].word2idx:
                                         bigrams[j] = all_bpe_vocabs[
                                             vocab_].word2idx[bigram]
                             word_encodings[vocab_][word] = encoding
                         for subword in encoding:
                             new_elem += subword + " "
                     row[i] = new_elem
             row += row_extension
             out_f.write(self.concatenate_segments(row))
             line_ctr = self.print_lines_processed(line_ctr)
             if max_lines is not None and line_ctr >= max_lines:
                 break
     self.in_files = [out_file]
Exemplo n.º 10
0
def main(_):
    print("Loading parameters..")
    params = util.load_params(FLAGS.params_file)

    print("Building model..")
    model_dir = FLAGS.model_dir
    if FLAGS.clean_model_dir:
        util.clean_model_dir(model_dir)
    first_model = PersonaSeq2SeqEstimator(model_dir, params, scope="first")
    second_model_encoder = Seq2SeqEncoderEstimator(model_dir,
                                                   params,
                                                   scope="second_encoder")
    second_model = EstimatorChain([second_model_encoder, first_model.decoder],
                                  model_dir,
                                  params,
                                  scope="second")
    mmi_model = PersonaSeq2SeqEstimator(model_dir,
                                        params,
                                        scope="mmi",
                                        is_mmi_model=True)
    model_group = EstimatorGroup([first_model, second_model, mmi_model],
                                 model_dir,
                                 params,
                                 scope="group")

    print("Getting sources..")
    fields = {
        "train/inputs": "int",
        "train/targets": "int",
        "train/speakers": "int"
    }
    train_source = DataSource(FLAGS.train_file, fields)
    autoenc_source = DataSource(FLAGS.autoenc_file, fields)
    test_source = DataSource(FLAGS.test_file, fields)

    train_field_map = {
        "inputs": "train/inputs",
        "targets": "train/targets",
        "speaker_ids": "train/speakers"
    }
    autoenc_field_map = {
        "inputs": "train/inputs",
        "targets": "train/inputs",
        "speaker_ids": "train/speakers"
    }
    mmi_field_map = {
        "inputs": "train/targets",
        "targets": "train/inputs",
        "speaker_ids": "train/speakers"
    }

    paired_input_fn = train_source.get_input_fn("paired_in", train_field_map,
                                                None, FLAGS.batch_size)
    autoenc_input_fn = train_source.get_input_fn("autoenc_in",
                                                 autoenc_field_map, None,
                                                 FLAGS.batch_size)
    mmi_input_fn = train_source.get_input_fn("mmi_in", mmi_field_map, None,
                                             FLAGS.batch_size)
    train_input_fn = DataSource.group_input_fns(
        ["first", "second", "mmi"],
        [paired_input_fn, autoenc_input_fn, mmi_input_fn])
    test_input_fn = test_source.get_input_fn("test_in", train_field_map, 1,
                                             FLAGS.batch_size)

    print("Processing models..")
    print("Pretraining primary model..")
    model_group.train(train_input_fn,
                      first_model,
                      steps=FLAGS.pretrain_batches)
    print("Multitask training..")
    model_group.train(train_input_fn, {
        "first": 1,
        "second": 1,
        "mmi": 0
    },
                      steps=FLAGS.train_batches)
    print("Training MMI model..")
    model_group.train(train_input_fn, mmi_model, steps=FLAGS.mmi_batches)
    print("Evaluating..")
    model_group.evaluate(test_input_fn, first_model)

    if FLAGS.interactive:
        print("Interactive decoding...")
        vocab = Vocabulary(fname=params["vocab_file"])
        decoding.cmd_decode(first_model,
                            vocab,
                            persona=True,
                            mmi_component=mmi_model)
Exemplo n.º 11
0
 def __init__(self, fname, fields, vocab=None):
     self.fname = fname
     self.parse_fields(fields)
     self.input_fns = dict()
     self.vocab = vocab if vocab is not None else Vocabulary()
Exemplo n.º 12
0
def main(_):
    '''
    This is a simple example of how to build an Icecaps training script, and is essentially
    the "Hello World" of Icecaps. Icecaps training scripts follow a basic five-phase pattern
    that we describe here. We train a basic model on the paired data stored in
    dummy_data/paired_personalized.tfrecord. For information on how to build TFRecords
    from text data files, please see data_processing_example.py.
    '''

    print("Loading hyperparameters..")
    # The first phase is to load hyperparameters from a .params file. These files follow a
    # simple colon-delimited format (e.g. see dummy_params/simple_example_seq2seq.params).
    params = util.load_params(FLAGS.params_file)

    print("Building model..")
    # Second, we build our architecture based on our loaded hyperparameters. Our architecture
    # here is very basic: we use a simple LSTM-based seq2seq model. For information on more
    # complex architectures, wee train_persona_mmi_example.py.
    model_dir = FLAGS.model_dir
    if FLAGS.clean_model_dir:
        util.clean_model_dir(model_dir)
    model_cls = Seq2SeqEstimator

    # Every estimator expects a different set of hyperparmeters. If you set use_default_params
    # to True in your .params file, the estimator will employ default values for any unspecified
    # hyperparameters. To view the list of hyperparmeters with default values, you can run the
    # class method list_params(). E.g. you can open a Python session and run
    # Seq2SeqEstimator.list_params() to view what hyperparameters our seq2seq estimator expects.
    model = model_cls(model_dir, params)

    print("Getting sources..")
    # Third, we set up our data sources. DataSource objects allow you to build input_fns that
    # efficiently feed data into the training pipeline from TFRecord files. In our simple example,
    # we set up two data sources: one for training and one for testing.

    # TFRecords are created with name variables per data point. You must create a fields dictionary
    # to tell the DataSource which variables to load and what their types are.
    fields = {"train/inputs": "int", "train/targets": "int"}
    train_source = DataSource(FLAGS.train_file, fields)
    test_source = DataSource(FLAGS.test_file, fields)

    # Then, you must create a field_map dictionary to tell your estimator how to map the TFRecord's
    # variable names to the names expected by the estimator. While this may seem like unnecessary
    # overhead in this simple example, it provides useful flexibility in more complex scenarios.
    field_map = {"inputs": "train/inputs", "targets": "train/targets"}

    # Finally, build input_fns from your DataSources.
    train_input_fn = train_source.get_input_fn(
        "train_in", field_map, None,
        FLAGS.batch_size)  # None lets our input_fn run for an unbounded
    # number of epochs.
    test_input_fn = test_source.get_input_fn(
        "test_in", field_map, 1,
        FLAGS.batch_size)  # For testing, we only want to run the input_fn
    # for one epoch instead.

    print("Processing model..")
    # Fourth, we pipe our input_fns through our model for training and evaluation.
    model.train(train_input_fn, steps=FLAGS.train_batches)
    model.evaluate(test_input_fn)

    if FLAGS.interactive:
        print("Interactive decoding...")
        # Fifth, you may optionally set up an interactive session to test your system by directly
        # engaging with it.
        vocab = Vocabulary(fname=params["vocab_file"])
        decoding.cmd_decode(model, vocab)
Exemplo n.º 13
0
def main(_):
    '''
    This is a more complex example in which we build an Icecaps script involving
    component chaining and multi-task learning. We recommend you start with
    train_simple_example.py. In this example, we build a personalized conversation system
    that combines paired and unpaired data, and applies MMI during decoding.
    '''

    print("Loading parameters..")
    # When multiple estimators are involved, you can specify which hyperparameters in your
    # params file belong to which estimator using scoping. See dummy_params/persona_mmi_example.params
    # for an example. If no scope is specified, the hyperparameter is provided to all
    # models in your architecture.
    params = util.load_params(FLAGS.params_file)

    print("Building model..")
    model_dir = FLAGS.model_dir
    if FLAGS.clean_model_dir:
        util.clean_model_dir(model_dir)

    # For this system, we will need to build three different estimators.
    # The first estimator is a personalized seq2seq estimator that will be responsible for
    # learning the conversational model.
    first_model = PersonaSeq2SeqEstimator(model_dir, params, scope="first")

    # The second estimator is a personalized seq2seq estimator that shares its decoder with
    # the first model. This model will learn an autoencoder on an unpaired personalized
    # data set. The purpose of this configuration is to influence the first model with
    # stylistic information from the unpaired dataset.

    # To construct this second estimator, we first build a seq2seq encoder separate from
    # the first model. Then, we use an EstimatorChain to chain that encoder to the first
    # model's decoder, allowing the two models to share that decoder.
    second_model_encoder = Seq2SeqEncoderEstimator(model_dir,
                                                   params,
                                                   scope="second_encoder")
    second_model = EstimatorChain([second_model_encoder, first_model.decoder],
                                  model_dir,
                                  params,
                                  scope="second")

    # The third estimator is used for MMI decoding. This model will learn the inverse
    # function of the first model. During decoding, this estimator will be used to rerank
    # hypotheses generated by the first model during beam search decoding. While this
    # won't have much of an effect on our toy data sets, the purpose of this model in
    # real-world settings is to penalize generic responses applicable to many contexts
    # such as "I don't know."
    mmi_model = PersonaSeq2SeqEstimator(model_dir,
                                        params,
                                        scope="mmi",
                                        is_mmi_model=True)
    model_group = EstimatorGroup([first_model, second_model, mmi_model],
                                 model_dir,
                                 params,
                                 scope="group")

    print("Getting sources..")
    # We will use two DataSources for training and one for testing.
    fields = {
        "train/inputs": "int",
        "train/targets": "int",
        "train/speakers": "int"
    }
    paired_source = DataSource(FLAGS.paired_file, fields)
    unpaired_source = DataSource(FLAGS.unpaired_file, fields)
    test_source = DataSource(FLAGS.test_file, fields)

    # We construct three field maps.
    # The paired field map is similar to the field map shown in train_simple_example.py
    # The unpaired field map maps train/inputs to both the estimator's inputs and targets,
    # in order to train an autoencoder.
    # The mmi field maps maps train/inputs to targets and train/targets to inputs, in
    # order to learn the inverse of the first estimator.
    paired_field_map = {
        "inputs": "train/inputs",
        "targets": "train/targets",
        "speaker_ids": "train/speakers"
    }
    unpaired_field_map = {
        "inputs": "train/inputs",
        "targets": "train/inputs",
        "speaker_ids": "train/speakers"
    }
    mmi_field_map = {
        "inputs": "train/targets",
        "targets": "train/inputs",
        "speaker_ids": "train/speakers"
    }

    paired_input_fn = paired_source.get_input_fn("paired_in", paired_field_map,
                                                 None, FLAGS.batch_size)
    unpaired_input_fn = unpaired_source.get_input_fn("unpaired_in",
                                                     unpaired_field_map, None,
                                                     FLAGS.batch_size)
    mmi_input_fn = paired_source.get_input_fn("mmi_in", mmi_field_map, None,
                                              FLAGS.batch_size)
    # For multi-task learning, you will need to group your input_fns together with group_input_fns().
    train_input_fn = DataSource.group_input_fns(
        ["first", "second", "mmi"],
        [paired_input_fn, unpaired_input_fn, mmi_input_fn])
    test_input_fn = test_source.get_input_fn("test_in", paired_field_map, 1,
                                             FLAGS.batch_size)

    print("Processing models..")
    # Icecaps supports flexible multi-task training pipelines. You can set up multiple phases
    # where each phase trains your architecture with different weights across your objectives.
    # In this example, we will first pre-train the first model by itself, then jointly train
    # the first and second models, then finally train the MMI model by itself.
    print("Pretraining primary model..")
    model_group.train(train_input_fn,
                      first_model,
                      steps=FLAGS.pretrain_batches)
    print("Multitask training..")
    model_group.train(train_input_fn, {
        "first": 1,
        "second": 1,
        "mmi": 0
    },
                      steps=FLAGS.train_batches)
    print("Training MMI model..")
    model_group.train(train_input_fn, mmi_model, steps=FLAGS.mmi_batches)
    print("Evaluating..")
    model_group.evaluate(test_input_fn, first_model)

    if FLAGS.interactive:
        print("Interactive decoding...")
        vocab = Vocabulary(fname=params["vocab_file"])
        # To decode with MMI, you can pass in your MMI model to cmd_decode().
        # lambda_balance represents how the first model and MMI model's scores are weighted during decoding.
        decoding.cmd_decode(first_model,
                            vocab,
                            persona=True,
                            mmi_component=mmi_model,
                            lambda_balance=FLAGS.lambda_balance)