def testSerialization(self): """Tests if the class is able to generate serialized strings.""" original_topology = topology.Topology( mesh_shape=[1, 1, 2], device_coordinates=[[[0, 0, 0], [0, 0, 1]]], ) serialized_str = original_topology.serialized() new_topology = topology.Topology(serialized=serialized_str) # Make sure the topology recovered from serialized str is same as the # original topology. self.assertAllEqual(original_topology.mesh_shape, new_topology.mesh_shape) self.assertAllEqual(original_topology.device_coordinates, new_topology.device_coordinates)
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices in a separate session and graph. Args: cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.contrib.tpu.Topology object for the topology of the TPU cluster. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") master = cluster_resolver.master() logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. # The easiest way to trigger a rewrite is to run the function with # TPUPartitionedCallOp. @function.defun def _tpu_init_fn(): return tpu.initialize_system() # We can't call _tpu_init_fn normally (because it contains just a dummy op, # see above) but need to define it to get it added to eager context # and get its assigned name. # pylint: disable=protected-access graph_func = _tpu_init_fn._get_concrete_function_internal() func_name = compat.as_str(graph_func._inference_function.name) # pylint: enable=protected-access output = tpu_functional_ops.TPUPartitionedCall( args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) serialized_topology = output[0].numpy() else: session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") return topology.Topology(serialized=serialized_topology)
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices in a separate session and graph. Args: cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.contrib.tpu.Topology object for the topology of the TPU cluster. """ if cluster_resolver is None: cluster_resolver = resolver_lib.TPUClusterResolver("") master = cluster_resolver.master() logging.info("Initializing the TPU system.") session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") return topology.Topology(serialized=serialized_topology)