Example #1
0
    def testSerialization(self):
        """Tests if the class is able to generate serialized strings."""
        original_topology = topology.Topology(
            mesh_shape=[1, 1, 2],
            device_coordinates=[[[0, 0, 0], [0, 0, 1]]],
        )
        serialized_str = original_topology.serialized()
        new_topology = topology.Topology(serialized=serialized_str)

        # Make sure the topology recovered from serialized str is same as the
        # original topology.
        self.assertAllEqual(original_topology.mesh_shape,
                            new_topology.mesh_shape)
        self.assertAllEqual(original_topology.device_coordinates,
                            new_topology.device_coordinates)
Example #2
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices in a separate session and graph.

  Args:
    cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.contrib.tpu.Topology object for the topology of the TPU cluster.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  master = cluster_resolver.master()

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    # The easiest way to trigger a rewrite is to run the function with
    # TPUPartitionedCallOp.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    # We can't call _tpu_init_fn normally (because it contains just a dummy op,
    # see above) but need to define it to get it added to eager context
    # and get its assigned name.
    # pylint: disable=protected-access
    graph_func = _tpu_init_fn._get_concrete_function_internal()
    func_name = compat.as_str(graph_func._inference_function.name)
    # pylint: enable=protected-access

    output = tpu_functional_ops.TPUPartitionedCall(
        args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
    serialized_topology = output[0].numpy()
  else:
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  return topology.Topology(serialized=serialized_topology)
Example #3
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices in a separate session and graph.

  Args:
    cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.contrib.tpu.Topology object for the topology of the TPU cluster.
  """
    if cluster_resolver is None:
        cluster_resolver = resolver_lib.TPUClusterResolver("")
    master = cluster_resolver.master()

    logging.info("Initializing the TPU system.")
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)

    with ops.Graph().as_default():
        with session_lib.Session(config=session_config, target=master) as sess:
            serialized_topology = sess.run(tpu.initialize_system())
    logging.info("Finished initializing TPU system.")
    return topology.Topology(serialized=serialized_topology)