def build_model(self, model_fn, params):
        """Build the TPU model and infeed enqueue ops."""
        tf.logging.info("EvalLowLevelRunner: build_model method")

        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.PREDICT, params)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(utils.device_for_tpu_core(self._get_host(0))):
                outfeed_enqueue_ops = tpu.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()

        def eval_loop():
            return tpu.repeat(self.eval_steps, tpu_eval_step, [])

        def create_dequeue_ops():
            """Create outfeed dequeue ops."""
            dequeue_ops = []
            tensor_dtypes = []
            tensor_shapes = []
            for v in self.outfeed_tensors:
                dequeue_ops.append([])
                tensor_dtypes.append(v.dtype)
                tensor_shapes.append(v.shape)
            for i in range(FLAGS.num_shards):
                with tf.device(utils.device_for_host(self._get_host(0))):
                    outfeed_tensors = tpu.outfeed_dequeue_tuple(
                        dtypes=tensor_dtypes,
                        shapes=tensor_shapes,
                        device_ordinal=i)
                    for j, item in enumerate(outfeed_tensors):
                        dequeue_ops[j].append(item)
            for j in range(len(outfeed_tensors)):
                dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
            return dequeue_ops

        with self.graph.as_default():
            (self.eval_op, ) = tpu.shard(
                eval_loop,
                inputs=[],
                num_shards=FLAGS.num_shards,
                outputs_from_all_shards=False,
            )

            for i, dequeue_tenor in enumerate(create_dequeue_ops()):
                self.dequeue_ops[self.outfeed_names[i]] = dequeue_tenor

            self.saver = tf.train.Saver()
Example #2
0
    def build_model(self, model_fn, params):
        """Build the TPU model and infeed enqueue ops."""
        tf.logging.info("TrainLowLevelRunner: build_model method")

        def tpu_train_step(loss):
            """Generate the TPU graph."""
            del loss
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            labels = unflattened_inputs["labels"]
            estimator_spec = model_fn(features, labels,
                                      tf.estimator.ModeKeys.TRAIN, params)
            loss, train_op = estimator_spec.loss, estimator_spec.train_op
            with tf.control_dependencies([train_op]):
                return tf.identity(loss)

        def train_loop():
            return tpu.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS])

        with self.graph.as_default():
            (self.loss, ) = tpu.shard(
                train_loop,
                inputs=[],
                num_shards=self.hparams.num_shards,
                outputs_from_all_shards=False,
            )
            global_initializer = tf.global_variables_initializer()
            local_initializer = tf.local_variables_initializer()
            graph_io.write_graph(self.graph.as_graph_def(add_shapes=True),
                                 self.hparams.out_dir, "graph.pbtxt")
            self.saver = tf.train.Saver()

        self.sess.run(global_initializer)
        self.sess.run(local_initializer)

        checkpoint_path = tf.train.latest_checkpoint(self.hparams.out_dir)
        if checkpoint_path:
            self.saver.restore(self.sess, checkpoint_path)

        with self.graph.as_default():
            for hook in self.hooks:
                hook.after_create_session(self.sess, None)
Example #3
0
def build_tpu_loop(infeed_queue, params, batch_count, embedding, mode):
    """Build op to run loops on TPU."""
    if mode == tpu_embedding.TRAINING:

        def tpu_step_fn(labels):
            """Create one step in training."""
            logits = logits_fn(embedding, params)

            if FLAGS.lazy_adam:
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=params["learning_rate"],
                    beta1=params["beta1"],
                    beta2=params["beta2"],
                    epsilon=params["epsilon"])
            else:
                optimizer = tf.contrib.opt.LazyAdamOptimizer(
                    learning_rate=params["learning_rate"],
                    beta1=params["beta1"],
                    beta2=params["beta2"],
                    epsilon=params["epsilon"])
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

            # Softmax with the first column of ones is equivalent to sigmoid.
            softmax_logits = tf.concat(
                [tf.ones(logits.shape, dtype=logits.dtype), logits], axis=1)

            loss = tf.losses.sparse_softmax_cross_entropy(
                labels=labels, logits=softmax_logits)

            minimize_op = optimizer.minimize(loss)
            with tf.control_dependencies([minimize_op]):
                send_gradient_op = embedding.generate_send_gradients_op()

            return send_gradient_op

        def tpu_loop_fn():
            return tpu.repeat(batch_count,
                              tpu_step_fn,
                              infeed_queue=infeed_queue)

        tpu_loop = tpu.shard(tpu_loop_fn, num_shards=embedding.num_cores)
        return tpu_loop
    else:

        def tpu_step_fn(total, count, duplicate_mask):
            """One step in evaluation."""
            logits = logits_fn(embedding, params)
            in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
                logits, duplicate_mask, FLAGS.ml_perf)
            metric_weights = tf.cast(metric_weights, tf.float32)
            total += tf.reduce_sum(tf.multiply(in_top_k, metric_weights))
            count += tf.reduce_sum(metric_weights)
            return total, count

        inputs = [tf.constant(0.), tf.constant(0.)]

        def tpu_loop_fn():
            return tpu.repeat(batch_count,
                              tpu_step_fn,
                              inputs,
                              infeed_queue=infeed_queue)

        tpu_loop = tpu.shard(tpu_loop_fn, num_shards=embedding.num_cores)
        return tpu_loop
Example #4
0
    def initialize(self, input_fn, model_fn, params):
        """Build graph and do initialization for training."""
        tf.logging.info("TrainLowLevelRunner: initialize method")

        for i in range(self.num_hosts):
            self.build_enqueue_ops(input_fn, params, host_id=i)

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            # Build infeed sesssion
            self.input_sess = tf.Session(
                self.tpu_cluster_resolver.get_master(),
                graph=self.input_graph,
                config=self.session_config)
            # Initialize dataset variables
            self.input_sess.run(self.dataset_initializer)
            # Run infeed session.run calls
            while True:
                iterations = self.queue.get(block=True)
                if iterations == _STOP:
                    return
                tf.logging.info("Start to infeed %d batches", iterations)
                self.input_sess.run([self.enqueue_ops])

        def tpu_train_step(loss):
            """Generate the TPU graph."""
            del loss
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            labels = unflattened_inputs["labels"]
            estimator_spec = model_fn(features, labels,
                                      tf.estimator.ModeKeys.TRAIN, params)
            loss, train_op = estimator_spec.loss, estimator_spec.train_op
            self.scaffold_fn = estimator_spec.scaffold_fn
            with tf.control_dependencies([train_op]):
                return tf.identity(loss)

        @tpu_function.on_device_training_loop
        def train_loop():
            return tpu.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS])

        self.train_graph = tf.Graph()
        with self.train_graph.as_default():
            (self.loss, ) = tpu.shard(
                train_loop,
                inputs=[],
                num_shards=self.num_shards,
                outputs_from_all_shards=False,
                device_assignment=self.device_assignment,
            )
            if self.scaffold_fn:
                self.scaffold_fn()
            global_initializer = tf.global_variables_initializer()
            local_initializer = tf.local_variables_initializer()
            graph_io.write_graph(
                self.input_graph.as_graph_def(add_shapes=True),
                FLAGS.model_dir, "input_graph.pbtxt")
            graph_io.write_graph(
                self.train_graph.as_graph_def(add_shapes=True),
                FLAGS.model_dir, "graph.pbtxt")
            self.saver = tf.train.Saver()

        # Build tpu train model session and initialize graph
        self.train_sess = tf.Session(self.tpu_cluster_resolver.get_master(),
                                     graph=self.train_graph,
                                     config=self.session_config)

        self.train_sess.run(global_initializer)
        self.train_sess.run(local_initializer)

        # Complete infeed graph generation and session.run calls
        self.infeed_thread = threading.Thread(target=infeed_thread_fn)
        self.infeed_thread.start()
    def initialize_eval(self, params, eval_input_fn, model_fn):
        """Initialize eval."""

        self.eval_infeed_queue = []

        for i in range(0, self.num_hosts):
            self.build_enqueue_ops(eval_input_fn,
                                   params,
                                   host_id=i,
                                   is_training=False)

        eval_step = self.get_tpu_step(params, model_fn, is_training=False)

        @tpu_function.on_device_training_loop
        def eval_loop():
            with tf.variable_scope("resnet", reuse=tf.AUTO_REUSE):
                return tpu.repeat(int(self.eval_steps), eval_step,
                                  [_INITIAL_LOSS])

        def train_eval_step(loss):
            del loss
            with tf.control_dependencies(self.train_loop()):
                return eval_loop()

        @on_device_train_and_eval_loops
        def train_eval_loop():
            return tpu.repeat(self.max_train_iterations, train_eval_step,
                              [_INITIAL_LOSS])

        def create_dequeue_ops(host_id):
            """Create deque ops graph function."""
            dequeue_ops = []
            tensor_dtypes = []
            tensor_shapes = []
            for v in self.eval_tensors:
                dequeue_ops.append([])
                tensor_dtypes.append(v.dtype)
                tensor_shapes.append(v.shape)
            for i in range(FLAGS.tpu_cores_per_host):
                with tf.device(device_for_host(self.get_host(host_id))):
                    outfeed_tensors = tpu.outfeed_dequeue_tuple(
                        dtypes=tensor_dtypes,
                        shapes=tensor_shapes,
                        device_ordinal=i)
                    for j, item in enumerate(outfeed_tensors):
                        dequeue_ops[j].append(item)
            for j in range(len(outfeed_tensors)):
                dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
            return dequeue_ops

        with self.graph.as_default():
            with tf.variable_scope("resnet", reuse=True):
                (self.train_eval_op, ) = tpu.shard(
                    train_eval_loop,
                    inputs=[],
                    num_shards=self.num_cores,
                    outputs_from_all_shards=False)

                graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True),
                                     FLAGS.model_dir, "graph.pbtxt")

        with self.eval_output_graph.as_default():
            with tf.variable_scope("resnet", reuse=True):
                for i in range(0, self.num_hosts):
                    host_dequeue_ops = create_dequeue_ops(i)
                    for j, dequeue_tenor in enumerate(host_dequeue_ops):
                        self.dequeue_ops[j].append(dequeue_tenor)

                for j, _ in enumerate(self.eval_tensors):
                    self.dequeue_ops[j] = tf.concat(self.dequeue_ops[j],
                                                    axis=0)

                with tf.device(device_for_host(self.get_host(0))):
                    metrics = self.eval_metrics[0](*self.dequeue_ops)
                metric_update_ops = []
                metric_value_ops = {}
                for (k, v) in metrics.items():
                    metric_update_ops.append(v[1])
                    metric_value_ops[k] = v[0]
                self.metric_update_ops = metric_update_ops
                self.metric_value_ops = metric_value_ops

                self.metric_initializer = tf.variables_initializer(
                    tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
  def initialize(self, input_fn, model_fn, params):
    """Build graphs for the TPU device and the input pipelines.

    Args:
      input_fn: Dataset input graph generation function
      model_fn: Model definition function
      params:  Parameters to input and model functions
    """

    tf.logging.info("TrainRunner: initialize method")

    def infeed_thread_fn():
      """Build and infeed session.run calls in a background thread."""
      i = 1
      while i < FLAGS.num_cores // FLAGS.tpu_cores_per_host:
        self.build_enqueue_ops(input_fn, params, i)
        i += 1
      # Build infeed sesssion
      self.input_sess = tf.Session(
          self.cluster_resolver.get_master(),
          graph=self.input_graph,
          config=self.config)
      self.input_sess.run(self.dataset_initializer)
      # Run infeed session.run calls
      self.input_sess.run([self.enqueue_ops])

    self.build_enqueue_ops(input_fn, params, 0)

    def get_tpu_step(mparams):
      """Get the TPU graph generation function."""

      def tpu_step(loss):
        """Generate the TPU graph."""
        del loss
        values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
        unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                        values)
        features = unflattened_inputs["features"]
        labels = unflattened_inputs["labels"]
        estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN,
                                  mparams)
        loss, train_op = estimator_spec.loss, estimator_spec.train_op
        with tf.device(device_for_tpu_core()):
          with tf.control_dependencies([train_op]):
            return tf.identity(loss)

      return tpu_step

    tpu_step = get_tpu_step(params)

    @tpu_function.on_device_training_loop
    def tpu_loop():
      return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS])

    (self.loss,) = tpu.shard(
        tpu_loop,
        inputs=[],
        num_shards=FLAGS.num_cores,
        outputs_from_all_shards=False,
    )
    initializer = tf.global_variables_initializer()
    self.saver = tf.train.Saver()
    graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True),
                         FLAGS.model_dir, "graph.pbtxt")

    # Build tpu train model session and initialize graph
    self.sess = tf.Session(
        self.cluster_resolver.get_master(),
        config=self.config)
    self.sess.run(initializer)

    # Complete infeed graph generation and session.run calls
    self.infeed_thread = threading.Thread(target=infeed_thread_fn)
    self.infeed_thread.start()
Example #7
0
    def initialize(self, input_fn, model_fn, params):
        """Build graphs for the TPU device and the input pipelines.

    Args:
      input_fn: Dataset input graph generation function
      model_fn: Model definition function
      params:  Parameters to input and model functions
    """

        tf.logging.info("TrainRunner: initialize method")

        with tf.device(self.device_for_host()):
            self.global_step = tflex.get_or_create_global_step()

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            tf.logging.info('Run infeed session.run calls')
            tflex.run(self.input_sess, [self.enqueue_ops])
            tf.logging.info('infeed session.run finished')

        for i in tqdm.trange(self.num_hosts):
            self.build_enqueue_ops(input_fn, params, i)

        # Build infeed sesssion
        self.input_sess = tf.Session(self.master,
                                     graph=self.input_graph,
                                     config=self.config)
        self.input_sess.run(self.dataset_initializer)
        tf.logging.info('Ensure infeed data has fully uploaded')
        tflex.flush(self.input_sess)

        def get_tpu_step(mparams):
            """Get the TPU graph generation function."""
            def tpu_pre(loss):
                """Generate the TPU graph."""
                del loss
                values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
                unflattened_inputs = data_nest.pack_sequence_as(
                    self.feature_structure, values)
                if "features" in unflattened_inputs and "labels" in unflattened_inputs:
                    features = unflattened_inputs["features"]
                    labels = unflattened_inputs["labels"]
                else:
                    features = unflattened_inputs
                    labels = None
                estimator_spec = model_fn(features, labels,
                                          tf.estimator.ModeKeys.TRAIN, mparams)
                return estimator_spec

            def tpu_make(estimator_spec):
                loss, train_op = estimator_spec.loss, estimator_spec.train_op
                with tf.device(device_for_tpu_core()):
                    with tf.control_dependencies([train_op]):
                        return tf.identity(loss, name="tpu_loss_op")

            def tpu_step(loss):
                estimator_spec = tpu_pre(loss)
                return tpu_make(estimator_spec)

            return tpu_pre, tpu_make, tpu_step

        tpu_pre, tpu_make, tpu_step = get_tpu_step(params)

        if False:
            with tf.Graph().as_default() as self.tpu_graph:
                params['use_tpu'] = False
                self.tpu_global_step = tflex.get_or_create_global_step()
                self.tpu_spec = tpu_pre(_INITIAL_LOSS)
                self.tpu_sess = tf.Session(self.master,
                                           config=self.config,
                                           graph=self.graph)
                import pdb
                pdb.set_trace()
                self.tpu_op = tpu_make(self.tpu_spec)
                params['use_tpu'] = True

        @tpu_function.on_device_training_loop
        def tpu_loop():
            return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS])
            #return tpu_step(_INITIAL_LOSS)

        (self.loss, ) = tpu.shard(
            tpu_loop,
            inputs=[],
            num_shards=self.num_cores,
            outputs_from_all_shards=False,
        )
        if FLAGS.restore_trainable_variables:
            self.var_list = tf.trainable_variables()
        else:
            self.var_list = tf.global_variables()
        self.model_dir = FLAGS.model_dir or train_flags.options().get(
            'model_dir')
        self.saver = None
        if self.model_dir is not None:
            self.saver = tf.train.Saver(var_list=self.var_list,
                                        keep_checkpoint_every_n_hours=0.5)
            # Why do this?
            # graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True), self.model_dir, "graph.pbtxt")

        # Build tpu train model session and initialize graph
        self.sess = tf.Session(self.master, config=self.config)
        self.initializer = [
            tf.local_variables_initializer(),
            tf.global_variables_initializer()
        ]
        self.extra_initializers = tf.get_collection(
            'tftorch_initializers'
        )  # a bit of a hack, but it gets the job done
        tflex.run(self.sess, self.initializer)
        tflex.run(self.sess, self.extra_initializers)

        if FLAGS.restore_dir is not None:
            ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir)
            if ckpt is None:
                #raise ValueError("restore_dir has no latest_checkpoint: %s" % repr(FLAGS.restore_dir))
                pass
            else:
                step = tflex.checkpoint_step(ckpt) or 0
                saver = tf.train.Saver(var_list=self.var_list,
                                       restore_sequentially=True)
                for x in self.var_list:
                    tf.logging.info('\t%s', repr(x))
                tf.logging.info('Restoring %s step %d', ckpt, step)
                saver.restore(self.sess, ckpt)
                tf.logging.info('Setting step %d', step)
                self.global_step.load(step, self.sess)
                tf.logging.info('Restoring %s step %d (done)', ckpt, step)
        self.cur_step = tflex.run(self.sess, self.global_step)

        # Complete infeed graph generation and session.run calls
        def null_fn():
            pass

        self.infeed_thread = threading.Thread(target=null_fn, daemon=True)
        self.infeed_thread.start()
        self.infeed_thread_fn = infeed_thread_fn
  def build_model(self, params, input_fn, model_fn, num_steps):
    """Build the TPU model and infeed enqueue ops."""

    iparams = {}
    iparams["batch_size"] = params["batch_size"] // FLAGS.num_cores

    def get_tpu_step(mparams):
      """Get the TPU graph generation function."""

      def tpu_step(loss, *args):
        """Generate the TPU graph."""
        del loss
        unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                        args)
        features = unflattened_inputs["features"]
        labels = unflattened_inputs["labels"]
        estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.EVAL,
                                  mparams)
        loss = estimator_spec.loss
        self.eval_metrics = estimator_spec.eval_metrics
        self.eval_tensors = estimator_spec.eval_metrics[1]
        with tf.device(device_for_tpu_core()):
          outfeed_enqueue_ops = tpu.outfeed_enqueue_tuple(self.eval_tensors)
          with tf.control_dependencies([outfeed_enqueue_ops]):
            return tf.identity(loss)

      return tpu_step

    infeed_queue = []

    def get_enqueue_ops_fn():
      """Generate the enqueue ops graph function."""

      def enqueue_ops_fn():
        """Generate the infeed enqueue ops graph."""

        per_host_sharded_inputs = []
        control_deps = []
        with tf.device(device_for_host()):
          for _ in range(FLAGS.num_cores):
            with tf.control_dependencies(control_deps):
              features, labels = self.iterator.get_next()
              self.feature_structure["features"] = features
              self.feature_structure["labels"] = labels
              flattened_inputs = data_nest.flatten(self.feature_structure)
              control_deps.extend(flattened_inputs)
              per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          infeed_queue.append(infeed)
          return infeed.generate_enqueue_ops(per_host_sharded_inputs,
                                             tpu_ordinal_function=tpu_ordinal_fn)

      return enqueue_ops_fn

    with tf.device(device_for_host()):
      dataset = input_fn(iparams)
      dataset = dataset.cache()  # Cache the fully-generated eval dataset.
      dataset = dataset.repeat()  # Repeat indefinitely for unknown # of evals.
      self.iterator = dataset.make_initializable_iterator()
      self.enqueue_ops = wrap_computation_in_while_loop(
          get_enqueue_ops_fn(), n=num_steps, parallel_iterations=1)

    tpu_step = get_tpu_step(params)

    @tpu_function.on_device_training_loop
    def tpu_loop():
      return tpu.repeat(
          num_steps, tpu_step, [_INITIAL_LOSS], infeed_queue=infeed_queue[0])

    def create_dequeue_ops():
      dequeue_ops = []
      tensor_dtypes = []
      tensor_shapes = []
      for v in self.eval_tensors:
        dequeue_ops.append([])
        tensor_dtypes.append(v.dtype)
        tensor_shapes.append(v.shape)
        tf.logging.info("appending %s" % v.name)
      for i in range(FLAGS.num_cores):
        with tf.device(device_for_host()):
          outfeed_tensors = tpu.outfeed_dequeue_tuple(
              dtypes=tensor_dtypes,
              shapes=tensor_shapes,
              device_ordinal=i)
          for j, item in enumerate(outfeed_tensors):
            dequeue_ops[j].append(item)
      for j in range(len(outfeed_tensors)):
        dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
      return dequeue_ops

    (self.loss,) = tpu.shard(
        tpu_loop,
        inputs=[],
        num_shards=FLAGS.num_cores,
        outputs_from_all_shards=False)

    self.dequeue_ops = create_dequeue_ops()
    with tf.device(device_for_host()):
      metrics = self.eval_metrics[0](*self.dequeue_ops)
    metric_update_ops = []
    metric_value_ops = {}
    for (k, v) in metrics.items():
      # print("k: ", k)
      # print("v: ", v)
      metric_update_ops.append(v[1])
      metric_value_ops[k] = v[0]
    self.metric_update_ops = metric_update_ops
    self.metric_value_ops = metric_value_ops

    self.metric_initializer = tf.variables_initializer(
        tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
Example #9
0
    def initialize(self, input_fn, model_fn, params):
        """Build graphs for the TPU device and the input pipelines.

    Args:
      input_fn: Dataset input graph generation function
      model_fn: Model definition function
      params:  Parameters to input and model functions
    """

        tf.logging.info("TrainRunner: initialize method")

        with tf.device(self.device_for_host()):
            self.global_step = tflex.get_or_create_global_step()

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            i = 1
            while i < FLAGS.num_cores // FLAGS.tpu_cores_per_host:
                self.build_enqueue_ops(input_fn, params, i)
                i += 1
            # Build infeed sesssion
            self.input_sess = tf.Session(self.cluster_resolver.get_master(),
                                         graph=self.input_graph,
                                         config=self.config)
            self.input_sess.run(self.dataset_initializer)
            tf.logging.info('Ensure infeed data has fully uploaded')
            tflex.flush(self.input_sess)
            tf.logging.info('Run infeed session.run calls')
            tflex.run(self.input_sess, [self.enqueue_ops])

        self.build_enqueue_ops(input_fn, params, 0)

        def get_tpu_step(mparams):
            """Get the TPU graph generation function."""
            def tpu_pre(loss):
                """Generate the TPU graph."""
                del loss
                values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
                unflattened_inputs = data_nest.pack_sequence_as(
                    self.feature_structure, values)
                features = unflattened_inputs["features"]
                labels = unflattened_inputs["labels"]
                estimator_spec = model_fn(features, labels,
                                          tf.estimator.ModeKeys.TRAIN, mparams)
                return estimator_spec

            def tpu_make(estimator_spec):
                loss, train_op = estimator_spec.loss, estimator_spec.train_op
                with tf.device(device_for_tpu_core()):
                    with tf.control_dependencies([train_op]):
                        return tf.identity(loss, name="tpu_loss_op")

            def tpu_step(loss):
                estimator_spec = tpu_pre(loss)
                return tpu_make(estimator_spec)

            return tpu_pre, tpu_make, tpu_step

        tpu_pre, tpu_make, tpu_step = get_tpu_step(params)

        if False:
            with tf.Graph().as_default() as self.tpu_graph:
                params['use_tpu'] = False
                self.tpu_global_step = tflex.get_or_create_global_step()
                self.tpu_spec = tpu_pre(_INITIAL_LOSS)
                self.tpu_sess = tf.Session(self.cluster_resolver.get_master(),
                                           config=self.config,
                                           graph=self.graph)
                import pdb
                pdb.set_trace()
                self.tpu_op = tpu_make(self.tpu_spec)
                params['use_tpu'] = True

        @tpu_function.on_device_training_loop
        def tpu_loop():
            return tpu.repeat(self.iterations, tpu_step, [_INITIAL_LOSS])
            #return tpu_step(_INITIAL_LOSS)

        (self.loss, ) = tpu.shard(
            tpu_loop,
            inputs=[],
            num_shards=FLAGS.num_cores,
            outputs_from_all_shards=False,
        )
        initializer = tf.global_variables_initializer()
        self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.5)
        graph_io.write_graph(tf.Graph().as_graph_def(add_shapes=True),
                             FLAGS.model_dir, "graph.pbtxt")

        # Build tpu train model session and initialize graph
        self.sess = tf.Session(self.cluster_resolver.get_master(),
                               config=self.config)
        tflex.run(self.sess, initializer)

        if FLAGS.restore_dir is not None:
            ckpt = tf.train.latest_checkpoint(FLAGS.restore_dir)
            if ckpt is not None:
                if FLAGS.restore_trainable_variables:
                    var_list = tf.trainable_variables()
                    if params['n_ctx'] != 1024:
                        var_list = [
                            x for x in var_list if '/wpe' not in x.name
                        ]
                else:
                    var_list = tf.global_variables()
                saver = tf.train.Saver(var_list=var_list,
                                       restore_sequentially=True)
                tf.logging.info('Restoring %s', ckpt)
                for x in var_list:
                    tf.logging.info('\t%s', repr(x))
                saver.restore(self.sess, ckpt)
                tf.logging.info('Restoring %s (done)', ckpt)
        self.cur_step = tflex.run(self.sess, self.global_step)

        # Complete infeed graph generation and session.run calls
        self.infeed_thread = threading.Thread(target=infeed_thread_fn,
                                              daemon=True)
        self.infeed_thread.start()
    def build_model(self, model_fn, params):
        """Build the TPU model and infeed enqueue ops."""
        tf.logging.info("DistEvalLowLevelRunner: build_model method")

        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.PREDICT, params)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(utils.device_for_tpu_core(self._get_host(0))):
                outfeed_enqueue_ops = tpu.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()

        @tpu_function.on_device_training_loop
        def eval_loop():
            return tpu.repeat(self.eval_steps, tpu_eval_step, [])

        def create_dequeue_ops(host_id):
            """Create outfeed dequeue ops."""
            dequeue_ops = []
            tensor_dtypes = []
            tensor_shapes = []
            for v in self.outfeed_tensors:
                tensor_dtypes.append(v.dtype)
                tensor_shapes.append(v.shape)
            with tf.device(utils.device_for_host(self._get_host(host_id))):
                for i in range(FLAGS.num_shards_per_host):
                    outfeed = tpu.outfeed_dequeue_tuple(dtypes=tensor_dtypes,
                                                        shapes=tensor_shapes,
                                                        device_ordinal=i)
                    if len(outfeed) == 2:
                        if outfeed[0].shape.ndims == 3:
                            detections, is_pad = outfeed
                        else:
                            is_pad, detections = outfeed
                        num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum(
                            tf.cast(is_pad, tf.int32))
                        dequeue_ops.append(
                            tf.slice(detections, [0, 0, 0],
                                     [num_non_pad, -1, -1]))
                    else:
                        dequeue_ops.append(outfeed)
                dequeue_ops = tf.concat(dequeue_ops, axis=0)
            return dequeue_ops

        with self.graph.as_default():
            (self.eval_op, ) = tpu.shard(
                eval_loop,
                inputs=[],
                num_shards=FLAGS.num_shards,
                outputs_from_all_shards=False,
            )

            # Get dequeue ops from each hosts.
            for i in range(self.num_hosts):
                self.dequeue_ops.append(create_dequeue_ops(i))

            self.saver = tf.train.Saver()