def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_result = fn(ctx, iterator.get_next()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs cond = lambda i, *args: i < iterations i = constant_op.constant(0) # TODO(priyag): Use max_iterations instead of an explicit counter. loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) fn_result = fn(ctx, *fn_inputs) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop. This is useful in cases where we might need to exit # these contexts and get back to the outer context to do some things, for # e.g. create an op which should be evaluated only once at the end of the # loop on the host. One such usage is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored # container, else in a PerDevice container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerDevice) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_result = fn(ctx, iterator.get_next()) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self.unwrap(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop(cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored # container, else in a PerDevice container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerDevice) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs, ) fn_result = fn(ctx, *fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop. This is useful in cases where we might need to exit # these contexts and get back to the outer context to do some things, for # e.g. create an op which should be evaluated only once at the end of the # loop on the host. One such usage is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access # TODO(priyag): Use max_iterations instead of an explicit counter. cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop(cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.batch(..., drop_remainder=True).') types = nest.flatten(iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, iterator, shapes, iterations) for host_id in range(self.num_hosts) ] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs, ) fn_result = fn(ctx, *fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [ list(x) for x in zip(*last_step_tensor_outputs) ] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple(inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = [] ctx = values.MultiStepContext(initial_loop_values) def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) if ctx.last_step_outputs is None: ctx.last_step_outputs = [] with ops.control_dependencies([fn_result]): return array_ops.identity(ctx.last_step_outputs) # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, [initial_loop_values]) replicate_inputs = [[]] * self.num_towers outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) last_step_tensor_outputs = [list(x) for x in zip(*outputs)] # Take index [0] of last_step_tensor_outputs as we wrapped # initial_loop_values in a list in the `repeat` call. return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), last_step_tensor_outputs[0], ctx)
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx