def create_train_model(model_creator, hps, scope=None, extra_args=None):

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        vocab_table = data.create_vocab_tables(hps.vocab_file, hps.vocab_size,
                                               hps.unk_id)
        train_dataset = data.get_dataset(hps.data_dir, hps.train_prefix)

        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        iterator = data.get_iterator(train_dataset, vocab_table, hps)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        if extra_args: model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(iterator=iterator,
                                  hps=hps,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  vocab_table=vocab_table,
                                  scope=scope)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
def create_eval_model(model_creator, hps, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""

    graph = tf.Graph()
    with graph.as_default(), tf.container(scope or "eval"):
        vocab_table = data.create_vocab_tables(hps.vocab_file, hps.vocab_size,
                                               hps.unk_id)
        src_file_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=[None], dtype=tf.string)

        article_data_set = tf.data.TextLineDataset(src_file_placeholder)
        abstract_data_set = tf.data.TextLineDataset(tgt_file_placeholder)
        dataset = tf.data.Dataset.zip((article_data_set, abstract_data_set))

        iterator = data.get_iterator(dataset, vocab_table, hps)

        model = model_creator(iterator=iterator,
                              hps=hps,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              vocab_table=vocab_table,
                              scope=scope)

    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     iterator=iterator)
Ejemplo n.º 3
0
def create_infer_model(hparams):
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()
    with graph.as_default(), tf.container('infer'):
        src_vocab_table = lookup_ops.index_table_from_file(
            src_vocab_file, default_value=UNK_ID)
        tgt_vocab_table = lookup_ops.index_table_from_file(
            tgt_vocab_file, default_value=UNK_ID)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=UNK)
        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
        iterator = get_infer_iterator(src_dataset,
                                      src_vocab_table,
                                      batch_size_placeholder,
                                      EOS,
                                      src_max_len=hparams.src_max_len_infer)
        model = NMTModel(hparams, 'infer', iterator, src_vocab_table,
                         tgt_vocab_table, reverse_tgt_vocab_table)
        return InferModel(graph=graph,
                          model=model,
                          src_placeholder=src_placeholder,
                          batch_size_placeholder=batch_size_placeholder,
                          iterator=iterator)
def create_infer_model(model_creator, hps, scope=None, sampling=None):
    """Create inference model."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = data.create_vocab_tables(hps.vocab_file, hps.vocab_size,
                                               hps.unk_id)
        reverse_target_vocab_table = data.create_id_tables(
            hps.vocab_file, hps.vocab_size)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

        iterator = data.get_infer_iterator(src_dataset, vocab_table,
                                           batch_size_placeholder,
                                           reverse_target_vocab_table, hps.eos,
                                           hps.src_max_len)

        model = model_creator(
            iterator=iterator,
            hps=hps,
            mode=tf.contrib.learn.ModeKeys.INFER,
            vocab_table=vocab_table,
            reverse_target_vocab_table=reverse_target_vocab_table,
            scope=scope)

    return InferModel(graph=graph,
                      model=model,
                      batch_size_placeholder=batch_size_placeholder,
                      src_placeholder=src_placeholder,
                      iterator=iterator)
Ejemplo n.º 5
0
def create_eval_model(hparams):
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()
    with graph.as_default(), tf.container('eval'):
        src_vocab_table = lookup_ops.index_table_from_file(
            src_vocab_file, default_value=UNK_ID)
        tgt_vocab_table = lookup_ops.index_table_from_file(
            tgt_vocab_file, default_value=UNK_ID)
        src_file_placeholder = tf.placeholder(shape=[], dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=[], dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
        iterator = get_iterator(src_dataset,
                                tgt_dataset,
                                src_vocab_table,
                                tgt_vocab_table,
                                hparams.batch_size,
                                SOS,
                                EOS,
                                src_max_len=hparams.src_max_len,
                                tgt_max_len=hparams.tgt_max_len)
        model = NMTModel(hparams, 'eval', iterator, src_vocab_table,
                         tgt_vocab_table)
        return EvalModel(graph=graph,
                         model=model,
                         src_file_placeholder=src_file_placeholder,
                         tgt_file_placeholder=tgt_file_placeholder,
                         iterator=iterator)
Ejemplo n.º 6
0
def create_train_model(hparams):
    src_file = hparams.src_train_file
    tgt_file = hparams.tgt_train_file
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()
    with graph.as_default(), tf.container('train'):
        src_vocab_table = lookup_ops.index_table_from_file(
            src_vocab_file, default_value=UNK_ID)
        tgt_vocab_table = lookup_ops.index_table_from_file(
            tgt_vocab_file, default_value=UNK_ID)
        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)
        iterator = get_iterator(src_dataset,
                                tgt_dataset,
                                src_vocab_table,
                                tgt_vocab_table,
                                hparams.batch_size,
                                SOS,
                                EOS,
                                src_max_len=hparams.src_max_len,
                                tgt_max_len=hparams.tgt_max_len)
        model = NMTModel(hparams, 'train', iterator, src_vocab_table,
                         tgt_vocab_table)
        return TrainModel(graph=graph, model=model, iterator=iterator)
Ejemplo n.º 7
0
    def testResetFails(self):
        # Creates variable with container name.
        with tf.container("test0"):
            v0 = tf.Variable(1.0, name="v0")
        # Creates variable with default container.
        v1 = tf.Variable(2.0, name="v1")
        # Verifies resetting the non-existent target returns error.
        with self.assertRaises(tf.errors.NotFoundError):
            tf.Session.reset("nonexistent", ["test0"])

        # Verifies resetting with config.
        # Verifies that resetting target with no server times out.
        with self.assertRaises(tf.errors.DeadlineExceededError):
            tf.Session.reset("grpc://localhost:0", ["test0"],
                             config=tf.ConfigProto(operation_timeout_in_ms=5))

        # Verifies no containers are reset with non-existent container.
        server = tf.train.Server.create_local_server()
        sess = tf.Session(server.target)
        sess.run(tf.global_variables_initializer())
        self.assertAllEqual(1.0, sess.run(v0))
        self.assertAllEqual(2.0, sess.run(v1))
        # No container is reset, but the server is reset.
        tf.Session.reset(server.target, ["test1"])
        # Verifies that both variables are still valid.
        sess = tf.Session(server.target)
        self.assertAllEqual(1.0, sess.run(v0))
        self.assertAllEqual(2.0, sess.run(v1))
def create_eval_model(hparams, scope=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    vocab_file = hparams.vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        vocab_table = vocab.create_vocab_table(vocab_file)
        eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

        eval_dataset = tf.data.TextLineDataset(eval_file_placeholder)
        iterator = taware_iterators.get_iterator(
            eval_dataset,
            vocab_table,
            hparams.batch_size,
            num_buckets=hparams.num_buckets,
            topic_words_per_utterance=hparams.topic_words_per_utterance,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len)
        model = TopicAwareSeq2SeqModel(mode=tf.contrib.learn.ModeKeys.EVAL,
                                       iterator=iterator,
                                       params=hparams,
                                       scope=scope,
                                       log_trainables=False)
    return EvalModel(graph=graph,
                     model=model,
                     eval_file_placeholder=eval_file_placeholder,
                     iterator=iterator)
Ejemplo n.º 9
0
  def testResetFails(self):
    # Creates variable with container name.
    with tf.container("test0"):
      v0 = tf.Variable(1.0, name="v0")
    # Creates variable with default container.
    v1 = tf.Variable(2.0, name="v1")
    # Verifies resetting the non-existent target returns error.
    with self.assertRaises(tf.errors.NotFoundError):
      tf.Session.reset("nonexistent", ["test0"])

    # Verifies resetting with config.
    # Verifies that resetting target with no server times out.
    with self.assertRaises(tf.errors.DeadlineExceededError):
      tf.Session.reset(
          "grpc://localhost:0", ["test0"],
          config=tf.ConfigProto(operation_timeout_in_ms=5))

    # Verifies no containers are reset with non-existent container.
    server = tf.train.Server.create_local_server()
    sess = tf.Session(server.target)
    sess.run(tf.global_variables_initializer())
    self.assertAllEqual(1.0, sess.run(v0))
    self.assertAllEqual(2.0, sess.run(v1))
    # No container is reset, but the server is reset.
    tf.Session.reset(server.target, ["test1"])
    # Verifies that both variables are still valid.
    sess = tf.Session(server.target)
    self.assertAllEqual(1.0, sess.run(v0))
    self.assertAllEqual(2.0, sess.run(v1))
Ejemplo n.º 10
0
def create_infer_model(model_class, hparams, scope=None):
    """Create inference model."""
    graph = tf.Graph()
    vocab_file = hparams.vocab_file

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = vocab.create_vocab_table(vocab_file)
        reverse_vocab_table = vocab.create_rev_vocab_table(vocab_file)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(
            src_placeholder)
        iterator = taware_iterators.get_infer_iterator(
            src_dataset,
            vocab_table,
            batch_size=batch_size_placeholder,
            topic_words_per_utterance=hparams.topic_words_per_utterance,
            src_max_len=hparams.src_max_len)
        model = model_class(
            mode=tf.contrib.learn.ModeKeys.INFER,
            iterator=iterator,
            params=hparams,
            rev_vocab_table=reverse_vocab_table,
            scope=scope,
            log_trainables=False)
    return InferModel(
        graph=graph,
        model=model,
        src_placeholder=src_placeholder,
        batch_size_placeholder=batch_size_placeholder,
        iterator=iterator)
Ejemplo n.º 11
0
def create_eval_model(model_creator, hparams, mode):
    graph = tf.Graph()
    with graph.as_default(), tf.container("eval"):
        # create a table to map words to vocab ids.
        input_vocab_table = vocab_utils.create_vocab_table(hparams.vocab_path)
        # define a placeholder for the input dataset.
        # we will dynamically initialize this placeholder with a file name during validation.
        # The reason for this is that during validation, we may want to evaluate our trained model on different datasets.
        input_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        input_dataset = tf.contrib.data.TextLineDataset(input_file_placeholder)
        output_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        output_dataset = tf.contrib.data.TextLineDataset(
            output_file_placeholder)

        iterator = iterator_utils.get_iterator(
            input_dataset,
            output_dataset,
            input_vocab_table,
            batch_size=hparams.eval_batch_size,
            random_seed=hparams.random_seed,
            pad=hparams.pad,
            input_max_len=hparams.input_max_len)
        model = model_creator(hparams,
                              mode,
                              iterator,
                              input_vocab_table=input_vocab_table,
                              reverse_input_vocab_table=None)
        return EvalModel(graph, model, input_file_placeholder,
                         output_file_placeholder, iterator)
Ejemplo n.º 12
0
def build_val_model(log_file, ckpt_dir, scope='validation'):
    model_creator = _get_model_creator()
    graph = tf.Graph()
    with graph.as_default(), tf.container(scope):
        # src_file = "%s.%s" % (PARAM.val_prefix, PARAM.src)
        # tgt_file = "%s.%s" % (PARAM.val_prefix, PARAM.tgt)
        # src_file = misc_utils.add_rootdir(src_file)
        # tgt_file = misc_utils.add_rootdir(tgt_file)
        src_file_ph = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_ph = tf.placeholder(shape=(), dtype=tf.string)
        vocab_tables = vocab_utils.create_vocab_word2id_tables(
            log_file)  # word->id
        src_vocab_table, tgt_vocab_table, src_table_size, tgt_table_size = vocab_tables

        val_set = dataset_utils.get_batch_inputs_form_dataset(
            log_file,
            src_file_ph,
            tgt_file_ph,
            src_vocab_table,
            tgt_vocab_table,
            shuffle=False,
            bucket=False,
            filter_zero_seq=False,
        )

        val_model = model_creator(
            log_file=log_file,
            mode=PARAM.MODEL_VALIDATE_KEY,
            source_id_seq=val_set.source_id_seq,
            source_seq_lengths=val_set.source_seq_lengths,
            tgt_vocab_table=tgt_vocab_table,
            src_vocab_size=src_table_size,
            tgt_vocab_size=tgt_table_size,
            target_in_id_seq=val_set.target_in_id_seq,
            target_out_id_seq=val_set.target_out_id_seq,
            target_seq_lengths=val_set.target_seq_lengths,
        )
        # init = tf.group(tf.global_variables_initializer(),
        #                 tf.local_variables_initializer())
        config_proto = misc_utils.get_session_config_proto()
        val_sess = tf.Session(config=config_proto, graph=graph)
        # val_sess.run(init)
        val_sess.run(tf.tables_initializer())

        # restore model
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        if ckpt and ckpt.model_checkpoint_path:
            tf.logging.set_verbosity(tf.logging.WARN)
            val_model.saver.restore(val_sess, ckpt.model_checkpoint_path)
            tf.logging.set_verbosity(tf.logging.INFO)
        else:
            msg = 'Checkpoint not found. code:fweikgn2394jasdjf2'
            tf.logging.fatal(msg)
            misc_utils.printinfo(msg, log_file, noPrt=True)

    return BuildModelOutputs(session=val_sess,
                             graph=graph,
                             model=val_model,
                             dataset=val_set)
Ejemplo n.º 13
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()  ####创建graph

    with graph.as_default(), tf.container(scope or "train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file))
        tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file))
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        ####tf.data的iterateor
        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid,
            use_char_encode=hparams.use_char_encode)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        if extra_args: model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            ####创建model
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  source_vocab_table=src_vocab_table,
                                  target_vocab_table=tgt_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)

###返回TrainModel,是个空的类,但设置好了参数?
    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
Ejemplo n.º 14
0
  def _LoopEnqueue(self, op):
    """Runs the enqueue op in a loop."""
    with tf.container(self._container_id), self._GetSession() as sess:
      if self.initialize_tables is not None:
        sess.run(self.initialize_tables)
      gsteps = self._model.global_step
      local_enqueue_steps = 0

      # Avoid calling trial.ShouldStop too often as it can slow down the
      # infeed queue by adding latency. last_should_stop_check_time tracks
      # the last time we made the call, and rate limits below.
      last_should_stop_check_time = 0

      # Global enqueue steps measures how many global steps have data enqueued
      # for already. We use this to terminate; note that the enqueue op may
      # hang in session.run if we do not terminate with this check.
      global_enqueue_steps = None

      # Each session run to the tpu trainer makes tpu_steps_per_loop. We need
      # to continue enqueueing beyond the max train steps since the tpu_steps
      # in the loop may exceed the max train steps. adjust_steps makes an
      # appropriate adjustment.
      adjust_steps = (
          self.params.train.tpu_steps_per_loop if py_utils.use_tpu() else 0)

      tf.logging.info('params.train.max_steps: %d, enqueue_max_steps: %d',
                      self.params.train.max_steps, FLAGS.enqueue_max_steps)
      while True:
        global_step, = sess.run([gsteps])
        if global_enqueue_steps is None:
          global_enqueue_steps = global_step
        if local_enqueue_steps % 1000 == 0:
          tf.logging.info(
              'Current global_enqueue_steps: %d, '
              'local_enqueue_steps: %d, global_step: %d', global_enqueue_steps,
              local_enqueue_steps, global_step)

        # Check trial.ShouldStop only every 10 seconds
        trial_should_stop = False
        if time.time() > last_should_stop_check_time + 10:
          trial_should_stop = self._trial.ShouldStop()
          last_should_stop_check_time = time.time()

        if (trial_should_stop or
            self._ShouldStop(sess, global_enqueue_steps - adjust_steps) or
            self._ShouldStop(sess, global_step)):
          tf.logging.info('Done. Params.train.max_steps reached.')
          return
        if (FLAGS.enqueue_max_steps > 0 and
            local_enqueue_steps > FLAGS.enqueue_max_steps):
          tf.logging.info('Done. FLAGS.enqueue_max_steps reached.')
          return
        local_enqueue_steps += 1

        # There are tpu_infeed_parallism parallel threads enqueuing.
        # We account for all of them when updating global_enqueue_steps.
        global_enqueue_steps += self.params.input.tpu_infeed_parallism

        sess.run([op])
Ejemplo n.º 15
0
def build_train_model(hparams, scope="train"):
    """Builds a training Seq2Seq model
        Args:
            hparams: a HParams object
            scope: scope of train model

        Returns:
            model: a NTModel tuple, representing a handle to the model
    """
    src_lang = hparams.src_lang
    src_vocab_file_name = hparams.src_vocab_file_name
    tgt_lang = hparams.tgt_lang
    tgt_vocab_file_name = hparams.tgt_vocab_file_name

    tf.reset_default_graph()

    train_graph = tf.Graph()
    with train_graph.as_default() as g:
        with tf.container(scope):
            src_vocab, tgt_vocab = load_vocabs(src_lang, src_vocab_file_name,
                                               tgt_lang, tgt_vocab_file_name)
            src_dataset_file_name = tf.placeholder(
                tf.string, name="src_dataset_file_name")
            tgt_dataset_file_name = tf.placeholder(
                tf.string, name="tgt_dataset_file_name")

            src_dataset = tf.data.TextLineDataset(src_dataset_file_name)
            tgt_dataset = tf.data.TextLineDataset(tgt_dataset_file_name)

            batch_size = tf.placeholder(tf.int64, name="batch_size")

            # maximum sequence length for training example
            max_len = tf.placeholder(tf.int64, name="max_len")

            iterator = Iterator(src_dataset,
                                src_vocab,
                                tgt_dataset,
                                tgt_vocab,
                                batch_size=batch_size,
                                max_len=max_len)

            # actual TensorFlow Dataset Iterator
            iterator_tf = iterator.create_iterator()

            model_class = _get_model_from_str_type(hparams.model_name)

            model = model_class(hparams, src_vocab, tgt_vocab)

            model_graph = model.build_graph(iterator_tf,
                                            tf.contrib.learn.ModeKeys.TRAIN,
                                            batch_size, g)

            return NTModel(src_vocab=src_vocab,
                           tgt_vocab=tgt_vocab,
                           iterator_tf=iterator_tf,
                           model_graph=model_graph,
                           model=model,
                           hparams=hparams,
                           mode=tf.contrib.learn.ModeKeys.TRAIN)
Ejemplo n.º 16
0
 def testContainer(self):
   with tf.Graph().as_default():
     v0 = tf.Variable([0])
     with tf.container("l1"):
       v1 = tf.Variable([1])
       with tf.container("l2"):
         v2 = tf.Variable([2])
         special_v = state_ops.variable_op([1], tf.float32, container="l3")
       v3 = tf.Variable([3])
     v4 = tf.Variable([4])
   self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l3"),
                    special_v.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
Ejemplo n.º 17
0
def create_train_model(model_creator, hparams, input_path, target_path, mode):
    graph = tf.Graph()
    with graph.as_default(), tf.container("train"):
        # create iterator over the train batches.
        iterator = get_dataset_iterator(hparams, input_path, target_path, hparams.batch_size)
        # create the actual model (the tf.graph).
        model = model_creator(hparams, mode, iterator)
        return TrainModel(graph, src.model, iterator)
def _train(args):
    container_name = ""

    R = lambda: nav_env.get_multiplexer_class(args.navtask, args.solver.task)
    m = utils.Foo()
    m.tf_graph = tf.Graph()

    config = tf.ConfigProto()
    config.device_count['GPU'] = 1

    with m.tf_graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(args.solver.ps_tasks,
                                               merge_devices=True)):
            with tf.container(container_name):
                m = args.setup_to_run(m,
                                      args,
                                      is_training=True,
                                      batch_norm_is_training=True,
                                      summary_mode='train')

                train_step_kwargs = args.setup_train_step_kwargs(
                    m,
                    R(),
                    os.path.join(args.logdir, 'train'),
                    rng_seed=args.solver.task,
                    is_chief=args.solver.task == 0,
                    num_steps=args.navtask.task_params.num_steps *
                    args.navtask.task_params.num_goals,
                    iters=1,
                    train_display_interval=args.summary.display_interval,
                    dagger_sample_bn_false=args.arch.dagger_sample_bn_false)

                delay_start = (
                    args.solver.task *
                    (args.solver.task + 1)) / 2 * FLAGS.delay_start_iters
                logging.error('delaying start for task %d by %d steps.',
                              args.solver.task, delay_start)

                additional_args = {}
                final_loss = slim.learning.train(
                    train_op=m.train_op,
                    logdir=args.logdir,
                    master=args.solver.master,
                    is_chief=args.solver.task == 0,
                    number_of_steps=args.solver.max_steps,
                    train_step_fn=tf_utils.train_step_custom_online_sampling,
                    train_step_kwargs=train_step_kwargs,
                    global_step=m.global_step_op,
                    init_op=m.init_op,
                    init_fn=m.init_fn,
                    sync_optimizer=m.sync_optimizer,
                    saver=m.saver_op,
                    startup_delay_steps=delay_start,
                    summary_op=None,
                    session_config=config,
                    **additional_args)
Ejemplo n.º 19
0
 def testContainer(self):
   with tf.Graph().as_default():
     v0 = tf.Variable([0])
     with tf.container("l1"):
       v1 = tf.Variable([1])
       with tf.container("l2"):
         v2 = tf.Variable([2])
         special_v = gen_state_ops._variable(shape=[1], dtype=tf.float32, 
             name="VariableInL3", container="l3", shared_name="")
       v3 = tf.Variable([3])
     v4 = tf.Variable([4])
   self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l3"),
                    special_v.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
Ejemplo n.º 20
0
 def testContainer(self):
   with tf.Graph().as_default():
     v0 = tf.Variable([0])
     with tf.container("l1"):
       v1 = tf.Variable([1])
       with tf.container("l2"):
         v2 = tf.Variable([2])
         special_v = gen_state_ops._variable(shape=[1], dtype=tf.float32, 
             name="VariableInL3", container="l3", shared_name="")
       v3 = tf.Variable([3])
     v4 = tf.Variable([4])
   self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l3"),
                    special_v.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
   self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
Ejemplo n.º 21
0
def create_selfplay_model(model_creator,
                          is_mutable,
                          num_workers,
                          jobid,
                          hparams,
                          scope=None,
                          extra_args=None):
    """create slef play models."""
    graph = tf.Graph()
    with graph.as_default(), tf.container(scope or "selfplay"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        if is_mutable:
            mutable_index = 0
        else:
            mutable_index = 1

        # get a list of iterators and placeholders
        iterators, placeholders = self_play_iterator_creator(
            hparams, num_workers, jobid)
        train_iterator, self_play_fulltext_iterator, self_play_structured_iterator = iterators
        data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder = placeholders

        # get an iterator handler
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=[
                                      dialogue_utils.mode_self_play_mutable,
                                      dialogue_utils.mode_self_play_immutable
                                  ][mutable_index],
                                  vocab_table=vocab_table,
                                  reverse_vocab_table=reverse_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)
    return SelfplayModel(graph=graph,
                         model=model,
                         placeholder_iterator=iterator,
                         placeholder_handle=handle,
                         train_iterator=train_iterator,
                         self_play_ft_iterator=self_play_fulltext_iterator,
                         self_play_st_iterator=self_play_structured_iterator,
                         data_placeholder=data_placeholder,
                         kb_placeholder=kb_placeholder,
                         skip_count_placeholder=skip_count_placeholder,
                         batch_size_placeholder=batch_size_placeholder)
Ejemplo n.º 22
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create graph, model and iterator for training."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        data_dataset = tf.data.TextLineDataset(hparams.train_data)
        kb_dataset = tf.data.TextLineDataset(hparams.train_kb)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)
        # this is the actual train_iterator
        train_iterator = iterator_utils.get_iterator(
            data_dataset,
            kb_dataset,
            vocab_table,
            batch_size=hparams.batch_size,
            t1=hparams.t1,
            t2=hparams.t2,
            eod=hparams.eod,
            len_action=hparams.len_action,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            max_dialogue_len=hparams.max_dialogue_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # this is the placeholder iterator. One can use this placeholder iterator
        # to switch between training and evauation.
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_iterator.output_types, train_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)
        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=batched_iterator,
                                  handle=handle,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  vocab_table=vocab_table,
                                  scope=scope,
                                  extra_args=extra_args,
                                  reverse_vocab_table=reverse_vocab_table)
    return TrainModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      train_iterator=train_iterator,
                      placeholder_handle=handle,
                      skip_count_placeholder=skip_count_placeholder)
Ejemplo n.º 23
0
def create_train_model(
        model_creator, hparams, scope=None, num_workers=1, jobid=0,
        extra_args=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(
                hparams,
                iterator=iterator,
                mode=tf.contrib.learn.ModeKeys.TRAIN,
                source_vocab_table=src_vocab_table,
                target_vocab_table=tgt_vocab_table,
                scope=scope,
                extra_args=extra_args)

    return TrainModel(
        graph=graph,
        model=model,
        iterator=iterator,
        skip_count_placeholder=skip_count_placeholder,
        insert_op=(src_vocab_table.init, tgt_vocab_table.init))
Ejemplo n.º 24
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    with graph.as_default(), tf.container(scope or "infer"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        reverse_src_vocab_table = lookup_ops.index_to_string_table_from_file(
            src_vocab_file, default_value=vocab_utils.UNK)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)

        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        src_dataset = tf.data.Dataset.from_tensor_slices(
            src_placeholder)
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tgt_placeholder)
        # set real values
        iterator_src = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        iterator_tgt = iterator_utils.get_infer_iterator(
            tgt_dataset,
            tgt_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.tgt_max_len_infer)
        model = model_creator(
            hparams,
            iterator_s2s=iterator_src,
            iterator_s2t=iterator_src,
            iterator_t2t=iterator_tgt,
            iterator_t2s=iterator_tgt,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_source_vocab_table=reverse_src_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return InferModel(
        graph=graph,
        model=model,
        src_placeholder=src_placeholder,
        tgt_placeholder=tgt_placeholder,
        batch_size_placeholder=batch_size_placeholder,
        iterator_src=iterator_src,
        iterator_tgt=iterator_tgt)
Ejemplo n.º 25
0
def create_unified_model(model_creator, hparams, scope=None):
  unified_graph = tf.Graph()

  with unified_graph.as_default(), tf.container(scope or tf.estimator.ModeKeys.TRAIN):
    unified_model = model_creator(hparams, mode=tf.estimator.ModeKeys.TRAIN, scope=scope)

  return TrainModel(
      graph=unified_graph,
      model=unified_model,
      placeholders=unified_model.placeholders)
Ejemplo n.º 26
0
 def testContainer(self):
     with tf.Graph().as_default():
         v0 = tf.Variable([0])
         with tf.container("l1"):
             v1 = tf.Variable([1])
             with tf.container("l2"):
                 v2 = tf.Variable([2])
                 special_v = state_ops.variable_op([1],
                                                   tf.float32,
                                                   container="l3")
             v3 = tf.Variable([3])
         v4 = tf.Variable([4])
     self.assertEqual(tf.compat.as_bytes(""), v0.op.get_attr("container"))
     self.assertEqual(tf.compat.as_bytes("l1"), v1.op.get_attr("container"))
     self.assertEqual(tf.compat.as_bytes("l2"), v2.op.get_attr("container"))
     self.assertEqual(tf.compat.as_bytes("l3"),
                      special_v.op.get_attr("container"))
     self.assertEqual(tf.compat.as_bytes("l1"), v3.op.get_attr("container"))
     self.assertEqual(tf.compat.as_bytes(""), v4.op.get_attr("container"))
Ejemplo n.º 27
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "infer"):
        vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file)
        reverse_vocab_table = lookup_ops.index_to_string_table_from_file(
            hparams.vocab_file, default_value=vocab_utils.UNK)

        data_src_placeholder = tf.placeholder(shape=[None],
                                              dtype=tf.string,
                                              name="src_ph")
        kb_placeholder = tf.placeholder(shape=[None],
                                        dtype=tf.string,
                                        name="kb_ph")
        batch_size_placeholder = tf.placeholder(shape=[],
                                                dtype=tf.int64,
                                                name="bs_ph")

        data_src_dataset = tf.data.Dataset.from_tensor_slices(
            data_src_placeholder)
        kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder)

        # this is the actual infer iterator
        infer_iterator = iterator_utils.get_infer_iterator(
            data_src_dataset,
            kb_dataset,
            vocab_table,
            batch_size=batch_size_placeholder,
            eod=hparams.eod,
            len_action=hparams.len_action)

        # this is the placeholder infer iterator
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, infer_iterator.output_types, infer_iterator.output_shapes)
        batched_iterator = iterator_utils.get_batched_iterator(iterator)

        model = model_creator(hparams,
                              iterator=batched_iterator,
                              handle=handle,
                              mode=tf.contrib.learn.ModeKeys.INFER,
                              vocab_table=vocab_table,
                              reverse_vocab_table=reverse_vocab_table,
                              scope=scope,
                              extra_args=extra_args)

    return InferModel(graph=graph,
                      model=model,
                      placeholder_iterator=iterator,
                      placeholder_handle=handle,
                      infer_iterator=infer_iterator,
                      data_src_placeholder=data_src_placeholder,
                      kb_placeholder=kb_placeholder,
                      batch_size_placeholder=batch_size_placeholder)
Ejemplo n.º 28
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    lbl_vocab_file = hparams.lbl_vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table, lbl_vocab_table = \
          vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file,
                                          lbl_vocab_file, hparams.share_vocab)
        reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
            tgt_vocab_file, default_value=vocab_utils.UNK)
        reverse_lbl_vocab_table = lookup_ops.index_to_string_table_from_file(
            lbl_vocab_file, default_value=vocab_utils.UNK)

        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        lbl_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
        lbl_dataset = tf.data.TextLineDataset(lbl_file_placeholder)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            lbl_dataset,
            src_vocab_table,
            tgt_vocab_table,
            lbl_vocab_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.EVAL,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            label_vocab_table=lbl_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            reverse_target_intent_vocab_table=reverse_lbl_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     lbl_file_placeholder=lbl_file_placeholder,
                     iterator=iterator)
Ejemplo n.º 29
0
def _train(args):
    agent = navi_env.Environment('5cf0e1e9493994e483e985c436b9d3bc', args.navi)
    Z = utils.Foo()
    Z.tf_graph = tf.Graph()

    config = tf.ConfigProto()
    config.device_count['GPU'] = 1

    with Z.tf_graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(args.solver.ps_tasks,
                                               merge_devices=True)):
            with tf.container("planner"):
                Z = args.setup_to_run(Z,
                                      args,
                                      is_training=True,
                                      batch_norm_is_training=True,
                                      summary_mode='train')
                train_step_kwargs = args.setup_train_step_kwargs(
                    Z,
                    agent,
                    os.path.join(args.logdir, 'train'),
                    rng_seed=args.solver.rng_seed,
                    is_chief=args.solver.rng_seed == 0,
                    num_steps=args.navi.num_steps * args.navi.num_goals,
                    iters=1,
                    train_display_interval=args.summary.display_interval,
                    dagger_sample_bn_false=args.solver.dagger_sample_bn_false)

                delay_start = (args.solver.task *
                               (args.solver.task + 1)) / 2 * DELAY_START_ITERS
                logging.info('delaying start for task %d by %d steps.',
                             args.solver.task, delay_start)

                additional_args = {}
                final_loss = slim.learning.train(
                    train_op=Z.train_op,
                    logdir=args.logdir,
                    is_chief=args.solver.task == 0,
                    number_of_steps=args.solver.max_steps,
                    train_step_fn=tf_utils.train_step_fn,
                    train_step_kwargs=train_step_kwargs,
                    master=args.solver.master,
                    global_step=Z.global_step_op,
                    init_op=Z.init_op,
                    init_fn=Z.init_fn,
                    sync_optimizer=Z.sync_optimizer,
                    saver=Z.saver_op,
                    save_summaries_secs=5000,
                    save_interval_secs=5000,
                    startup_delay_steps=delay_start,
                    summary_op=None,
                    session_config=config,
                    **additional_args)
Ejemplo n.º 30
0
def create_train_model(
    model_creator, hparams, scope=None, num_workers=1,
    jobid=0, extra_args=None):
    src_file = '{}.{}'.format(hparams.train_prefix, hparams.src)
    tgt_file = '{}.{}'.format(hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    grapth = tf.Graph()

    with graph.as_default(), tf.container(scope or 'train'):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            source_reverse=hparams.source_reverse,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        model_device_fn = None
        if extra_args:
            model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(
                hparams,
                iterator=iterator,
                mode=tf.contrib.learn.ModeKeys.TRAIN,
                source_vocab_table=src_vocab_table,
                target_vocab_table=tgt_vocab_table,
                scope=scope,
                extra_args=extra_args)

    return TrainModel(
        graph=graph,
        model=model,
        iterator=iterator,
        skip_count_placeholder=skip_count_placeholder)
Ejemplo n.º 31
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    speaker_file = hparams.speaker_file

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)

        spkr_table, _ = vocab_utils.create_vocab_tables(speaker_file, "", True)

        src_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_spkr_dataset = tf.data.TextLineDataset(src_spk_file_placeholder)
        tgt_spkr_dataset = tf.data.TextLineDataset(tgt_spk_file_placeholder)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            src_spkr_dataset,
            tgt_spkr_dataset,
            spkr_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(hparams,
                              iterator=iterator,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              source_vocab_table=src_vocab_table,
                              target_vocab_table=tgt_vocab_table,
                              speaker_table=spkr_table,
                              scope=scope,
                              extra_args=extra_args)
    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     src_spk_file_placeholder=src_spk_file_placeholder,
                     tgt_spk_file_placeholder=tgt_spk_file_placeholder,
                     iterator=iterator)
Ejemplo n.º 32
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
  """Create inference model."""
  # Create the infer model graph
  infer_graph = tf.Graph()

  with infer_graph.as_default(), tf.container(scope or tf.estimator.ModeKeys.PREDICT):
    infer_model = model_creator(hparams, mode=tf.estimator.ModeKeys.PREDICT, scope=scope)

  return InferModel(
      graph=infer_graph,
      model=infer_model,
      placeholders=infer_model.placeholders)
Ejemplo n.º 33
0
  def _Loop(self):
    with tf.container(
        self._container_id), self._GetSession(inline=False) as sess:
      # This initializes local tables
      sess.run(self.initialize_tables)

      path = None
      while True:
        path = self._FindNewCheckpoint(path, sess)
        if not path or self.DecodeCheckpoint(sess, path):
          break
      tf.logging.info('Decoding finished.')
def create_train_model(model_creator, hparams):

    graph = tf.Graph()
    with graph.as_default(), tf.container("train"):
        model = model_creator(
            hparams,
            tf.contrib.learn.ModeKeys.TRAIN,
        )
        return TrainModel(
            graph=graph,
            model=model,
        )
Ejemplo n.º 35
0
  def testMultipleContainers(self):
    with tf.container("test0"):
      v0 = tf.Variable(1.0, name="v0")
    with tf.container("test1"):
      v1 = tf.Variable(2.0, name="v0")
    server = tf.train.Server.create_local_server()
    sess = tf.Session(server.target)
    sess.run(tf.global_variables_initializer())
    self.assertAllEqual(1.0, sess.run(v0))
    self.assertAllEqual(2.0, sess.run(v1))

    # Resets container. Session aborts.
    tf.Session.reset(server.target, ["test0"])
    with self.assertRaises(tf.errors.AbortedError):
      sess.run(v1)

    # Connects to the same target. Device memory for the v0 would have
    # been released, so it will be uninitialized. But v1 should still
    # be valid.
    sess = tf.Session(server.target)
    with self.assertRaises(tf.errors.FailedPreconditionError):
      sess.run(v0)
    self.assertAllEqual(2.0, sess.run(v1))
Ejemplo n.º 36
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.EVAL,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)
    return EvalModel(
        graph=graph,
        model=model,
        src_file_placeholder=src_file_placeholder,
        tgt_file_placeholder=tgt_file_placeholder,
        iterator=iterator,
        insert_op=(src_vocab_table.init, tgt_vocab_table.init))
Ejemplo n.º 37
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
  """Create inference model."""
  graph = tf.Graph()
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  with graph.as_default(), tf.container(scope or "infer"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)

    src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

    src_dataset = tf.data.Dataset.from_tensor_slices(
        src_placeholder)
    iterator = iterator_utils.get_infer_iterator(
        src_dataset,
        src_vocab_table,
        batch_size=batch_size_placeholder,
        eos=hparams.eos,
        src_max_len=hparams.src_max_len_infer)
    model = model_creator(
        hparams,
        iterator=iterator,
        mode=tf.contrib.learn.ModeKeys.INFER,
        source_vocab_table=src_vocab_table,
        target_vocab_table=tgt_vocab_table,
        reverse_target_vocab_table=reverse_tgt_vocab_table,
        scope=scope,
        extra_args=extra_args)
  return InferModel(
      graph=graph,
      model=model,
      src_placeholder=src_placeholder,
      batch_size_placeholder=batch_size_placeholder,
      iterator=iterator)
Ejemplo n.º 38
0
def evaluate_model(hparams, data, train_dir, log, id_to_word,
                   data_ngram_counts):
  """Evaluate MaskGAN model.

  Args:
    hparams:  Hyperparameters for the MaskGAN.
    data: Data to evaluate.
    train_dir: Path to a directory containing checkpoints.
    id_to_word: Dictionary of indices to words.
    data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
      data_set.
  """
  tf.logging.error('Evaluate model.')

  # Boolean indicating operational mode.
  is_training = False

  if FLAGS.mode == MODE_VALIDATION:
    logdir = FLAGS.base_directory + '/validation'
  elif FLAGS.mode == MODE_TRAIN_EVAL:
    logdir = FLAGS.base_directory + '/train_eval'
  elif FLAGS.mode == MODE_TEST:
    logdir = FLAGS.base_directory + '/test'
  else:
    raise NotImplementedError

  # Wait for a checkpoint to exist.
  print(train_dir)
  print(tf.train.latest_checkpoint(train_dir))
  while not tf.train.latest_checkpoint(train_dir):
    tf.logging.error('Waiting for checkpoint...')
    print('Waiting for checkpoint...')
    time.sleep(10)

  with tf.Graph().as_default():
    # Use a separate container for each trial
    container_name = ''
    with tf.container(container_name):

      # Construct the model.
      if FLAGS.num_rollouts == 1:
        model = create_MaskGAN(hparams, is_training)
      elif FLAGS.num_rollouts > 1:
        model = rollout.create_rollout_MaskGAN(hparams, is_training)
      else:
        raise ValueError

      # Create the supervisor.  It will take care of initialization, summaries,
      # checkpoints, and recovery.  We only pass the trainable variables
      # to load since things like baselines keep batch_size which may not
      # match between training and evaluation.
      evaluation_variables = tf.trainable_variables()
      evaluation_variables.append(model.global_step)
      eval_saver = tf.train.Saver(var_list=evaluation_variables)
      sv = tf.Supervisor(logdir=logdir)
      sess = sv.PrepareSession(FLAGS.eval_master, start_standard_services=False)

      tf.logging.info('Before sv.Loop.')
      sv.Loop(FLAGS.eval_interval_secs, evaluate_once,
              (data, sv, model, sess, train_dir, log, id_to_word,
               data_ngram_counts, eval_saver))

      sv.WaitForStop()
      tf.logging.info('sv.Stop().')
      sv.Stop()
Ejemplo n.º 39
0
def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
    """Create inference model."""
    graph = tf.Graph()
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    # REvo added
    tgt_table = codecs.open(src_vocab_file, 'r').readlines()
    tmp_ids = []
    tmp_words = []
    for i in range(len(tgt_table)):
        tmp_ids.append(i)
        tmp_words.append(tgt_table[i].strip())


    with graph.as_default(), tf.container(scope or "infer"):
        # Constant vocab table
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        # reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        #     tgt_vocab_file, default_value=vocab_utils.UNK, name="reverse_table")
        # added
        vals = tf.constant(tmp_words, dtype=tf.string)
        keys = tf.constant(tmp_ids, dtype=tf.int64)
        reverse_tgt_vocab_table = lookup_ops.HashTable(lookup_ops.KeyValueTensorInitializer(keys, vals), "<unk>", name="reverse_table")
        #

        # debug
        print ("SRC:", src_vocab_table)
        print ("SRC type:", type(src_vocab_table))
        #
        src_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="src_place")
        # batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64, name="batch_place")
        batch_size_placeholder = tf.constant(1, dtype=tf.int64, name="batch_place")

        src_dataset = tf.data.Dataset.from_tensor_slices(
            src_placeholder)
        iterator = iterator_utils.get_infer_iterator(
            src_dataset,
            src_vocab_table,
            batch_size=batch_size_placeholder,
            eos=hparams.eos,
            src_max_len=hparams.src_max_len_infer)
        model = model_creator(
            hparams,
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.INFER,
            source_vocab_table=src_vocab_table,
            target_vocab_table=tgt_vocab_table,
            reverse_target_vocab_table=reverse_tgt_vocab_table,
            scope=scope,
            extra_args=extra_args)

        # Debug
        # with tf.Session() as sess:
        #     # init
        #     sess.run(
        #         iterator.initializer,
        #         feed_dict={
        #             src_placeholder: iterator.infer_data,
        #             batch_size_placeholder: 64
        #         })
        #     value = sess.run(iterator.source)
        #     print ("value:", value)
        # sys.exit()

    return InferModel(
        graph=graph,
        model=model,
        src_placeholder=src_placeholder,
        batch_size_placeholder=batch_size_placeholder,
        iterator=iterator,
        insert_op=(src_vocab_table.init, tgt_vocab_table.init, reverse_tgt_vocab_table.init))
Ejemplo n.º 40
0
 def testContainer(self):
   with tf.Graph().as_default():
     with tf.container("test"):
       q = tf.FIFOQueue(10, tf.float32)
   self.assertEqual(tf.compat.as_bytes("test"),
                    q.queue_ref.op.get_attr("container"))
Ejemplo n.º 41
0
except KeyError as e:
    logging.error("Key '{k}' not in dictionary".format(k=e.message))
    raise

if args.start is True:
    import tensorflow as tf
    import numpy as np
    import time

    server = tf.train.Server(cluster, job_name=job, task_index=task)
    logging.info("Server Target is '{st}'".format(st=server.target))

    g = tf.Graph()

    with g.as_default():
        with tf.container('shared'):
            queue_in = tf.FIFOQueue(10, [tf.int32],
                name='queue_in',
                shared_name='master_queue_in')

            queue_out = tf.FIFOQueue(10, [tf.string],
                name='queue_out',
                shared_name='master_queue_out')

            tmp = tf.Variable(-1, tf.float32, name='master_tmp')

        do_deq = queue_in.dequeue()
        do_enq = queue_out.enqueue("Hello World")

    with tf.Session(server.target, graph=g) as S:
        S.run(tf.initialize_local_variables())
Ejemplo n.º 42
0
def run_training(config=None, tuner=None, logdir=None, trial_name=None,
                 is_chief=True):
  """Do all training runs.

  This is the top level training function for policy gradient based models.
  Run this from the main function.

  Args:
    config: config_lib.Config instance containing global config (agent and
        environment hparams). If None, config will be parsed from FLAGS.config.
    tuner: A tuner instance. Leave as None if not tuning.
    logdir: Parent directory where all data from all runs will be written. If
        None, FLAGS.logdir will be used.
    trial_name: If tuning, set this to a unique string that identifies this
        trial. If `tuner` is not None, this also must be set.
    is_chief: True if this worker is the chief.

  Returns:
    List of results dicts which were written to disk. Each training run gets a
    results dict. Results dict contains metrics, i.e. (name, value) pairs which
    give information about the training run.

  Raises:
    ValueError: If results dicts read from disk contain invalid data.
  """
  if not config:
    # If custom config is not given, get it from flags.
    config = defaults.default_config_with_updates(FLAGS.config)
  if not logdir:
    logdir = FLAGS.logdir
  if not tf.gfile.Exists(logdir):
    tf.gfile.MakeDirs(logdir)
  assert FLAGS.num_repetitions > 0
  results = results_lib.Results(logdir)
  results_list, _ = results.read_all()

  logging.info('Starting experiment. Directory: "%s"', logdir)

  if results_list:
    if results_list[0]['max_npe'] != FLAGS.max_npe:
      raise ValueError(
          'Cannot resume training. Max-NPE changed. Was %s, now %s',
          results_list[0]['max_npe'], FLAGS.max_npe)
    if results_list[0]['max_global_repetitions'] != FLAGS.num_repetitions:
      raise ValueError(
          'Cannot resume training. Number of repetitions changed. Was %s, '
          'now %s',
          results_list[0]['max_global_repetitions'],
          FLAGS.num_repetitions)

  while len(results_list) < FLAGS.num_repetitions:
    run_number = len(results_list)
    rep_container_name = trial_name if trial_name else 'container'
    if FLAGS.num_repetitions > 1:
      rep_dir = os.path.join(logdir, 'run_%d' % run_number)
      rep_container_name = rep_container_name + '_run_' + str(run_number)
    else:
      rep_dir = logdir

    logging.info(
        'Starting repetition %d (%d out of %d)', run_number, run_number + 1,
        FLAGS.num_repetitions)

    # Train will write result to disk.
    with tf.container(rep_container_name):
      trainer = train(config, is_chief, tuner, rep_dir, run_number, results)
    logging.info('Done training.')

    if is_chief:
      # Destroy current container immediately (clears current graph).
      logging.info('Clearing shared variables.')
      tf.Session.reset(FLAGS.master, containers=[rep_container_name])
      logging.info('Shared variables cleared.')

      # Delete replay buffer on disk.
      assert trainer
      trainer.delete_replay_buffer()
    else:
      # Give chief worker time to clean up.
      sleep_sec = 30.0
      logging.info('Sleeping for %s sec.', sleep_sec)
      time.sleep(sleep_sec)
    tf.reset_default_graph()
    logging.info('Default graph reset.')

    # Expecting that train wrote new result to disk before returning.
    results_list, _ = results.read_all()
  return results_list
Ejemplo n.º 43
0
def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
  """Train model.

  Args:
    hparams: Hyperparameters for the MaskGAN.
    data: Data to evaluate.
    log_dir: Directory to save checkpoints.
    log: Readable log for the experiment.
    id_to_word: Dictionary of indices to words.
    data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
      data_set.
  """
  print('Training model.')
  tf.logging.info('Training model.')

  # Boolean indicating operational mode.
  is_training = True

  # Write all the information to the logs.
  log.write('hparams\n')
  log.write(str(hparams))
  log.flush()

  is_chief = FLAGS.task == 0

  with tf.Graph().as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      container_name = ''
      with tf.container(container_name):
        # Construct the model.
        if FLAGS.num_rollouts == 1:
          model = create_MaskGAN(hparams, is_training)
        elif FLAGS.num_rollouts > 1:
          model = rollout.create_rollout_MaskGAN(hparams, is_training)
        else:
          raise ValueError

        print('\nTrainable Variables in Graph:')
        for v in tf.trainable_variables():
          print(v)

        ## Retrieve the initial savers.
        init_savers = model_utils.retrieve_init_savers(hparams)

        ## Initial saver function to supervisor.
        init_fn = partial(model_utils.init_fn, init_savers)

        # Create the supervisor.  It will take care of initialization,
        # summaries, checkpoints, and recovery.
        sv = tf.train.Supervisor(
            logdir=log_dir,
            is_chief=is_chief,
            saver=model.saver,
            global_step=model.global_step,
            save_model_secs=60,
            recovery_wait_secs=30,
            summary_op=None,
            init_fn=init_fn)

        # Get an initialized, and possibly recovered session.  Launch the
        # services: Checkpointing, Summaries, step counting.
        #
        # When multiple replicas of this program are running the services are
        # only launched by the 'chief' replica.
        with sv.managed_session(FLAGS.master) as sess:

          ## Pretrain the generator.
          if FLAGS.gen_pretrain_steps:
            pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log,
                                                 id_to_word, data_ngram_counts,
                                                 is_chief)

          ## Pretrain the discriminator.
          if FLAGS.dis_pretrain_steps:
            pretrain_mask_gan.pretrain_discriminator(
                sv, sess, model, data, log, id_to_word, data_ngram_counts,
                is_chief)

          # Initial indicators for printing and summarizing.
          print_step_division = -1
          summary_step_division = -1

          # Run iterative computation in a loop.
          while not sv.ShouldStop():
            is_present_rate = FLAGS.is_present_rate

            if FLAGS.is_present_rate_decay is not None:
              is_present_rate *= (1. - FLAGS.is_present_rate_decay)

            model_utils.assign_percent_real(sess, model.percent_real_update,
                                            model.new_rate, is_present_rate)

            # GAN training.
            avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
            cumulative_costs = 0.
            gen_iters = 0

            # Generator and Discriminator statefulness initial evaluation.
            # TODO(liamfedus): Throughout the code I am implicitly assuming
            # that the Generator and Discriminator are equal sized.
            [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                [model.eval_initial_state, model.fake_gen_initial_state])
            dis_initial_state_eval = fake_gen_initial_state_eval

            # Save zeros state to reset later.
            zeros_state = fake_gen_initial_state_eval

            ## Offset Discriminator.
            if FLAGS.ps_tasks == 0:
              dis_offset = 1
            else:
              dis_offset = FLAGS.task * 1000 + 1
            dis_iterator = get_iterator(data)

            for i in range(dis_offset):
              try:
                dis_x, dis_y, _ = next(dis_iterator)
              except StopIteration:
                dis_iterator = get_iterator(data)
                dis_initial_state_eval = zeros_state
                dis_x, dis_y, _ = next(dis_iterator)

              p = model_utils.generate_mask()

              # Construct the train feed.
              train_feed = {
                  model.inputs: dis_x,
                  model.targets: dis_y,
                  model.present: p
              }

              if FLAGS.data_set == 'ptb':
                # Statefulness of the Generator being used for Discriminator.
                for i, (c, h) in enumerate(model.fake_gen_initial_state):
                  train_feed[c] = dis_initial_state_eval[i].c
                  train_feed[h] = dis_initial_state_eval[i].h

                # Determine the state had the Generator run over real data.  We
                # use this state for the Discriminator.
                [dis_initial_state_eval] = sess.run(
                    [model.fake_gen_final_state], train_feed)

            ## Training loop.
            iterator = get_iterator(data)
            gen_initial_state_eval = zeros_state

            if FLAGS.ps_tasks > 0:
              gen_offset = FLAGS.task * 1000 + 1
              for i in range(gen_offset):
                try:
                  next(iterator)
                except StopIteration:
                  dis_iterator = get_iterator(data)
                  dis_initial_state_eval = zeros_state
                  next(dis_iterator)

            for x, y, _ in iterator:
              for _ in xrange(hparams.dis_train_iterations):
                try:
                  dis_x, dis_y, _ = next(dis_iterator)
                except StopIteration:
                  dis_iterator = get_iterator(data)
                  dis_initial_state_eval = zeros_state
                  dis_x, dis_y, _ = next(dis_iterator)

                  if FLAGS.data_set == 'ptb':
                    [dis_initial_state_eval] = sess.run(
                        [model.fake_gen_initial_state])

                p = model_utils.generate_mask()

                # Construct the train feed.
                train_feed = {
                    model.inputs: dis_x,
                    model.targets: dis_y,
                    model.present: p
                }

                # Statefulness for the Discriminator.
                if FLAGS.data_set == 'ptb':
                  for i, (c, h) in enumerate(model.fake_gen_initial_state):
                    train_feed[c] = dis_initial_state_eval[i].c
                    train_feed[h] = dis_initial_state_eval[i].h

                _, dis_loss_eval, step = sess.run(
                    [model.dis_train_op, model.dis_loss, model.global_step],
                    feed_dict=train_feed)

                # Determine the state had the Generator run over real data.
                # Use this state for the Discriminator.
                [dis_initial_state_eval] = sess.run(
                    [model.fake_gen_final_state], train_feed)

              # Randomly mask out tokens.
              p = model_utils.generate_mask()

              # Construct the train feed.
              train_feed = {model.inputs: x, model.targets: y, model.present: p}

              # Statefulness for Generator.
              if FLAGS.data_set == 'ptb':
                tf.logging.info('Generator is stateful.')
                print('Generator is stateful.')
                # Statefulness for *evaluation* Generator.
                for i, (c, h) in enumerate(model.eval_initial_state):
                  train_feed[c] = gen_initial_state_eval[i].c
                  train_feed[h] = gen_initial_state_eval[i].h

                # Statefulness for Generator.
                for i, (c, h) in enumerate(model.fake_gen_initial_state):
                  train_feed[c] = fake_gen_initial_state_eval[i].c
                  train_feed[h] = fake_gen_initial_state_eval[i].h

              # Determine whether to decay learning rate.
              lr_decay = hparams.gen_learning_rate_decay**max(
                  step + 1 - hparams.gen_full_learning_rate_steps, 0.0)

              # Assign learning rate.
              gen_learning_rate = hparams.gen_learning_rate * lr_decay
              model_utils.assign_learning_rate(sess, model.learning_rate_update,
                                               model.new_learning_rate,
                                               gen_learning_rate)

              [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run(
                  [
                      model.gen_train_op, model.gen_loss,
                      model.avg_log_perplexity, model.global_step
                  ],
                  feed_dict=train_feed)

              cumulative_costs += gen_log_perplexity_eval
              gen_iters += 1

              # Determine the state had the Generator run over real data.
              [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                  [model.eval_final_state,
                   model.fake_gen_final_state], train_feed)

              avg_epoch_dis_loss.append(dis_loss_eval)
              avg_epoch_gen_loss.append(gen_loss_eval)

              ## Summaries.
              # Calulate rolling perplexity.
              perplexity = np.exp(cumulative_costs / gen_iters)

              if is_chief and (step / FLAGS.summaries_every >
                               summary_step_division):
                summary_step_division = step / FLAGS.summaries_every

                # Confirm perplexity is not infinite.
                if (not np.isfinite(perplexity) or
                    perplexity >= FLAGS.perplexity_threshold):
                  print('Training raising FloatingPoinError.')
                  raise FloatingPointError(
                      'Training infinite perplexity: %.3f' % perplexity)

                # Graph summaries.
                summary_str = sess.run(
                    model.merge_summaries_op, feed_dict=train_feed)
                sv.SummaryComputed(sess, summary_str)

                # Summary:  n-gram
                avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
                for n, data_ngram_count in data_ngram_counts.iteritems():
                  batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                      sess, model.fake_sequence, log, train_feed,
                      data_ngram_count, int(n))
                  summary_percent_str = tf.Summary(value=[
                      tf.Summary.Value(
                          tag='general/%s-grams_percent_correct' % n,
                          simple_value=batch_percent_captured)
                  ])
                  sv.SummaryComputed(
                      sess, summary_percent_str, global_step=step)

                # Summary:  geometric_avg
                geometric_avg = compute_geometric_average(avg_percent_captured)
                summary_geometric_avg_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/geometric_avg', simple_value=geometric_avg)
                ])
                sv.SummaryComputed(
                    sess, summary_geometric_avg_str, global_step=step)

                # Summary:  arithmetic_avg
                arithmetic_avg = compute_arithmetic_average(
                    avg_percent_captured)
                summary_arithmetic_avg_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/arithmetic_avg',
                        simple_value=arithmetic_avg)
                ])
                sv.SummaryComputed(
                    sess, summary_arithmetic_avg_str, global_step=step)

                # Summary:  perplexity
                summary_perplexity_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/perplexity', simple_value=perplexity)
                ])
                sv.SummaryComputed(
                    sess, summary_perplexity_str, global_step=step)

              ## Printing and logging
              if is_chief and (step / FLAGS.print_every > print_step_division):
                print_step_division = (step / FLAGS.print_every)
                print('global_step: %d' % step)
                print(' perplexity: %.3f' % perplexity)
                print(' gen_learning_rate: %.6f' % gen_learning_rate)
                log.write('global_step: %d\n' % step)
                log.write(' perplexity: %.3f\n' % perplexity)
                log.write(' gen_learning_rate: %.6f' % gen_learning_rate)

                # Average percent captured for each of the n-grams.
                avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
                for n, data_ngram_count in data_ngram_counts.iteritems():
                  batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                      sess, model.fake_sequence, log, train_feed,
                      data_ngram_count, int(n))
                  avg_percent_captured[n] = batch_percent_captured
                  print(' percent of %s-grams captured: %.3f.' %
                        (n, batch_percent_captured))
                  log.write(' percent of %s-grams captured: %.3f.\n' %
                            (n, batch_percent_captured))
                geometric_avg = compute_geometric_average(avg_percent_captured)
                print(' geometric_avg: %.3f.' % geometric_avg)
                log.write(' geometric_avg: %.3f.' % geometric_avg)
                arithmetic_avg = compute_arithmetic_average(
                    avg_percent_captured)
                print(' arithmetic_avg: %.3f.' % arithmetic_avg)
                log.write(' arithmetic_avg: %.3f.' % arithmetic_avg)

                evaluation_utils.print_and_log_losses(
                    log, step, is_present_rate, avg_epoch_dis_loss,
                    avg_epoch_gen_loss)

                if FLAGS.gen_training_strategy == 'reinforce':
                  evaluation_utils.generate_RL_logs(sess, model, log,
                                                    id_to_word, train_feed)
                else:
                  evaluation_utils.generate_logs(sess, model, log, id_to_word,
                                                 train_feed)
                log.flush()

  log.close()