Beispiel #1
0
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_result = fn(ctx, iterator.get_next())
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    # TODO(priyag): Use max_iterations instead of an explicit counter.
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self.unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, wrap them in a Mirrored
      # container, else in a PerDevice container.
      if aggregation is variables_lib.VariableAggregation.NONE:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = values.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_result = fn(ctx, iterator.get_next())
            for (name, output) in ctx.last_step_outputs.items():
                # Convert all outputs to tensors, potentially from `DistributedValues`.
                ctx.last_step_outputs[name] = self.unwrap(output)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been aggregated, wrap them in a Mirrored
            # container, else in a PerDevice container.
            if aggregation is variables_lib.VariableAggregation.NONE:
                last_step_tensor_outputs_dict[name] = values.regroup(
                    {d: t
                     for d, t in zip(self._devices, output)}, values.PerDevice)
            else:
                assert len(output) == 1
                last_step_tensor_outputs_dict[name] = output[0]

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = values.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_inputs = iterator.get_next()
            if not isinstance(fn_inputs, tuple):
                fn_inputs = (fn_inputs, )
            fn_result = fn(ctx, *fn_inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # TODO(priyag): Use max_iterations instead of an explicit counter.
        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)
        del self._outer_control_flow_context

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
Beispiel #5
0
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):

        shapes = nest.flatten(iterator.output_shapes)
        if any([not s.is_fully_defined() for s in shapes]):
            raise ValueError(
                'TPU currently requires fully defined shapes. Either use '
                'set_shape() on the input tensors or use '
                'dataset.batch(..., drop_remainder=True).')
        types = nest.flatten(iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, iterator, shapes,
                                          iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(iterator.output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = values.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_inputs = dequeue_fn()
            if not isinstance(fn_inputs, tuple):
                fn_inputs = (fn_inputs, )
            fn_result = fn(ctx, *fn_inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        # TODO(sourabhbajaj): The input to while loop should be based on the output
        # type of the step_fn
        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        replicate_inputs = [[]] * self.num_towers
        replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        # Filter out any ops from the outputs, typically this would be the case
        # when there were no tensor outputs.
        last_step_tensor_outputs = [
            x for x in replicate_outputs if not isinstance(x, ops.Operation)
        ]

        # Outputs are currently of the structure (grouped by device)
        # [[output0_device0, output1_device0, output2_device0],
        #  [output0_device1, output1_device1, output2_device1]]
        # Convert this to the following structure instead: (grouped by output)
        # [[output0_device0, output0_device1],
        #  [output1_device0, output1_device1],
        #  [output2_device0, output2_device1]]
        last_step_tensor_outputs = [
            list(x) for x in zip(*last_step_tensor_outputs)
        ]

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been aggregated, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            if aggregation is not variables_lib.VariableAggregation.NONE:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx
Beispiel #6
0
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):

        shapes = nest.flatten(iterator.output_shapes)
        if any([not s.is_fully_defined() for s in shapes]):
            raise ValueError(
                'TPU currently requires fully defined shapes. Either use '
                'set_shape() on the input tensors or use '
                'dataset.apply(map_and_batch(..., drop_remainder=True)).')
        types = nest.flatten(iterator.output_types)

        def enqueue_ops_fn():
            """Enqueue ops for one iteration."""
            control_deps = []
            sharded_inputs = []
            with ops.device(self._host):
                for _ in range(self.num_towers):
                    # Use control dependencies to ensure a deterministic ordering.
                    with ops.control_dependencies(control_deps):
                        inputs = nest.flatten(iterator.get_next())
                        control_deps.extend(inputs)
                        sharded_inputs.append(inputs)

            enqueue_ops = []
            for core_id, shard_input in enumerate(sharded_inputs):
                enqueue_ops.append(
                    tpu_ops.infeed_enqueue_tuple(inputs=shard_input,
                                                 shapes=shapes,
                                                 device_ordinal=core_id))
            return enqueue_ops

        def enqueue_ops_loop_body(i):
            with ops.control_dependencies(enqueue_ops_fn()):
                return i + 1

        with ops.device(self._host):
            enqueue_ops = control_flow_ops.while_loop(
                lambda i: i < iterations,
                enqueue_ops_loop_body, [constant_op.constant(0)],
                parallel_iterations=1)

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(iterator.output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = []
        ctx = values.MultiStepContext(initial_loop_values)

        def run_fn(*args, **kwargs):
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            if ctx.last_step_outputs is None:
                ctx.last_step_outputs = []
            with ops.control_dependencies([fn_result]):
                return array_ops.identity(ctx.last_step_outputs)

        # TODO(sourabhbajaj): The input to while loop should be based on the output
        # type of the step_fn
        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        [initial_loop_values])

        replicate_inputs = [[]] * self.num_towers
        outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        last_step_tensor_outputs = [list(x) for x in zip(*outputs)]

        # Take index [0] of last_step_tensor_outputs as we wrapped
        # initial_loop_values in a list in the `repeat` call.
        return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops),
                last_step_tensor_outputs[0], ctx)
Beispiel #7
0
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):

    shapes = nest.flatten(iterator.output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = nest.flatten(iterator.output_types)

    def enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      with ops.device(self._host):
        for _ in range(self.num_towers):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())
            control_deps.extend(inputs)
            sharded_inputs.append(inputs)

      enqueue_ops = []
      for core_id, shard_input in enumerate(sharded_inputs):
        enqueue_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=shard_input, shapes=shapes, device_ordinal=core_id))
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      with ops.control_dependencies(enqueue_ops_fn()):
        return i + 1

    with ops.device(self._host):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          enqueue_ops_loop_body,
          [constant_op.constant(0)],
          parallel_iterations=1)

    def dequeue_fn():
      dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(iterator.output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = values.MultiStepContext()
    def run_fn(*args, **kwargs):
      del args, kwargs
      fn_result = fn(ctx, dequeue_fn())
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    # TODO(sourabhbajaj): The input to while loop should be based on the output
    # type of the step_fn
    def iterate_on_tpu():
      return training_loop.repeat(iterations, run_fn, initial_loop_values)

    replicate_inputs = [[]] * self.num_towers
    replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
    ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

    # Filter out any ops from the outputs, typically this would be the case
    # when there were no tensor outputs.
    last_step_tensor_outputs = [x for x in replicate_outputs
                                if not isinstance(x, ops.Operation)]

    # Outputs are currently of the structure (grouped by device)
    # [[output0_device0, output1_device0, output2_device0],
    #  [output0_device1, output1_device1, output2_device1]]
    # Convert this to the following structure instead: (grouped by output)
    # [[output0_device0, output0_device1],
    #  [output1_device0, output1_device1],
    #  [output2_device0, output2_device1]]
    last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)]

    # Convert replicate_outputs to the original dict structure of
    # last_step_outputs.
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, take the first value
      # from the list as each value should be the same. Else return the full
      # list of values.
      if aggregation is not variables_lib.VariableAggregation.NONE:
        # TODO(priyag): Should this return the element or a list with 1 element
        last_step_tensor_outputs_dict[name] = output[0]
    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

    return ctx