Exemple #1
0
  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
                                          initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self._unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been reduced, wrap them in a Mirrored
      # container, else in a PerReplica container.
      if reduce_op is None:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerReplica)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Exemple #2
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            iterator,
                                            iterations,
                                            initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = values.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_inputs = iterator.get_next()
            if not isinstance(fn_inputs, tuple):
                fn_inputs = (fn_inputs, )
            fn_result = fn(ctx, fn_inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # TODO(priyag): Use max_iterations instead of an explicit counter.
        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)
        del self._outer_control_flow_context

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
Exemple #3
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = values.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_inputs = dequeue_fn()
            if not isinstance(fn_inputs, tuple):
                fn_inputs = (fn_inputs, )
            fn_result = fn(ctx, fn_inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        # TODO(sourabhbajaj): The input to while loop should be based on the output
        # type of the step_fn
        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        replicate_inputs = [[]] * self._num_replicas_in_sync
        replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        # Filter out any ops from the outputs, typically this would be the case
        # when there were no tensor outputs.
        last_step_tensor_outputs = [
            x for x in replicate_outputs if not isinstance(x, ops.Operation)
        ]

        # Outputs are currently of the structure (grouped by device)
        # [[output0_device0, output1_device0, output2_device0],
        #  [output0_device1, output1_device1, output2_device1]]
        # Convert this to the following structure instead: (grouped by output)
        # [[output0_device0, output0_device1],
        #  [output1_device0, output1_device1],
        #  [output2_device0, output2_device1]]
        last_step_tensor_outputs = [
            list(x) for x in zip(*last_step_tensor_outputs)
        ]

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
            # value.
            if reduce_op is not None:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx