def _DeviceAssignment(self): """A context for tpu device assignment of a JF 8x8 slice.""" mesh_shape = [8, 8, 1, 2] device_coordinates = np.zeros([16, 8, 4], dtype=np.int32) for i in range(np.prod(mesh_shape)): x = i // 16 y = i % 16 // 2 core = i % 2 task = x // 2 * 4 + y // 2 device = x % 2 * 4 + y % 2 * 2 + core device_coordinates[task, device] = [x, y, 0, core] topology = tf.tpu.experimental.Topology( mesh_shape=mesh_shape, device_coordinates=device_coordinates) assignment = device_assignment.device_assignment( topology, computation_shape=[1, 1, 1, 1], num_replicas=128) py_utils.SetTpuDeviceAssignment(assignment) try: yield finally: py_utils.SetTpuDeviceAssignment(None)
def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise
def _WaitTillInit(): """Wait until the model is ready.""" try: with self._GetSession() as sess: topology = sess.run( tf.contrib.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = tf.contrib.tpu.device_assignment( topology, computation_shape=ComputationShape(num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise
def _init_tpu(self, num_partitions, device_order_mode): """Initialize tpu device assignment.""" tf.logging.info('Initializing TPU to get device assignment: start') graph = tf.Graph() with graph.as_default(): init_tpu_op = tf.tpu.initialize_system() try: sess = tf.Session(target=self._tpu, graph=graph, config=self._no_opt_sess_cfg()) topology = sess.run(init_tpu_op) except Exception as e: tf.logging.fatal('TPU initialization failed: %s', e) raise topology_proto = topology_pb2.TopologyProto() topology_proto.ParseFromString(topology) tf.logging.info('topology.num_tasks: %r', topology_proto.num_tasks) tf.logging.info('topology.num_tpu_devices_per_task: %r', topology_proto.num_tpu_devices_per_task) tf.logging.info('topology.mesh_shape: %r', topology_proto.mesh_shape) self.cluster_params = self._configure_cluster_params( tpu_cores=(topology_proto.num_tpu_devices_per_task * topology_proto.num_tasks), cpu_hosts=topology_proto.num_tasks) # We assume the topology and device assignment does not change # for a single address space. device_assignment = tpu_device_assignment.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_partitions, topology), num_replicas=1, device_order_mode=device_order_mode) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('Initializing TPU to get device assignment: done')
def _WaitTillInit(): """Wait until the model is ready.""" try: with self._graph.as_default(), self._GetSession( cluster_def=self._cluster_def, disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor ) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise