Esempio n. 1
0
def q1():
    global l_returnflag_group_size
    global l_linestatus_group_size
    returnflag_groups = np.unique(l_returnflag)
    linestatus_groups = np.unique(l_linestatus)
    l_returnflag_group_size = len(returnflag_groups)
    l_linestatus_group_size = len(linestatus_groups)
    inputs = [
        tf.convert_to_tensor(l_shipdate, np.float32),
        tf.convert_to_tensor(l_returnflag, np.float32),
        tf.convert_to_tensor(l_linestatus, np.float32),
        tf.convert_to_tensor(l_quantity, np.float32),
        tf.convert_to_tensor(l_extendedprice, np.float32),
        tf.convert_to_tensor(l_discount, np.float32),
        tf.convert_to_tensor(l_tax, np.float32),
        tf.convert_to_tensor(returnflag_groups, np.float32),
        tf.convert_to_tensor(linestatus_groups, np.float32)
    ]
    tpu_computation = tpu.rewrite(q1_computation, inputs)
    tpu_grpc_url = TPUClusterResolver(
        tpu=[os.environ['TPU_NAME']]).get_master()
    with tf.Session(tpu_grpc_url) as sess:
        sess.run(tpu.initialize_system())
        sess.run(tf.global_variables_initializer())
        for i in range(0, 5):
            res = sess.run(tpu_computation)
        sess.run(tpu.shutdown_system())
        print(res)
        return res
    def __init__(self, iterations, hparams, per_host_v1=False):
        tf.logging.info("TrainLowLevelRunner: constructor")

        self.feature_structure = {}
        self.loss = None
        self.infeed_queue = []
        self.enqueue_ops = []
        self.dataset_initializer = []
        self.is_local = ((hparams.master == "") and (hparams.tpu_name is None))
        self.per_host_v1 = per_host_v1
        self.iterations = iterations
        self.sess = None
        self.graph = tf.Graph()
        self.hparams = hparams
        with self.graph.as_default():
            self.tpu_init = [tpu.initialize_system()]
            self.tpu_shutdown = tpu.shutdown_system()

        self.resolver = get_resolver(hparams)
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        isolate_session_state=True)
        if self.hparams.tpu_name is None:
            master = self.hparams.master
        else:
            cluster_spec = self.resolver.cluster_spec()
            if cluster_spec:
                session_config.cluster_def.CopyFrom(
                    cluster_spec.as_cluster_def())
            master = self.resolver.get_master()
        self.sess = tf.Session(master, graph=self.graph, config=session_config)
        self.sess.run(self.tpu_init)
Esempio n. 3
0
    def __init__(self, iterations):
        tf.logging.info("TrainLowLevelRunner: constructor")

        self.feature_structure = {}
        self.loss = None
        self.infeed_queue = []
        self.enqueue_ops = []
        self.dataset_initializer = []
        self.iterations = iterations
        self.num_hosts = FLAGS.num_shards // FLAGS.num_shards_per_host
        self.scaffold_fn = None
        # Having two separate sessions and graphs to make the initialization faster.
        self.input_sess = None
        self.train_sess = None
        self.input_graph = tf.Graph()
        self.train_graph = None
        self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        # Disable grappler for better performance.
        self.session_config = tf.ConfigProto(
            allow_soft_placement=True,
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True)),
            isolate_session_state=True)
        cluster_spec = self.tpu_cluster_resolver.cluster_spec()
        if cluster_spec:
            self.session_config.cluster_def.CopyFrom(
                cluster_spec.as_cluster_def())
        self.tpu_init = [tpu.initialize_system()]
        self.tpu_shutdown = tpu.shutdown_system()
        self.init_sess = tf.Session(self.tpu_cluster_resolver.get_master(),
                                    config=self.session_config)
        self.init_sess.run(self.tpu_init)
        self.queue = Queue.Queue()
Esempio n. 4
0
    def execute_tpu(self, graph_fn, inputs):
        """Constructs the graph, executes it on TPU and returns the result.

    Args:
      graph_fn: a callable that constructs the tensorflow graph to test. The
        arguments of this function should correspond to `inputs`.
      inputs: a list of numpy arrays to feed input to the computation graph.

    Returns:
      A list of numpy arrays or a scalar returned from executing the tensorflow
      graph.
    """
        with self.test_session(graph=tf.Graph()) as sess:
            placeholders = [
                tf.placeholder_with_default(v, v.shape) for v in inputs
            ]
            tpu_computation = tpu.rewrite(graph_fn, placeholders)
            sess.run(tpu.initialize_system())
            sess.run([
                tf.global_variables_initializer(),
                tf.tables_initializer(),
                tf.local_variables_initializer()
            ])
            materialized_results = sess.run(tpu_computation,
                                            feed_dict=dict(
                                                zip(placeholders, inputs)))
            sess.run(tpu.shutdown_system())
            if (hasattr(materialized_results, '__len__')
                    and len(materialized_results) == 1
                    and (isinstance(materialized_results, list)
                         or isinstance(materialized_results, tuple))):
                materialized_results = materialized_results[0]
        return materialized_results
 def __init__(self, iterations, train_steps):
   tf.logging.info("TrainRunner: constructor")
   self.feature_structure = {}
   self.loss = None
   self.infeed_queue = []
   self.enqueue_ops = []
   self.dataset_initializer = []
   self.iterations = iterations
   self.sess = None
   self.input_sess = None
   self.infeed_thread = None
   if train_steps % iterations != 0:
     train_steps = iterations * int(math.ceil(train_steps / iterations))
   self.train_steps = train_steps
   self.input_graph = tf.Graph()
   tpu_init = [tpu.initialize_system()]
   self.tpu_shutdown = tpu.shutdown_system()
   self.cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
       FLAGS.tpu or FLAGS.master,
       zone=FLAGS.tpu_zone,
       project=FLAGS.gcp_project)
   self.config = tf.ConfigProto(operation_timeout_in_ms=600 * 60 * 1000,
                                graph_options=tf.GraphOptions(
                                    rewrite_options=rewriter_config_pb2.RewriterConfig(
                                        disable_meta_optimizer=True)),
                                isolate_session_state=True)
   cluster_spec = self.cluster_resolver.cluster_spec()
   if cluster_spec:
     self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
   self.init_sess = tf.Session(self.cluster_resolver.get_master(), config=self.config)
   self.init_sess.run(tpu_init)
Esempio n. 6
0
  def execute_tpu(self, graph_fn, inputs):
    """Constructs the graph, executes it on TPU and returns the result.

    Args:
      graph_fn: a callable that constructs the tensorflow graph to test. The
        arguments of this function should correspond to `inputs`.
      inputs: a list of numpy arrays to feed input to the computation graph.

    Returns:
      A list of numpy arrays or a scalar returned from executing the tensorflow
      graph.
    """
    with self.test_session(graph=tf.Graph()) as sess:
      placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
      tpu_computation = tpu.rewrite(graph_fn, placeholders)
      sess.run(tpu.initialize_system())
      sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
                tf.local_variables_initializer()])
      materialized_results = sess.run(tpu_computation,
                                      feed_dict=dict(zip(placeholders, inputs)))
      sess.run(tpu.shutdown_system())
      if (hasattr(materialized_results, '__len__') and
          len(materialized_results) == 1 and
          (isinstance(materialized_results, list) or
           isinstance(materialized_results, tuple))):
        materialized_results = materialized_results[0]
    return materialized_results
Esempio n. 7
0
  def __init__(self, sess, use_tpu, mesh_shape, layout_rules):
    super(MeshContext, self).__init__()
    self._use_tpu = use_tpu
    self._mesh_shape = mtf.convert_to_shape(mesh_shape)
    self._layout_rules = layout_rules

    self._d_assignment = None
    self._num_hosts = None
    self._num_cores = None

    self._cpu_devices, self._gpu_devices = self._list_cpu_gpu_devices(sess)

    if self._use_tpu:
      topology = sess.run(tpu.initialize_system())
      topo_object = tpu.Topology(serialized=topology)
      self._num_cores = int(np.prod(topo_object.mesh_shape))
      self._num_hosts = int(topo_object.num_tasks)
      num_cores_per_host = int(self._num_cores // self._num_hosts)
      assert num_cores_per_host == int(topo_object.num_tpus_per_task)

      # Get a device_assignment object for mtf.
      self._d_assignment = device_assignment.device_assignment(
          topology, computation_shape=[1, 1, 1],
          num_replicas=self._num_cores)

      self._mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          self._mesh_shape, self._layout_rules, None, self._d_assignment)
    else:
      self._mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          self._mesh_shape, self._layout_rules, self._gpu_devices)
Esempio n. 8
0
def run(size):
    a_ = []
    b_ = []
    c_ = []
    for i in range(size):
        a_.append((i * 1.0 + 4.0) * 2.5)
        b_.append((i * 1.0 + 5.0) * 2.5)
        c_.append((i * 1.0 + 6.0) * 0.1)

    inputs = [tf.constant(a_), tf.constant(b_), tf.constant(c_)]

    tpu_computation = tpu.rewrite(expression, inputs)
    tpu_grpc_url = TPUClusterResolver(
        tpu=[os.environ['TPU_NAME']]).get_master()

    with tf.Session(tpu_grpc_url) as sess:
        sess.run(tpu.initialize_system())
        t1 = time()
        sess.run(tf.global_variables_initializer())
        sess.run(tpu_computation)
        t2 = time()
        print(str(size) + " : " + str(t2 - t1))
        sess.run(tpu.shutdown_system())

    print('Done !')
Esempio n. 9
0
def run_graph(master, graph_spec, epoch):
    """Run graph_spec.graph with master."""
    tf.logging.info("Running graph for epoch {}...".format(epoch))
    with tf.Session(master, graph_spec.graph) as sess:
        tf.logging.info("Initializing system for epoch {}...".format(epoch))
        sess.run(
            tpu.initialize_system(
                embedding_config=graph_spec.embedding.config_proto))

        tf.logging.info("Running before hook for epoch {}...".format(epoch))
        graph_spec.hook_before(sess, epoch)

        tf.logging.info("Running infeed for epoch {}...".format(epoch))
        infeed_thread_fn = graph_spec.get_infeed_thread_fn(sess)
        infeed_thread = threading.Thread(target=infeed_thread_fn)
        tf.logging.info("Staring infeed thread...")
        infeed_thread.start()

        tf.logging.info("Running TPU loop for epoch {}...".format(epoch))
        graph_spec.run_tpu_loop(sess, epoch)

        tf.logging.info("Joining infeed thread...")
        infeed_thread.join()

        tf.logging.info("Running after hook for epoch {}...".format(epoch))
        graph_spec.hook_after(sess, epoch)
Esempio n. 10
0
def initialize_tpu(session=None, timeout_in_ms=None):
  session = session or get_default_session()
  with session.as_default():
    op = tpu.initialize_system()
  options = None
  if timeout_in_ms:
    options=config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
  return session.run(op, options=options)
def main(unused_argv):
    assert FLAGS.tpu_name
    if FLAGS.tpu_name.startswith('grpc://'):
        tpu_grpc_url = FLAGS.tpu_name
    else:
        tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=None, project=None)
        tpu_grpc_url = tpu_cluster_resolver.get_master()

    sess = tf.Session(tpu_grpc_url)
    with sess.graph.as_default():
      contrib_tpu.initialize_system()
      contrib_tpu.shutdown_system()

    output_names = ['ConfigureDistributedTPU', 'ShutdownDistributedTPU']
    model_def = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), output_names)
    print(model_def)
Esempio n. 12
0
 def __init__(self, iterations, train_steps=-1):
     tf.logging.info("TrainRunner: constructor")
     self.feature_structure = None
     self.loss = None
     self.infeed_queue = []
     self.enqueue_ops = []
     self.dataset_initializer = []
     self.iterations = iterations
     self.sess = None
     self.input_sess = None
     self.infeed_thread = None
     if train_steps < 0:
         train_steps = None
     if train_steps is not None:
         if train_steps % iterations != 0:
             train_steps = iterations * int(
                 math.ceil(train_steps / iterations))
     self.train_steps = train_steps
     self.input_graph = tf.Graph()
     with tf.Graph().as_default() as self.init_graph:
         self.tpu_init = tpu.initialize_system()
         self.tpu_shutdown = tpu.shutdown_system()
     #self.cluster_resolver = tflex.TPUClusterResolver(
     self.cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
         FLAGS.tpu or FLAGS.master,
         zone=FLAGS.tpu_zone,
         project=FLAGS.gcp_project)
     self.config = tf.ConfigProto(
         operation_timeout_in_ms=600 * 60 * 1000,
         graph_options=tf.GraphOptions(
             rewrite_options=rewriter_config_pb2.RewriterConfig(
                 disable_meta_optimizer=True)),
         isolate_session_state=True)
     cluster_spec = self.cluster_resolver.cluster_spec()
     if cluster_spec:
         self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
     self.master = self.cluster_resolver.get_master()
     self.init_sess = tf.Session(self.master,
                                 graph=self.init_graph,
                                 config=self.config)
     tf.logging.info("TrainRunner: initializing TPU session...")
     if not bool(int(os.environ.get('TPU_NO_INIT', '0'))):
         tflex.run(self.init_sess, self.tpu_init)
     tf.logging.info("TrainRunner: initializing TPU session (done)")
     self.devices = self.init_sess.list_devices()
     self.cores = sorted(
         [x.name for x in self.devices if ':TPU:' in x.name])
     self.num_cores = len(self.cores)
     self.tpu_cores_per_host = 8
     assert self.num_cores % self.tpu_cores_per_host == 0
     self.num_hosts = self.num_cores // self.tpu_cores_per_host
     print(self.config.cluster_def)
     print('cores: %d hosts: %d ip: %s' %
           (self.num_cores, self.num_hosts, self.master))
 def __init__(self, iterations, train_steps, eval_steps):
   tf.logging.info("TrainAndEvalRunner: constructor")
   self.feature_structure = {}
   self.eval_feature_structure = {}
   self.loss = None
   self.eval_loss = None
   self.infeed_queue = []
   self.eval_infeed_queue = []
   self.enqueue_ops = []
   self.num_hosts = FLAGS.num_cores // FLAGS.tpu_cores_per_host
   self.dequeue_ops = []
   self.queue = Queue.Queue()
   self.eval_enqueue_ops = []
   self.dataset_initializer = []
   self.eval_dataset_initializer = []
   self.iterations = iterations
   self.steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size
   self.iterator = None
   self.sess = None
   self.input_sess = None
   self.eval_input_sess = None
   self.eval_output_sess = None
   self.infeed_thread = None
   self.train_eval_thread = None
   self.graph = tf.Graph()
   self.input_graph = tf.Graph()
   self.eval_input_graph = tf.Graph()
   self.eval_output_graph = tf.Graph()
   if train_steps % iterations != 0:
     train_steps = iterations * int(math.ceil(train_steps / iterations))
   self.train_steps = train_steps
   self.max_train_iterations = self.train_steps // iterations
   self.eval_steps = int(eval_steps)
   self.eval_batch_size = FLAGS.eval_batch_size
   tpu_init = [tpu.initialize_system()]
   self.tpu_shutdown = tpu.shutdown_system()
   self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
       FLAGS.tpu or FLAGS.master,
       zone=FLAGS.tpu_zone,
       project=FLAGS.gcp_project)
   self.config = tf.ConfigProto(
       operation_timeout_in_ms=600 * 60 * 1000,
       allow_soft_placement=True,
       graph_options=tf.GraphOptions(
           rewrite_options=rewriter_config_pb2.RewriterConfig(
               disable_meta_optimizer=True)),
       isolate_session_state=True)
   cluster_spec = self.tpu_cluster_resolver.cluster_spec()
   if cluster_spec:
     self.config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
   self.master = self.tpu_cluster_resolver.get_master()
   self.init_sess = tf.Session(self.master, config=self.config)
   self.init_sess.run(tpu_init)
Esempio n. 14
0
def filter_sum():
    inputs = [tf.convert_to_tensor(l_quantity, np.float32)]
    tpu_computation = tpu.rewrite(filter_sum_computation, inputs)
    tpu_grpc_url = TPUClusterResolver(
        tpu=[os.environ['TPU_NAME']]).get_master()
    with tf.Session(tpu_grpc_url) as sess:
        sess.run(tpu.initialize_system())
        sess.run(tf.global_variables_initializer())
        for i in range(0, 5):
            res = sess.run(tpu_computation)
        sess.run(tpu.shutdown_system())
        print(res)
        return res
Esempio n. 15
0
def apply_comp(inputs):
    tpu_computation = tpu.rewrite(apply, inputs)
    tpu_grpc_url = TPUClusterResolver(
        tpu=[os.environ['TPU_NAME']]).get_master()

    with tf.Session(tpu_grpc_url) as sess:
        sess.run(tpu.initialize_system())
        sess.run(tf.global_variables_initializer())
        t1 = time()
        sess.run(tpu_computation)
        t2 = time()
        sess.run(tpu.shutdown_system())
    print(t2 - t1)
Esempio n. 16
0
def init_tpu(name, host=None, timeout_in_ms=600 * 60 * 1000):
  tpu_init = [tpu.initialize_system()]
  cluster_resolver = TPUClusterResolver(name, host=host)
  config = tf.ConfigProto(operation_timeout_in_ms=timeout_in_ms,
                          graph_options=tf.GraphOptions(
                            rewrite_options=rewriter_config_pb2.RewriterConfig(
                              disable_meta_optimizer=True)),
                          isolate_session_state=True)
  cluster_spec = cluster_resolver.cluster_spec()
  if cluster_spec:
    config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
  init_sess = tf.Session(cluster_resolver.get_master(), config=config)
  init_sess.run(tpu_init)
  return init_sess, cluster_resolver
Esempio n. 17
0
def run(tpu_computation, tpu_grpc_url):

    reps = 1
    times = []

    for i in range(reps):
        with tf.Session(tpu_grpc_url) as sess:
            sess.run(tpu.initialize_system())
            t1 = time()
            sess.run(tf.global_variables_initializer())
            sess.run(tpu_computation)
            t2 = time()
            print(str(i) + "_ : " + str(t2 - t1))
            times.append(t2 - t1)
            sess.run(tpu.shutdown_system())

    print(sum(times) / reps)
Esempio n. 18
0
def group_by():
    unique_groups = np.unique(l_returnflag)
    inputs = [
        tf.convert_to_tensor(l_quantity, np.float32),
        tf.convert_to_tensor(l_returnflag, np.float32),
        tf.convert_to_tensor(unique_groups, np.float32)
    ]
    tpu_computation = tpu.rewrite(group_by_computation, inputs)
    tpu_grpc_url = TPUClusterResolver(
        tpu=[os.environ['TPU_NAME']]).get_master()
    with tf.Session(tpu_grpc_url) as sess:
        sess.run(tpu.initialize_system())
        sess.run(tf.global_variables_initializer())
        for i in range(0, 5):
            res = sess.run(tpu_computation)
        sess.run(tpu.shutdown_system())
        print(res)
Esempio n. 19
0
    def _run_tpu_computation(self):
        """Attempt to run computation graph directly on TPU."""
        def _computation_fn(alpha, x, y):
            return alpha * x + y

        alpha = tf.Variable(3.0, name='alpha')
        x = tf.Variable(tf.ones([3, 3], tf.float32), name='x')
        y = tf.Variable(tf.ones([3, 3], tf.float32), name='y')

        result = contrib_tpu.rewrite(_computation_fn, [alpha, x, y])

        with tf.Session('grpc://{0}:8470'.format(self.tpu_ip)) as sess:
            sess.run(contrib_tpu.initialize_system())
            sess.run(tf.global_variables_initializer())
            logging.info(sess.run(result))
            sess.run(tpu.shutdown_system())
            logging.info('Output should be a 3x3 matrix with all 4s.')
        self.tpu_computation = 'Passed'
        logging.info('Successfully ran a computation on the TPU')
    def test_large_input(self):
        if test_case.FLAGS.tpu_test:
            input_size = 1408
            min_level = 2
            max_level = 6
            batch_size = 2
            num_boxes = 512
            num_filters = 256
            output_size = [7, 7]
            with self.test_session() as sess:
                features = []
                for level in range(min_level, max_level + 1):
                    feat_size = int(input_size / 2**level)
                    features.append(
                        tf.constant(np.reshape(
                            np.arange(batch_size * feat_size * feat_size *
                                      num_filters,
                                      dtype=np.float32),
                            [batch_size, feat_size, feat_size, num_filters]),
                                    dtype=tf.bfloat16))
                boxes = np.array([
                    [[0, 0, 256, 256]] * num_boxes,
                ],
                                 dtype=np.float32) / input_size
                boxes = np.tile(boxes, [batch_size, 1, 1])
                tf_boxes = tf.constant(boxes)
                tf_levels = tf.random_uniform([batch_size, num_boxes],
                                              maxval=5,
                                              dtype=tf.int32)

                def crop_and_resize_fn():
                    return spatial_ops.multilevel_roi_align(
                        features, tf_boxes, tf_levels, output_size)

                tpu_crop_and_resize_fn = contrib_tpu.rewrite(
                    crop_and_resize_fn)
                sess.run(contrib_tpu.initialize_system())
                sess.run(tf.global_variables_initializer())
                roi_features = sess.run(tpu_crop_and_resize_fn)
                self.assertEqual(roi_features[0].shape,
                                 (batch_size, num_boxes, output_size[0],
                                  output_size[1], num_filters))
                sess.run(contrib_tpu.shutdown_system())
Esempio n. 21
0
  def __init__(self, eval_steps):
    tf.logging.info("EvalLowLevelRunner: constructor")
    tf.logging.info("eval_steps: %s", eval_steps)

    self.feature_structure = {}
    self.infeed_queue = []
    self.enqueue_ops = []
    self.dataset_initializer = []
    self.eval_steps = eval_steps
    self.sess = None
    self.eval_op = None
    self.graph = tf.Graph()
    self.outfeed_tensors = []
    self.outfeed_names = []
    self.dequeue_ops = {}
    self.saver = None
    self.tpu_cluster_resolver = None
    with self.graph.as_default():
      self.tpu_init = [tpu.initialize_system()]
      self.tpu_shutdown = tpu.shutdown_system()
Esempio n. 22
0
def timer(inputs):
    reps = 2
    times = []

    for i in range(reps):
        t1 = time()
        tpu_computation = tpu.rewrite(blackscholes, inputs)
        tpu_grpc_url = TPUClusterResolver(
            tpu=[os.environ['TPU_NAME']]).get_master()

        with tf.Session(tpu_grpc_url) as sess:
            sess.run(tpu.initialize_system())
            sess.run(tf.global_variables_initializer())
            sess.run(tpu_computation)
            sess.run(tpu.shutdown_system())

        t2 = time()
        print(str(i) + "_ : " + str(t2 - t1))
        times.append(t2 - t1)

    print(sum(times) / reps)
  def __init__(self, input_fn, model_fn, params, num_steps):
    self.feature_structure = {}
    self.loss = None
    self.enqueue_ops = None
    self.metric_initializer = None
    self.iterator = None
    self.batch_size = params["batch_size"]
    with tf.Graph().as_default() as self.graph:
      self.build_model(params, input_fn, model_fn, num_steps)
      self.tpu_init = tpu.initialize_system()
      initializer = tf.global_variables_initializer()
      self.tpu_shutdown = tpu.shutdown_system()
      self.local_initializer = tf.local_variables_initializer()
      self.saver = tf.train.Saver()

    cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu or FLAGS.master)
    self.sess = tf.Session(cluster_resolver.get_master(), graph=self.graph)
    self.sess.run(self.tpu_init)
    self.sess.run(initializer)
    self.sess.run(self.local_initializer)
    self.sess.run(self.iterator.initializer)
Esempio n. 24
0
    def execute_tpu_tf1(self, compute_fn, inputs, graph=None):
        """Executes compute_fn on TPU with Tensorflow 1.X.

    Args:
      compute_fn: a function containing Tensorflow computation that takes a list
        of input numpy tensors, performs computation and returns output numpy
        tensors.
      inputs: a list of numpy arrays to feed input to the `compute_fn`.
      graph: (optional) If not None, provided `graph` is used for computation
        instead of a brand new tf.Graph().

    Returns:
      A list of numpy arrays or a single numpy array.
    """
        with self.session(graph=(graph or tf.Graph())) as sess:
            placeholders = [
                tf.placeholder_with_default(v, v.shape) for v in inputs
            ]

            def wrap_graph_fn(*args, **kwargs):
                results = compute_fn(*args, **kwargs)
                if (not (isinstance(results, dict)
                         or isinstance(results, tf.Tensor))
                        and hasattr(results, '__iter__')):
                    results = list(results)
                return results

            tpu_computation = contrib_tpu.rewrite(wrap_graph_fn, placeholders)
            sess.run(contrib_tpu.initialize_system())
            sess.run([
                tf.global_variables_initializer(),
                tf.tables_initializer(),
                tf.local_variables_initializer()
            ])
            materialized_results = sess.run(tpu_computation,
                                            feed_dict=dict(
                                                zip(placeholders, inputs)))
            sess.run(contrib_tpu.shutdown_system())
        return self.maybe_extract_single_output(materialized_results)
    def __init__(self, eval_steps, hparams):
        tf.logging.info("EvalLowLevelRunner: constructor")
        tf.logging.info("eval_steps: %s", eval_steps)

        self.feature_structure = {}
        self.infeed_queue = []
        self.enqueue_ops = []
        self.dataset_initializer = []
        self.is_local = ((hparams.master == "") and (hparams.tpu_name is None))
        self.eval_steps = eval_steps
        self.sess = None
        self.eval_op = None
        self.graph = tf.Graph()
        self.hparams = hparams
        self.outfeed_tensors = []
        self.outfeed_names = []
        self.dequeue_ops = {}
        self.saver = None
        with self.graph.as_default():
            self.tpu_init = [tpu.initialize_system()]
            self.tpu_shutdown = tpu.shutdown_system()

        self.resolver = get_resolver(hparams)
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        operation_timeout_in_ms=600 * 60 *
                                        1000)  # 10 hours

        if self.hparams.tpu_name is None:
            master = self.hparams.master
        else:
            cluster_spec = self.resolver.cluster_spec()
            if cluster_spec:
                session_config.cluster_def.CopyFrom(
                    cluster_spec.as_cluster_def())
            master = self.resolver.get_master()

        self.sess = tf.Session(master, graph=self.graph, config=session_config)
        self.sess.run(self.tpu_init)
Esempio n. 26
0
def train_and_eval():
    """Trains and evaluates MeshTensorflow model without TPUEstimator.

  TODO(lehou): Pack everything nicely as a set of APIs.
  """
    tf.logging.info('FLAGS.master: {}'.format(FLAGS.master))

    # Open a session to get the list of CPU devices to hold master variables.
    with tf.Session(target=FLAGS.master,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        topology = sess.run(tpu.initialize_system())
        cpu_devices = _list_cpu_devices(sess)

    topo_object = tf.contrib.tpu.Topology(serialized=topology)
    num_cores = int(np.prod(topo_object.mesh_shape))
    num_hosts = int(topo_object.num_tasks)
    num_cores_per_host = int(num_cores // num_hosts)
    assert num_cores_per_host == int(topo_object.num_tpus_per_task)

    # Get a device_assignment object for mtf.
    d_assignment = device_assignment.device_assignment(
        topology, computation_shape=[1, 1, 1], num_replicas=num_cores)

    # Get mesh_impl.
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = unet.get_layout()
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None,
                                                d_assignment)

    for _ in range(FLAGS.num_training_loops):
        _train_phase(mesh_impl, cpu_devices, d_assignment, num_hosts,
                     num_cores)
        _eval_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores)

    _shutdown()

    tf.logging.info('finished.')
Esempio n. 27
0
def run_monitored_session(cross_entropy, log_dir, required_steps, class_range,
                          save_checkpoint_steps, validation_steps,
                          train_step,
                          augmentation_info, device,
                          training_nn_params, training_tensor,
                          testing_nn_params, testing_tensor,
                          validation_nn_params, validation_tensor):
    read_op_value = None
    augmentation_restorer = None
    if augmentation_info.perform_shadow_augmentation:
        if augmentation_info.shadow_struct is not None or augmentation_info.shadow_struct.shadow_op_initializer is not None:
            augmentation_restorer = augmentation_info.shadow_struct.shadow_op_creater()
            # Ready ops are overriden, as default ready ops awaits all variables to be initialized
            # but actually some of the variables(such as cycle-gan graphs) are not initialized but restored
            read_op_value = constant([])

    is_gpu_or_cpu = (device == "gpu" or device == "cpu")
    if is_gpu_or_cpu:
        config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        set_all_gpu_config()
        master = ''
    else:
        config = None
        tpu_worker = 'grpc://' + os.environ['COLAB_TPU_ADDR']
        # master = TPUClusterResolver(tpu=tpu_worker).get_master()
        master = tpu_worker
        print("TPU master")
        print(master)

    validation_hook = ValidationHook(validation_nn_params, validation_tensor, class_range, required_steps,
                                     validation_steps,
                                     log_dir)
    test_iteration_count = 100
    test_hook = TestHook(testing_nn_params, testing_tensor, cross_entropy, test_iteration_count, class_range)
    initializer_hook = InitializerHook(training_nn_params, training_tensor, augmentation_info, augmentation_restorer)
    stop_on_step_hook = StopAtStepHook(last_step=required_steps - 1)
    nan_tensor_hook = NanTensorHook(loss_tensor=cross_entropy, fail_on_nan_loss=False)

    hooks = [initializer_hook,
             validation_hook,
             test_hook,
             stop_on_step_hook,
             nan_tensor_hook]

    if is_gpu_or_cpu:
        # Only restore nn core variables along with the optimizer and global step variables
        nn_core_restorer = tf.train.Saver(
            max_to_keep=20,
            var_list=slim.get_variables_to_restore(include=["nn_core"]) +
                     slim.get_variables_to_restore(include=["global_step"]) +
                     slim.get_variables_to_restore(include=["training_optimizer"]), name="nn_core_restorer")
        training_scaffold = Scaffold(saver=nn_core_restorer,
                                     ready_for_local_init_op=read_op_value,
                                     ready_op=read_op_value)

        session = tf.train.MonitoredTrainingSession(master=master,
                                                    checkpoint_dir=log_dir,
                                                    summary_dir=log_dir,
                                                    config=config, is_chief=True,
                                                    save_summaries_steps=test_iteration_count,
                                                    save_checkpoint_steps=save_checkpoint_steps,
                                                    scaffold=training_scaffold,
                                                    hooks=hooks)
        # session = LocalCLIDebugWrapperSession(session)
        with session as monitored_sess:
            while not monitored_sess.should_stop():
                monitored_sess.run([train_step])
    else:
        session = tf.Session(target=master, config=config)
        session.run(tpu.initialize_system())
        session.run(tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer()))
        initializer_hook.after_create_session(session, None)
        while session.run(test_hook._global_step_tensor) < required_steps:
            try:
                session.run(train_step)
                test_hook.after_run_with_session(session)
            except tf.errors.OutOfRangeError:
                break

        validation_hook.end(session)
        session.run(tpu.shutdown_system())
        session.close()

    result = TrainingResult(validation_accuracy=validation_hook.validation_accuracy,
                            test_accuracy=test_hook.testing_accuracy, loss=test_hook.loss)
    return result
Esempio n. 28
0
    def __init__(self,
                 iterations,
                 num_cores_per_shard=1,
                 input_partition_dims=None):
        tf.logging.info("TrainLowLevelRunner: constructor")

        self.feature_structure = {}
        self.loss = None
        self.infeed_queue = []
        self.enqueue_ops = []
        self.dataset_initializer = []
        self.iterations = iterations
        # TODO(wangtao): change FLAGS.num_shards_per_host to
        # FLAGS.num_cores_per_host after other low level API
        # support spatial partition. FLAGS.num_shards_per_host means number of TPU
        # cores for each host.
        self.replicas_per_worker = FLAGS.num_shards_per_host // num_cores_per_shard
        self.num_hosts = FLAGS.num_shards * num_cores_per_shard // FLAGS.num_shards_per_host
        self.num_shards = FLAGS.num_shards
        self.scaffold_fn = None
        # Having two separate sessions and graphs to make the initialization faster.
        self.input_sess = None
        self.train_sess = None
        self.input_graph = tf.Graph()
        self.train_graph = None
        self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        # Disable grappler for better performance.
        self.session_config = tf.ConfigProto(
            allow_soft_placement=True,
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True)),
            isolate_session_state=True)
        cluster_spec = self.tpu_cluster_resolver.cluster_spec()
        if cluster_spec:
            self.session_config.cluster_def.CopyFrom(
                cluster_spec.as_cluster_def())
        self.tpu_init = tpu.initialize_system()
        self.tpu_shutdown = tpu.shutdown_system()
        self.init_sess = tf.Session(self.tpu_cluster_resolver.get_master(),
                                    config=self.session_config)
        self.queue = Queue.Queue()

        # Init for spatial partitioning.
        self.device_topology = self.init_sess.run(self.tpu_init)
        self.input_partition_dims = input_partition_dims
        self.use_spatial_partition = (
            input_partition_dims is not None
            and int(np.prod(FLAGS.input_partition_dims)) > 1)
        self.num_cores_per_shard = num_cores_per_shard
        if self.use_spatial_partition:
            computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
                self.num_cores_per_shard]
            self.device_assignment = tpu_device_assignment.device_assignment(
                topology=self.device_topology,
                computation_shape=computation_shape,
                num_replicas=self.num_shards)
            tf.logging.info("num_cores_per_shard: %d",
                            self.num_cores_per_shard)
            tf.logging.info("num_hosts: %d", self.num_hosts)
            tf.logging.info("replicas_per_worker: %d",
                            self.replicas_per_worker)
            tf.logging.info("computation_shape: %s", str(computation_shape))
            tf.logging.info("num_shards: %d", self.num_shards)
            tf.logging.info(
                "device_assignment.topology.device_coordinates: %s",
                str(self.device_assignment.topology.device_coordinates))
            tf.logging.info("device_assignment.core_assignment: %s",
                            str(self.device_assignment.core_assignment))
        else:
            self.device_assignment = None
Esempio n. 29
0
def tf_train_flow(
        train_once_fn,
        model_dir=None,
        log_dir=None,
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=None,
        freeze_graph=False,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        valid_interval_epochs=0,
        inference_fn=None,
        inference_interval_epochs=0,
        init_fn=None,
        restore_fn=None,
        restore_include=None,
        restore_exclude=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        output_collection_names=None,
        output_node_names=None,
        learning_rate=None,  #not use yet, just use in train_once
        learning_rate_patience=None,
        learning_rate_decay_factor=None,
        write_during_train=True,
        model=None,
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    model_dir_ = model_dir
    if use_horovod and hvd.rank() != 0:
        model_dir = None

    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()

    if FLAGS.use_tpu:
        sess.run(tpu.initialize_system())
    #logging.info('tf_train_flow start')
    #logging.info('max_models_keep:', max_models_keep)
    #logging.info('save_interval_seconds:', save_interval_seconds)

    if model_dir:
        if model:
            checkpoint = tf.train.Checkpoint(model=model)
            ckpt_dir = model_dir + '/ckpt'
            checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

        #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
        #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

        # var_list = None if not restore_scope else tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
        # #logging.info('-------------var_list', var_list)

        # if not variables_to_restore:
        #   variables_to_restore = var_list

        if not variables_to_restore:
            variables_to_restore = slim.get_variables_to_restore(
                include=restore_include, exclude=restore_exclude)

        if not variables_to_save:
            variables_to_save = variables_to_restore
        if save_all_scope:
            variables_to_save = None

        #if variables_to_restore is None:
        logging.info('variables_to_restore from %s' % model_dir)
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))

        # TODO has someproblem say  tf.Variable 'r_net/text_encoder/cudnn_rnn/cu_dnngru/recurrent_kernel/adam_v:0' even though in checkpoint I have renated it as ignore/rnet
        variables_to_restore_from_model = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)
        #logging.info('variables_to_restore_from_model: {}'.format(variables_to_restore_from_model))
        if not variables_to_restore:
            variables_to_restore = variables_to_restore_from_model
        else:
            variables_to_restore = [
                v for v in variables_to_restore
                if v in variables_to_restore_from_model
            ]
        if restore_exclude:
            for excl in restore_exclude:
                variables_to_restore = [
                    v for v in variables_to_restore if not excl in v.name
                ]
        #--tf 1.6 adadelta will have same vars...
        variables_to_restore = list(set(variables_to_restore))
        #logging.info('variables_to_restore', variables_to_restore[:100])
        logging.info('variables_to_restore', [
            x for x in variables_to_restore if not 'OptimizeLoss' in x.name
        ][:100])

    ##finally remove global_step since melt.apps.train will handle it!
    global_step = tf.train.get_or_create_global_step()

    #variables_to_restore = [v for v in variables_to_restore if not tf.GraphKeys.GLOBAL_STEP in v.name]
    #variables_to_restore = [v for v in variables_to_restore if not 'learning_rate' in v.name]

    # TODO fixme if step, step2.. and in checkpoint step then here will be step2...
    #print('------------', [v for v in variables_to_restore if 'step' in v.name])
    loader = tf.train.Saver(var_list=variables_to_restore)

    logging.info('max models to keep {}, keep every {} hours'.format(
        max_models_keep, save_interval_seconds / 3600.0))
    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)
    #logging.info('variables_to_save:{}'.format(variables_to_save))

    # # #TODO for safe restore all init will be ok ?
    # # if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # # else:
    # #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    # #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    # https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables
    # def guarantee_initialized_variables(session, list_of_variables = None):
    #   if list_of_variables is None:
    #       list_of_variables = tf.global_variables()
    #   uninitialized_variables = list(tf.get_variable(name) for name in
    #                                  session.run(tf.report_uninitialized_variables(list_of_variables)))
    #   return unintialized_variables

    # unintialized_variables = guarantee_initialized_variables(sess)
    # init_op = tf.group(tf.initialize_variables(uninitialized_vars), tf.local_variables_initializer())

    timer = gezi.Timer('sess run init_op in melt.tf_train_flow')
    #model.save('./weights')

    # notice
    sess.run(init_op)

    timer.print_elapsed()

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if you change batch size
    #print(model_dir)
    pre_epoch = None
    if model_dir:
        model_path = _get_model_path(model_dir, save_model)
        # if not model_path:
        #   model_path = _get_model_path(os.path.join(model_dir, 'epoch'))
        #print(model_path)
        model_dir = gezi.get_dir(
            model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model

        if model_path is not None:
            if not restore_from_latest:
                logging.info('using recent but not latest model')
                model_path = melt.recent_checkpoint(model_dir)
            model_name = os.path.basename(model_path)
            timer = gezi.Timer(
                'Loading and training from existing model [%s]' % model_path)
            if restore_fn is not None:
                restore_fn(sess)
            loader.restore(sess, model_path)
            ## not supported
            #model.save()
            #model.save_weights('./weights')
            timer.print()
            #pre_step = melt.get_model_step(model_path) - 1 if FLAGS.global_step is None else FLAGS.global_step -1
            # TODO check ..
            pre_step = sess.run(tf.train.get_global_step()) - 1
            pre_epoch = melt.get_model_epoch(
                model_path
            ) if FLAGS.global_epoch is None else FLAGS.global_epoch
            fixed_pre_step = pre_step
            # if pre_epoch is not None:
            #   #like using batch size 32, then reload train using batch size 64
            #   if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
            #     fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
            #     logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
            #       .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
        else:
            latest_checkpoint = None
            if not use_horovod:  #now will hang
                try:
                    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
                    if latest_checkpoint:
                        logging.info(
                            'Try start from eager trained mode, latest checkpoint:',
                            latest_checkpoint)
                        checkpoint.restore(latest_checkpoint).run_restore_ops(
                            session=sess)

                        pre_epoch = int(latest_checkpoint.split('-')[-1])
                        #pre_step = pre_epoch * num_steps_per_epoch - 1
                        # TODO check
                        pre_step = sess.run(tf.train.get_global_step()) - 1
                        fixed_pre_step = pre_step
                        logging.info('Start step is:', pre_step)
                except Exception:
                    logging.info(
                        'Something wrong with restore from eager trained model'
                    )
                if latest_checkpoint is None:
                    logging.info('Train all start step 0')
                    #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
                    #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
                    #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
                    #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
                    #init_op = tf.group(tf.global_variables_initializer(),
                    #                   tf.local_variables_initializer())
                    #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

                    #sess.run(init_op)

                    #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
                    #for finetune from loading other model init
                    if init_fn is not None:
                        init_fn(sess)

    if gezi.env_has('METRIC'):
        l = metric_eval_fn(model_path)
        print(list(zip(l[1], l[0])))
        exit(0)

    #sess.run(tf.assign(global_step, tf.constant(global_step_val, dtype=tf.int64)))
    try:
        learning_rate = tf.get_collection('learning_rate')[-1]
        learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
        sess.run(tf.assign(learning_rate,
                           learning_rate * learning_rate_weight))
    except Exception:
        # if not using weight_decay but using optimizer decay then will go here as learning rate is a tensor can not assign
        pass

    try:
        logging.info('Actual start global step:',
                     sess.run(global_step), 'learning rate:',
                     sess.run(learning_rate), 'learning_rate_weight:',
                     sess.run(learning_rate_weight))
    except Exception:
        pass

    if model_dir_:
        #if save_interval_epochs and num_steps_per_epoch and num_steps >= 0:
        epoch_dir = os.path.join(model_dir_, 'epoch')
        gezi.try_mkdir(epoch_dir)
        checkpoint_path = os.path.join(model_dir_, 'model.ckpt')

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if use_horovod:
        bcast = hvd.broadcast_global_variables(0)
        sess.run(bcast)

    #tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        if use_horovod:
            ## TODO FIXME why bcast here not work ? simple test work see tests/bcast.py
            #comm.bcast(pre_step, root=0)
            temp = np.array([pre_step, fixed_pre_step])
            comm.Bcast(temp, root=0)
            pre_step = temp[0]
            fixed_pre_step = temp[1]

        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1

        #first = True

        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            logging.info('just load and resave then exit')
            model_path_ = _get_checkpoint_path(checkpoint_path,
                                               step,
                                               num_steps_per_epoch,
                                               epoch=pre_epoch)
            saver.save(sess, model_path_, global_step=step + 1)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step + 1,
                                  output_collection_names, output_node_names)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            logging.info('just run one step')

        if FLAGS.work_mode != 'train':
            assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
            if 'valid' in FLAGS.work_mode:
                vals, names = metric_eval_fn(FLAGS.model_dir)
                logging.info(list(zip(names, vals)))
            if 'test' in FLAGS.work_mode:
                inference_fn(FLAGS.model_dir)
            exit(0)

        #early_stop = True #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        epoch_saved_step = 0
        while not coord.should_stop():
            model_step_path = None
            if model_dir_:
                model_path_ = os.path.join(
                    epoch_dir, 'model.ckpt-%.2f' %
                    (fixed_step / float(num_steps_per_epoch)))
                model_step_path_ = model_path_ + '-' + str(step)
                if (write_during_train and metric_eval_fn is not None
                        and valid_interval_epochs and fixed_step %
                        int(num_steps_per_epoch * valid_interval_epochs) == 0):
                    model_step_path = model_step_path_
                else:
                    model_step_path = None

            if step == 0:
                model_step_path = None

            #print('--------------------step', step)
            stop = train_once_fn(
                sess,
                step,
                is_start=(step == start),
                fixed_step=fixed_step,
                num_epochs=num_epochs,
                model_path=model_step_path,
                use_horovod=use_horovod,
                ## TODO FIXME this line will cause   tensorflow.python.framework.errors_impl.NotFoundError: Resource localhost/save_counter/N10tensorflow3VarE does not exist.
            )

            #first = False

            if only_one_step:
                stop = True

            step += 1
            fixed_step += 1

            if save_model and step and model_dir:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer(
                        'save model step %d to %s' % (step, checkpoint_path),
                        False)
                    model_path_ = _get_checkpoint_path(checkpoint_path,
                                                       fixed_step,
                                                       num_steps_per_epoch)
                    saver.save(sess, model_path_, global_step=step)
                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)
                    #if log_dir != model_dir:
                    #  assert log_dir
                    #  command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                    #  print(command, file=sys.stderr)
                    #  os.system(command)
                    timer.print_elapsed()

                if save_interval_steps and num_steps_per_epoch and fixed_step % int(
                        num_steps_per_epoch * save_interval_epochs) == 0:
                    # TODO only epoch in name not sep ?
                    epoch_saved_step = step
                    model_path_ = os.path.join(
                        epoch_dir, 'model.ckpt-%.2f' %
                        (fixed_step / float(num_steps_per_epoch)))
                    model_step_path = model_path_ + '-' + str(step)
                    epoch_saver.save(sess, model_path_, global_step=step)
                    #epoch_saver.save(sess, model_path_)

                    ## TODO FIXME do not support tf.keras save currently with horovod
                    # if model:
                    #   #model.save_weights(epoch_dir + '/ckpt-%.2f' % (fixed_step / float(num_steps_per_epoch)))
                    #   # TODO FIXME if restart will save from 1... again..
                    #   checkpoint.save(checkpoint_prefix, session=sess)
                    #   #print(sess.run(checkpoint.save_counter))

                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)

                if write_during_train:
                    if inference_fn is not None and inference_interval_epochs and fixed_step % int(
                            num_steps_per_epoch *
                            inference_interval_epochs) == 0:
                        model_step_path = model_path_ + '-' + str(step)
                        try:
                            #print('--------------inference fn')
                            inference_fn(model_path=model_step_path)
                        except Exception:
                            logging.info(traceback.format_exc())

                    # if metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0:
                    #   model_step_path = model_path_ + '-' + str(step)
                    #   try:
                    #     metric_eval_fn(model_path=model_step_path)
                    #   except Exception:
                    #     logging.info(traceback.format_exc())

            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            #if max_num_epochs and num_steps_per_epoch and fixed_step // num_steps_per_epoch >= max_num_epochs:
            if max_num_epochs and num_steps_per_epoch and fixed_step / num_steps_per_epoch > max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
    #except tf.errors.OutOfRangeError, e:
    except tf.errors.OutOfRangeError:
        # if run 2 epoch and we have just epoch saved, do not need to save only 1 step more model
        if (step - epoch_saved_step > 1) and not (
                step == start
        ) and save_model and step % save_interval_steps != 0 and model_dir:
            model_path_ = _get_checkpoint_path(checkpoint_path, step,
                                               num_steps_per_epoch)
            saver.save(sess, model_path_, global_step=step)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step,
                                  output_collection_names, output_node_names)
            if log_dir != model_dir:
                assert log_dir
                command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                print(command, file=sys.stderr)
                os.system(command)
        if only_one_step:
            logging.info('Done one step')
            exit(0)

        # if (step - epoch_saved_step > 1) and metric_eval_fn is not None:
        #   metric_eval_fn(model_path=model_step_path)

        if (num_epochs and fixed_step / num_steps_per_epoch >= num_epochs) or (
                num_steps and step == start + num_steps):
            logging.info('Done training for %.3f epochs, %d steps.' %
                         (fixed_step / num_steps_per_epoch, step))
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            logging.info('Should not stop, but stopped at epoch: %.3f' %
                         (fixed_step / num_steps_per_epoch))
            logging.info(traceback.format_exc())
            #raise e
    finally:
        coord.request_stop()

    coord.join(threads, stop_grace_period_secs=5)
    #FIMXE due to use melt.get_session(global not handle del well)
    #Done training for 3090020 steps.
    #Exception TypeError: "'NoneType' object is not callable" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7f6cf33cd450>> ignored
    if FLAGS.use_tpu:
        sess.run(tpu.shutdown_system())
    sess.close()
def flops():
    x = tf.random_uniform([N, N])
    y = tf.random_uniform([N, N])

    def _matmul(x, y):
        return tf.tensordot(x, y, axes=[[1], [0]]), y

    return tf.reduce_sum(tpu.repeat(COUNT, _matmul, [x, y]))


tpu_ops = tpu.batch_parallel(flops, [], num_shards=8)

session = tf.Session(tpu_cluster)

try:
    print('Warming up...')
    session.run(tpu.initialize_system())
    session.run(tpu_ops)
    print('Profiling')
    start = time.time()
    session.run(tpu_ops)
    end = time.time()
    elapsed = end - start
    print(elapsed,
          'TFlops: {:.2f}'.format(1e-12 * 8 * COUNT * 2 * N * N * N / elapsed))
except Exception as e:
    print(e)
finally:
    session.run(tpu.shutdown_system())
    session.close()
 def get_initialization_ops(self):
   return [tpu.initialize_system()]