Beispiel #1
0
 def _DeviceAssignment(self):
     """A context for tpu device assignment of a JF 8x8 slice."""
     mesh_shape = [8, 8, 1, 2]
     device_coordinates = np.zeros([16, 8, 4], dtype=np.int32)
     for i in range(np.prod(mesh_shape)):
         x = i // 16
         y = i % 16 // 2
         core = i % 2
         task = x // 2 * 4 + y // 2
         device = x % 2 * 4 + y % 2 * 2 + core
         device_coordinates[task, device] = [x, y, 0, core]
     topology = tf.tpu.experimental.Topology(
         mesh_shape=mesh_shape, device_coordinates=device_coordinates)
     assignment = device_assignment.device_assignment(
         topology, computation_shape=[1, 1, 1, 1], num_replicas=128)
     py_utils.SetTpuDeviceAssignment(assignment)
     try:
         yield
     finally:
         py_utils.SetTpuDeviceAssignment(None)
Beispiel #2
0
        def _WaitTillInit(job=None):
            """Wait until the model is ready."""
            try:
                if py_utils.IsEagerMode():
                    topology = tf.tpu.experimental.initialize_tpu_system(
                        resolver)
                else:
                    # tpu.initialize_system() is called with None as embedding_config, as
                    # embedding_config is not available yet. Later in _Loop, it is called
                    # with the correct embedding_config. Since it cannot be called twice
                    # in the same graph with different embedding_config, we use a
                    # dummy_graph here.
                    dummy_graph = tf.Graph()
                    with dummy_graph.as_default():
                        tpu_initialize_system_op = tf.tpu.initialize_system(
                            embedding_config=None, job=job)

                    with self._GetSession(graph=dummy_graph) as sess:
                        topology = sess.run(tpu_initialize_system_op)

                if train_cfg.train.tpu_computation_shape is None:
                    computation_shape = py_utils.ComputationShape(
                        num_devices_per_split, topology)
                else:
                    computation_shape = train_cfg.train.tpu_computation_shape
                    assert num_devices_per_split == np.prod(computation_shape)

                if train_cfg.train.tpu_device_order_mode is None:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism)
                else:
                    self.device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=computation_shape,
                        num_replicas=data_parallelism,
                        device_order_mode=train_cfg.train.tpu_device_order_mode
                    )
                py_utils.SetTpuDeviceAssignment(self.device_assignment, job)
                tf.logging.info('device_assignment.core_assignment: %s',
                                str(self.device_assignment.core_assignment))
                tf.logging.info(
                    'device_assignment.topology.device_coordinates: %s',
                    str(self.device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise
Beispiel #3
0
 def _WaitTillInit():
   """Wait until the model is ready."""
   try:
     with self._GetSession() as sess:
       topology = sess.run(
           tf.contrib.tpu.initialize_system(embedding_config=None, job=None))
       device_assignment = tf.contrib.tpu.device_assignment(
           topology,
           computation_shape=ComputationShape(num_devices_per_split),
           num_replicas=data_parallelism)
       py_utils.SetTpuDeviceAssignment(device_assignment)
       tf.logging.info('device_assignment.core_assignment: %s',
                       str(device_assignment.core_assignment))
       tf.logging.info('device_assignment.topology.device_coordinates: %s',
                       str(device_assignment.topology.device_coordinates))
   except py_utils.transient_tf_errors as e:
     tf.logging.info('TPU initialization failed: %s', e)
     raise
Beispiel #4
0
    def _init_tpu(self, num_partitions, device_order_mode):
        """Initialize tpu device assignment."""
        tf.logging.info('Initializing TPU to get device assignment: start')

        graph = tf.Graph()
        with graph.as_default():
            init_tpu_op = tf.tpu.initialize_system()
        try:
            sess = tf.Session(target=self._tpu,
                              graph=graph,
                              config=self._no_opt_sess_cfg())
            topology = sess.run(init_tpu_op)
        except Exception as e:
            tf.logging.fatal('TPU initialization failed: %s', e)
            raise

        topology_proto = topology_pb2.TopologyProto()
        topology_proto.ParseFromString(topology)
        tf.logging.info('topology.num_tasks: %r', topology_proto.num_tasks)
        tf.logging.info('topology.num_tpu_devices_per_task: %r',
                        topology_proto.num_tpu_devices_per_task)
        tf.logging.info('topology.mesh_shape: %r', topology_proto.mesh_shape)
        self.cluster_params = self._configure_cluster_params(
            tpu_cores=(topology_proto.num_tpu_devices_per_task *
                       topology_proto.num_tasks),
            cpu_hosts=topology_proto.num_tasks)

        # We assume the topology and device assignment does not change
        # for a single address space.
        device_assignment = tpu_device_assignment.device_assignment(
            topology,
            computation_shape=py_utils.ComputationShape(
                num_partitions, topology),
            num_replicas=1,
            device_order_mode=device_order_mode)
        py_utils.SetTpuDeviceAssignment(device_assignment)

        tf.logging.info('Initializing TPU to get device assignment: done')
Beispiel #5
0
 def _WaitTillInit():
   """Wait until the model is ready."""
   try:
     with self._graph.as_default(), self._GetSession(
         cluster_def=self._cluster_def,
         disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor
     ) as sess:
       topology = sess.run(
           tf.tpu.initialize_system(embedding_config=None, job=None))
       device_assignment = device_assignment_lib.device_assignment(
           topology,
           computation_shape=py_utils.ComputationShape(
               num_devices_per_split),
           num_replicas=data_parallelism)
       py_utils.SetTpuDeviceAssignment(device_assignment)
       tf.logging.info('device_assignment.core_assignment: %s',
                       str(device_assignment.core_assignment))
       tf.logging.info(
           'device_assignment.topology.device_coordinates: %s',
           str(device_assignment.topology.device_coordinates))
   except py_utils.transient_tf_errors as e:
     tf.logging.info('TPU initialization failed: %s', e)
     raise