def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shapes.append(input_tensor.get_shape()) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions." ) if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, replica_input): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(replica_input) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs) ]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append( (nest.map_structure(select_replica, per_replica_inputs), )) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs
def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append((nest.map_structure( select_replica, per_replica_inputs),)) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, replica_input): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(replica_input) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Remove None at the end of args as they are not replicatable # If there are None in the middle we can't do anything about it # so let those cases fail. # For example when Keras model predict is used they pass the targets as # None. We want to handle it here so all client libraries don't have to # do this as other strategies can handle None values better. while args and args[-1] is None: args = args[:-1] # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if self.experimental_enable_dynamic_batch_size and replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): rank = input_tensor.get_shape().rank else: rank = np.rank(input_tensor) maximum_shape = tensor_shape.TensorShape([None] * rank) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate( replicated_fn, replicate_inputs, device_assignment=self._device_assignment, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if not isinstance(output, ops.Operation) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None or isinstance(result[0], ops.Operation): replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] return values.regroup(replicate_outputs)
def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shape = input_tensor.get_shape() if tensor_util.is_tensor( input_tensor) else tensor_shape.TensorShape( np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def test_get_laidout_tensors(self, is_eval_mode): mesh_shape = "mesh_x:2, mesh_y:1" layout = "batch:mesh_x, io:mesh_y" batch_io_dim = 4 with tf.Session() as sess: topology, num_cores = self.initialize_system(sess) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[ 1, ] * mtf.utils.topology_rank(topology), num_replicas=num_cores) # Hacked dataset creator: creates different datasets for the first and # second call, in order to test SimdMeshImplInputReader. self.sub_batch_created_times = 0 def stateful_ds_creator(): whole_batch = tf.eye(batch_io_dim, dtype=tf.float32) sub_batch = tf.slice(whole_batch, [self.sub_batch_created_times * 2, 0], [2, 4]) self.sub_batch_created_times += 1 return tf.data.Dataset.from_tensors( sub_batch).repeat().unbatch() batch_dim = mtf.Dimension("batch", batch_io_dim) io_dim = mtf.Dimension("io", batch_io_dim) mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])] # Get mesh_impl. mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, None, d_assignment) simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, stateful_ds_creator, mtf_input_shapes, external_worker=False, is_eval_mode=is_eval_mode) def model_fn(features): return features replicated_computation = tpu.replicate( computation=model_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) simd_input_reader.start_infeed_thread(sess, 1) results = sess.run(replicated_computation) print("results: {}".format(results)) core_0_data = results[0][0] core_1_data = results[1][0] print("core_0_data: {}".format(core_0_data)) print("core_1_data: {}".format(core_1_data)) if is_eval_mode: # If there is only one dataset object, then the stateful_ds_creator() # should be called only once. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_1_data) else: # If there are two dataset objects, then the stateful_ds_creator() # should be called twice. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32), core_1_data) sess.run(tf.tpu.shutdown_system())
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): maximum_shape = input_tensor.get_shape() else: maximum_shape = tensor_shape.TensorShape( np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate( replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None: replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def replicated_f(): return tpu.replicate( f, inputs=[[constant_op.constant([1., 2., 3., 4.])]])