def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly(): raise NotImplementedError( "Eager mode not supported in TPUStrategy.") if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access raise NotImplementedError( "`experimental_run` is not compatible with " "`_disable_training_loop_on_host=True`") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, inputs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(inputs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs) ]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly(): raise NotImplementedError("Eager mode not supported in TPUStrategy.") if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access raise NotImplementedError( "`experimental_run` is not compatible with " "`_disable_training_loop_on_host=True`") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, inputs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(inputs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append((nest.map_structure( select_replica, per_replica_inputs),)) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.batch(..., drop_remainder=True).') types = nest.flatten(iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, iterator, shapes, iterations) for host_id in range(self.num_hosts) ] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs, ) fn_result = fn(ctx, *fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [ list(x) for x in zip(*last_step_tensor_outputs) ] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] # TODO(sourabhbajaj): Add support for TPU pods with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) fn_result = fn(ctx, *fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _experimental_run_steps_on_iterator(self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, iterations) for host_id in range(self.num_hosts) ] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access # pylint: disable=protected-access if self._container_strategy()._disable_training_loop_on_host: replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) else: def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on host 0. with ops.device(self.get_host_cpu_device(0)): replicate_outputs = training_loop.repeat( iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) if self._container_strategy()._disable_training_loop_on_host: # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [ list(x) for x in zip(*last_step_tensor_outputs) ] else: if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len( last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple(inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = [] ctx = values.MultiStepContext(initial_loop_values) def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) if ctx.last_step_outputs is None: ctx.last_step_outputs = [] with ops.control_dependencies([fn_result]): return array_ops.identity(ctx.last_step_outputs) # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, [initial_loop_values]) replicate_inputs = [[]] * self.num_towers outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) last_step_tensor_outputs = [list(x) for x in zip(*outputs)] # Take index [0] of last_step_tensor_outputs as we wrapped # initial_loop_values in a list in the `repeat` call. return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), last_step_tensor_outputs[0], ctx)
def _run_steps_on_iterator_with_device_loop(self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, iterations) for host_id in range(self.num_hosts) ] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [ list(x) for x in zip(*last_step_tensor_outputs) ] _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, iterations) for host_id in range(self.num_hosts)] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_iterator_with_device_loop( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, iterations) for host_id in range(self.num_hosts)] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx