示例#1
0
 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
                                         initial_loop_values=None):
   # TODO(tomhennigan) This is missing many things (e.g. ctx.run_op).
   ctx = input_lib.MultiStepContext()
   for _ in range(iterations):
     fn(ctx, iterator.get_next())
   return ctx
示例#2
0
  def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
                                          initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = input_lib.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_result = fn(ctx, iterator.get_next())
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self._local_results(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been reduced, wrap them in a Mirrored
      # container, else in a PerReplica container.
      if reduce_op is None:
        last_step_tensor_outputs_dict[name] = values.regroup(self._device_map,
                                                             output)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            iterator,
                                            iterations,
                                            initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = input_lib.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_result = fn(ctx, iterator.get_next())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # TODO(priyag): Use max_iterations instead of an explicit counter.
        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)
        del self._outer_control_flow_context

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
示例#4
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(inputs):
            """Single step on the TPU device."""
            fn_result = fn(ctx, inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        def rewrite_fn(*args):
            """The rewritten step fn running on TPU."""
            del args

            per_replica_inputs = multi_worker_iterator.get_next()
            replicate_inputs = []
            for replica_id in range(self._num_replicas_in_sync):
                select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
                replicate_inputs.append(
                    (nest.map_structure(select_replica, per_replica_inputs), ))

            replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

            # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
            # will flatten it in this case. If run_fn has no tensor outputs,
            # tpu.replicate returns a list of no_ops, we will keep the output as it
            # is.
            if isinstance(replicate_outputs[0], list):
                replicate_outputs = nest.flatten(replicate_outputs)

            return replicate_outputs

        # TODO(sourabhbajaj): The input to while loop should be based on the
        # output type of the step_fn
        assert isinstance(initial_loop_values, list)
        initial_loop_values = initial_loop_values * self._num_replicas_in_sync

        # Put the while loop op on TPU host 0.
        with ops.device(self._host_device):
            if self.steps_per_run == 1:
                replicate_outputs = rewrite_fn()
            else:
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs)

        if isinstance(replicate_outputs, list):
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (flattened)
            # [output0_device0, output1_device0, output2_device0,
            #  output0_device1, output1_device1, output2_device1,
            #  ...]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            output_num = len(
                last_step_tensor_outputs) // self._num_replicas_in_sync
            last_step_tensor_outputs = [
                last_step_tensor_outputs[i::output_num]
                for i in range(output_num)
            ]
        else:
            # no tensors returned.
            last_step_tensor_outputs = []

        _set_last_step_outputs(ctx, last_step_tensor_outputs)
        return ctx
示例#5
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # pylint: disable=protected-access
        if self._container_strategy()._disable_training_loop_on_host:
            replicate_inputs = [[]] * self._num_replicas_in_sync
            replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        else:

            def rewrite_fn(*args):
                """The rewritten step fn running on TPU."""
                del args
                replicate_inputs = [[]] * self._num_replicas_in_sync
                replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

                # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
                # will flatten it in this case. If run_fn has no tensor outputs,
                # tpu.replicate returns a list of no_ops, we will keep the output as it
                # is.
                if isinstance(replicate_outputs[0], list):
                    replicate_outputs = nest.flatten(replicate_outputs)

                return replicate_outputs

            # TODO(sourabhbajaj): The input to while loop should be based on the
            # output type of the step_fn
            assert isinstance(initial_loop_values, list)
            initial_loop_values = initial_loop_values * self._num_replicas_in_sync

            # Put the while loop op on host 0.
            with ops.device(self.get_host_cpu_device(0)):
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        if self._container_strategy()._disable_training_loop_on_host:
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (grouped by device)
            # [[output0_device0, output1_device0, output2_device0],
            #  [output0_device1, output1_device1, output2_device1]]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            last_step_tensor_outputs = [
                list(x) for x in zip(*last_step_tensor_outputs)
            ]
        else:
            if isinstance(replicate_outputs, list):
                # Filter out any ops from the outputs, typically this would be the case
                # when there were no tensor outputs.
                last_step_tensor_outputs = [
                    x for x in replicate_outputs
                    if not isinstance(x, ops.Operation)
                ]

                # Outputs are currently of the structure (flattened)
                # [output0_device0, output1_device0, output2_device0,
                #  output0_device1, output1_device1, output2_device1,
                #  ...]
                # Convert this to the following structure instead: (grouped by output)
                # [[output0_device0, output0_device1],
                #  [output1_device0, output1_device1],
                #  [output2_device0, output2_device1]]
                output_num = len(
                    last_step_tensor_outputs) // self._num_replicas_in_sync
                last_step_tensor_outputs = [
                    last_step_tensor_outputs[i::output_num]
                    for i in range(output_num)
                ]
            else:
                # no tensors returned.
                last_step_tensor_outputs = []

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
            # value.
            if reduce_op is not None:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx
示例#6
0
    def _run_steps_on_iterator_with_device_loop(self,
                                                fn,
                                                multi_worker_iterator,
                                                iterations,
                                                initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        replicate_inputs = [[]] * self._num_replicas_in_sync
        replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        # Filter out any ops from the outputs, typically this would be the case
        # when there were no tensor outputs.
        last_step_tensor_outputs = [
            x for x in replicate_outputs if not isinstance(x, ops.Operation)
        ]

        # Outputs are currently of the structure (grouped by device)
        # [[output0_device0, output1_device0, output2_device0],
        #  [output0_device1, output1_device1, output2_device1]]
        # Convert this to the following structure instead: (grouped by output)
        # [[output0_device0, output0_device1],
        #  [output1_device0, output1_device1],
        #  [output2_device0, output2_device1]]
        last_step_tensor_outputs = [
            list(x) for x in zip(*last_step_tensor_outputs)
        ]

        _set_last_step_outputs(ctx, last_step_tensor_outputs)
        return ctx