Exemple #1
0
            def rewrite_fn(*args):
                """The rewritten step fn running on TPU."""
                del args
                replicate_inputs = [[]] * self._num_replicas_in_sync
                replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

                # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
                # will flatten it in this case. If run_fn has no tensor outputs,
                # tpu.replicate returns a list of no_ops, we will keep the output as it
                # is.
                if isinstance(replicate_outputs[0], list):
                    replicate_outputs = nest.flatten(replicate_outputs)

                return replicate_outputs
Exemple #2
0
      def rewrite_fn(*args):
        """The rewritten step fn running on TPU."""
        del args
        replicate_inputs = [[]] * self._num_replicas_in_sync
        replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

        # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
        # will flatten it in this case. If run_fn has no tensor outputs,
        # tpu.replicate returns a list of no_ops, we will keep the output as it
        # is.
        if isinstance(replicate_outputs[0], list):
          replicate_outputs = nest.flatten(replicate_outputs)

        return replicate_outputs
    def experimental_run(self, fn, input_iterator=None):
        """See base class."""
        if context.executing_eagerly():
            raise NotImplementedError(
                "Eager mode not supported in TPUStrategy.")

        if self.extended._disable_training_loop_on_host:  # pylint: disable=protected-access
            raise NotImplementedError(
                "`experimental_run` is not compatible with "
                "`_disable_training_loop_on_host=True`")

        if input_iterator is None:
            inputs = []
        else:
            inputs = input_iterator.get_next()

        result = [None]

        def replicated_fn(replica_id, inputs):
            """Wraps user function to provide replica ID and `Tensor` inputs."""
            with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
                if input_iterator is None:
                    result[0] = fn()
                else:
                    result[0] = fn(inputs)
            return result[0]

        replicate_inputs = []  # By replica.
        for i in range(self.num_replicas_in_sync):
            replicate_inputs.append([
                constant_op.constant(i, dtype=dtypes.int32),
                values.select_replica(i, inputs)
            ])

        with self.scope():
            replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

        # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
        replicate_outputs = [
            nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
            for replica_outputs in replicate_outputs
        ]

        device_map = self.extended._device_map  # pylint: disable=protected-access
        return values.regroup(device_map, replicate_outputs)
  def experimental_run(self, fn, input_iterator=None):
    """See base class."""
    if context.executing_eagerly():
      raise NotImplementedError("Eager mode not supported in TPUStrategy.")

    if self.extended._disable_training_loop_on_host:  # pylint: disable=protected-access
      raise NotImplementedError(
          "`experimental_run` is not compatible with "
          "`_disable_training_loop_on_host=True`")

    if input_iterator is None:
      inputs = []
    else:
      inputs = input_iterator.get_next()

    result = [None]
    def replicated_fn(replica_id, inputs):
      """Wraps user function to provide replica ID and `Tensor` inputs."""
      with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
        if input_iterator is None:
          result[0] = fn()
        else:
          result[0] = fn(inputs)
      return result[0]

    replicate_inputs = []  # By replica.
    for i in range(self.num_replicas_in_sync):
      replicate_inputs.append(
          [constant_op.constant(i, dtype=dtypes.int32),
           values.select_replica(i, inputs)])

    with self.scope():
      replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)

    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
    replicate_outputs = [
        nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
        for replica_outputs in replicate_outputs]

    device_map = self.extended._device_map  # pylint: disable=protected-access
    return values.regroup(device_map, replicate_outputs)
Exemple #5
0
    def rewrite_fn(*args):
      """The rewritten step fn running on TPU."""
      del args

      per_replica_inputs = multi_worker_iterator.get_next()
      replicate_inputs = []
      for replica_id in range(self._num_replicas_in_sync):
        select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
        replicate_inputs.append((nest.map_structure(
            select_replica, per_replica_inputs),))

      replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
      # will flatten it in this case. If run_fn has no tensor outputs,
      # tpu.replicate returns a list of no_ops, we will keep the output as it
      # is.
      if isinstance(replicate_outputs[0], list):
        replicate_outputs = nest.flatten(replicate_outputs)

      return replicate_outputs
Exemple #6
0
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):

        shapes = nest.flatten(iterator.output_shapes)
        if any([not s.is_fully_defined() for s in shapes]):
            raise ValueError(
                'TPU currently requires fully defined shapes. Either use '
                'set_shape() on the input tensors or use '
                'dataset.batch(..., drop_remainder=True).')
        types = nest.flatten(iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, iterator, shapes,
                                          iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(iterator.output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = values.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_inputs = dequeue_fn()
            if not isinstance(fn_inputs, tuple):
                fn_inputs = (fn_inputs, )
            fn_result = fn(ctx, *fn_inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        # TODO(sourabhbajaj): The input to while loop should be based on the output
        # type of the step_fn
        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        replicate_inputs = [[]] * self.num_towers
        replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        # Filter out any ops from the outputs, typically this would be the case
        # when there were no tensor outputs.
        last_step_tensor_outputs = [
            x for x in replicate_outputs if not isinstance(x, ops.Operation)
        ]

        # Outputs are currently of the structure (grouped by device)
        # [[output0_device0, output1_device0, output2_device0],
        #  [output0_device1, output1_device1, output2_device1]]
        # Convert this to the following structure instead: (grouped by output)
        # [[output0_device0, output0_device1],
        #  [output1_device0, output1_device1],
        #  [output2_device0, output2_device1]]
        last_step_tensor_outputs = [
            list(x) for x in zip(*last_step_tensor_outputs)
        ]

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been aggregated, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            if aggregation is not variables_lib.VariableAggregation.NONE:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):

    shapes = nest.flatten(iterator.output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = nest.flatten(iterator.output_types)

    def enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      # TODO(sourabhbajaj): Add support for TPU pods
      with ops.device(self._host):
        for _ in range(self.num_towers):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())
            control_deps.extend(inputs)
            sharded_inputs.append(inputs)

      enqueue_ops = []
      for core_id, shard_input in enumerate(sharded_inputs):
        enqueue_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=shard_input, shapes=shapes, device_ordinal=core_id))
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      with ops.control_dependencies(enqueue_ops_fn()):
        return i + 1

    with ops.device(self._host):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          enqueue_ops_loop_body,
          [constant_op.constant(0)],
          parallel_iterations=1)

    def dequeue_fn():
      dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(iterator.output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = values.MultiStepContext()
    def run_fn(*args, **kwargs):
      del args, kwargs
      fn_inputs = dequeue_fn()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    # TODO(sourabhbajaj): The input to while loop should be based on the output
    # type of the step_fn
    def iterate_on_tpu():
      return training_loop.repeat(iterations, run_fn, initial_loop_values)

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop and TPU replicate context. This is useful in cases
    # where we might need to exit these contexts and get back to the outer
    # context to do some things, for e.g. create an op which should be
    # evaluated only once at the end of the loop on the host. One such usage
    # is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    replicate_inputs = [[]] * self.num_towers
    replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
    del self._outer_control_flow_context
    ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

    # Filter out any ops from the outputs, typically this would be the case
    # when there were no tensor outputs.
    last_step_tensor_outputs = [x for x in replicate_outputs
                                if not isinstance(x, ops.Operation)]

    # Outputs are currently of the structure (grouped by device)
    # [[output0_device0, output1_device0, output2_device0],
    #  [output0_device1, output1_device1, output2_device1]]
    # Convert this to the following structure instead: (grouped by output)
    # [[output0_device0, output0_device1],
    #  [output1_device0, output1_device1],
    #  [output2_device0, output2_device1]]
    last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)]

    # Convert replicate_outputs to the original dict structure of
    # last_step_outputs.
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, take the first value
      # from the list as each value should be the same. Else return the full
      # list of values.
      if aggregation is not variables_lib.VariableAggregation.NONE:
        # TODO(priyag): Should this return the element or a list with 1 element
        last_step_tensor_outputs_dict[name] = output[0]
    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

    return ctx
Exemple #8
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # pylint: disable=protected-access
        if self._container_strategy()._disable_training_loop_on_host:
            replicate_inputs = [[]] * self._num_replicas_in_sync
            replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        else:

            def rewrite_fn(*args):
                """The rewritten step fn running on TPU."""
                del args
                replicate_inputs = [[]] * self._num_replicas_in_sync
                replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

                # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
                # will flatten it in this case. If run_fn has no tensor outputs,
                # tpu.replicate returns a list of no_ops, we will keep the output as it
                # is.
                if isinstance(replicate_outputs[0], list):
                    replicate_outputs = nest.flatten(replicate_outputs)

                return replicate_outputs

            # TODO(sourabhbajaj): The input to while loop should be based on the
            # output type of the step_fn
            assert isinstance(initial_loop_values, list)
            initial_loop_values = initial_loop_values * self._num_replicas_in_sync

            # Put the while loop op on host 0.
            with ops.device(self.get_host_cpu_device(0)):
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        if self._container_strategy()._disable_training_loop_on_host:
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (grouped by device)
            # [[output0_device0, output1_device0, output2_device0],
            #  [output0_device1, output1_device1, output2_device1]]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            last_step_tensor_outputs = [
                list(x) for x in zip(*last_step_tensor_outputs)
            ]
        else:
            if isinstance(replicate_outputs, list):
                # Filter out any ops from the outputs, typically this would be the case
                # when there were no tensor outputs.
                last_step_tensor_outputs = [
                    x for x in replicate_outputs
                    if not isinstance(x, ops.Operation)
                ]

                # Outputs are currently of the structure (flattened)
                # [output0_device0, output1_device0, output2_device0,
                #  output0_device1, output1_device1, output2_device1,
                #  ...]
                # Convert this to the following structure instead: (grouped by output)
                # [[output0_device0, output0_device1],
                #  [output1_device0, output1_device1],
                #  [output2_device0, output2_device1]]
                output_num = len(
                    last_step_tensor_outputs) // self._num_replicas_in_sync
                last_step_tensor_outputs = [
                    last_step_tensor_outputs[i::output_num]
                    for i in range(output_num)
                ]
            else:
                # no tensors returned.
                last_step_tensor_outputs = []

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
            # value.
            if reduce_op is not None:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx
Exemple #9
0
    def _run_steps_on_dataset(self,
                              fn,
                              iterator,
                              iterations,
                              initial_loop_values=None):

        shapes = nest.flatten(iterator.output_shapes)
        if any([not s.is_fully_defined() for s in shapes]):
            raise ValueError(
                'TPU currently requires fully defined shapes. Either use '
                'set_shape() on the input tensors or use '
                'dataset.apply(map_and_batch(..., drop_remainder=True)).')
        types = nest.flatten(iterator.output_types)

        def enqueue_ops_fn():
            """Enqueue ops for one iteration."""
            control_deps = []
            sharded_inputs = []
            with ops.device(self._host):
                for _ in range(self.num_towers):
                    # Use control dependencies to ensure a deterministic ordering.
                    with ops.control_dependencies(control_deps):
                        inputs = nest.flatten(iterator.get_next())
                        control_deps.extend(inputs)
                        sharded_inputs.append(inputs)

            enqueue_ops = []
            for core_id, shard_input in enumerate(sharded_inputs):
                enqueue_ops.append(
                    tpu_ops.infeed_enqueue_tuple(inputs=shard_input,
                                                 shapes=shapes,
                                                 device_ordinal=core_id))
            return enqueue_ops

        def enqueue_ops_loop_body(i):
            with ops.control_dependencies(enqueue_ops_fn()):
                return i + 1

        with ops.device(self._host):
            enqueue_ops = control_flow_ops.while_loop(
                lambda i: i < iterations,
                enqueue_ops_loop_body, [constant_op.constant(0)],
                parallel_iterations=1)

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(iterator.output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = []
        ctx = values.MultiStepContext(initial_loop_values)

        def run_fn(*args, **kwargs):
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            if ctx.last_step_outputs is None:
                ctx.last_step_outputs = []
            with ops.control_dependencies([fn_result]):
                return array_ops.identity(ctx.last_step_outputs)

        # TODO(sourabhbajaj): The input to while loop should be based on the output
        # type of the step_fn
        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        [initial_loop_values])

        replicate_inputs = [[]] * self.num_towers
        outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        last_step_tensor_outputs = [list(x) for x in zip(*outputs)]

        # Take index [0] of last_step_tensor_outputs as we wrapped
        # initial_loop_values in a list in the `repeat` call.
        return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops),
                last_step_tensor_outputs[0], ctx)
Exemple #10
0
    def _run_steps_on_iterator_with_device_loop(self,
                                                fn,
                                                multi_worker_iterator,
                                                iterations,
                                                initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        replicate_inputs = [[]] * self._num_replicas_in_sync
        replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        # Filter out any ops from the outputs, typically this would be the case
        # when there were no tensor outputs.
        last_step_tensor_outputs = [
            x for x in replicate_outputs if not isinstance(x, ops.Operation)
        ]

        # Outputs are currently of the structure (grouped by device)
        # [[output0_device0, output1_device0, output2_device0],
        #  [output0_device1, output1_device1, output2_device1]]
        # Convert this to the following structure instead: (grouped by output)
        # [[output0_device0, output0_device1],
        #  [output1_device0, output1_device1],
        #  [output2_device0, output2_device1]]
        last_step_tensor_outputs = [
            list(x) for x in zip(*last_step_tensor_outputs)
        ]

        _set_last_step_outputs(ctx, last_step_tensor_outputs)
        return ctx
Exemple #11
0
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):

    shapes = nest.flatten(iterator.output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = nest.flatten(iterator.output_types)

    def enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      with ops.device(self._host):
        for _ in range(self.num_towers):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())
            control_deps.extend(inputs)
            sharded_inputs.append(inputs)

      enqueue_ops = []
      for core_id, shard_input in enumerate(sharded_inputs):
        enqueue_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=shard_input, shapes=shapes, device_ordinal=core_id))
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      with ops.control_dependencies(enqueue_ops_fn()):
        return i + 1

    with ops.device(self._host):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          enqueue_ops_loop_body,
          [constant_op.constant(0)],
          parallel_iterations=1)

    def dequeue_fn():
      dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(iterator.output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = values.MultiStepContext()
    def run_fn(*args, **kwargs):
      del args, kwargs
      fn_result = fn(ctx, dequeue_fn())
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    # TODO(sourabhbajaj): The input to while loop should be based on the output
    # type of the step_fn
    def iterate_on_tpu():
      return training_loop.repeat(iterations, run_fn, initial_loop_values)

    replicate_inputs = [[]] * self.num_towers
    replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
    ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

    # Filter out any ops from the outputs, typically this would be the case
    # when there were no tensor outputs.
    last_step_tensor_outputs = [x for x in replicate_outputs
                                if not isinstance(x, ops.Operation)]

    # Outputs are currently of the structure (grouped by device)
    # [[output0_device0, output1_device0, output2_device0],
    #  [output0_device1, output1_device1, output2_device1]]
    # Convert this to the following structure instead: (grouped by output)
    # [[output0_device0, output0_device1],
    #  [output1_device0, output1_device1],
    #  [output2_device0, output2_device1]]
    last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)]

    # Convert replicate_outputs to the original dict structure of
    # last_step_outputs.
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, take the first value
      # from the list as each value should be the same. Else return the full
      # list of values.
      if aggregation is not variables_lib.VariableAggregation.NONE:
        # TODO(priyag): Should this return the element or a list with 1 element
        last_step_tensor_outputs_dict[name] = output[0]
    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

    return ctx
  def _experimental_run_steps_on_iterator(
      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
    output_shapes = multi_worker_iterator.output_shapes
    shapes = nest.flatten(output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          "TPU currently requires fully defined shapes. Either use "
          "set_shape() on the input tensors or use "
          "dataset.batch(..., drop_remainder=True).")
    types = nest.flatten(multi_worker_iterator.output_types)

    enqueue_ops = [
        self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes,
                                      iterations)
        for host_id in range(self.num_hosts)]

    def dequeue_fn():
      dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = values.MultiStepContext()
    def run_fn(*args, **kwargs):
      """Single step on the TPU device."""
      del args, kwargs
      fn_inputs = dequeue_fn()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, fn_inputs)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    # TODO(sourabhbajaj): The input to while loop should be based on the output
    # type of the step_fn
    def iterate_on_tpu():
      return training_loop.repeat(iterations, run_fn, initial_loop_values)

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop and TPU replicate context. This is useful in cases
    # where we might need to exit these contexts and get back to the outer
    # context to do some things, for e.g. create an op which should be
    # evaluated only once at the end of the loop on the host. One such usage
    # is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    replicate_inputs = [[]] * self._num_replicas_in_sync
    replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
    del self._outer_control_flow_context
    ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

    # Filter out any ops from the outputs, typically this would be the case
    # when there were no tensor outputs.
    last_step_tensor_outputs = [x for x in replicate_outputs
                                if not isinstance(x, ops.Operation)]

    # Outputs are currently of the structure (grouped by device)
    # [[output0_device0, output1_device0, output2_device0],
    #  [output0_device1, output1_device1, output2_device1]]
    # Convert this to the following structure instead: (grouped by output)
    # [[output0_device0, output0_device1],
    #  [output1_device0, output1_device1],
    #  [output2_device0, output2_device1]]
    last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)]

    # Convert replicate_outputs to the original dict structure of
    # last_step_outputs.
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been reduced, take the first value
      # from the list as each value should be the same. Else return the full
      # list of values.
      # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
      # value.
      if reduce_op is not None:
        # TODO(priyag): Should this return the element or a list with 1 element
        last_step_tensor_outputs_dict[name] = output[0]
    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

    return ctx
  def _run_steps_on_iterator_with_device_loop(
      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
    output_shapes = multi_worker_iterator.output_shapes
    shapes = nest.flatten(output_shapes)
    if any(not s.is_fully_defined() for s in shapes):
      raise ValueError(
          "TPU currently requires fully defined shapes. Either use "
          "set_shape() on the input tensors or use "
          "dataset.batch(..., drop_remainder=True).")
    types = nest.flatten(multi_worker_iterator.output_types)

    enqueue_ops = [
        self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes,
                                      iterations)
        for host_id in range(self.num_hosts)]

    def dequeue_fn():
      dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = input_lib.MultiStepContext()

    def run_fn(*args, **kwargs):
      """Single step on the TPU device."""
      del args, kwargs
      fn_result = fn(ctx, dequeue_fn())
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    def iterate_on_tpu():
      return training_loop.repeat(iterations, run_fn, initial_loop_values)

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop and TPU replicate context. This is useful in cases
    # where we might need to exit these contexts and get back to the outer
    # context to do some things, for e.g. create an op which should be
    # evaluated only once at the end of the loop on the host. One such usage
    # is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    replicate_inputs = [[]] * self._num_replicas_in_sync
    replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)

    del self._outer_control_flow_context
    ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

    # Filter out any ops from the outputs, typically this would be the case
    # when there were no tensor outputs.
    last_step_tensor_outputs = [x for x in replicate_outputs
                                if not isinstance(x, ops.Operation)]

    # Outputs are currently of the structure (grouped by device)
    # [[output0_device0, output1_device0, output2_device0],
    #  [output0_device1, output1_device1, output2_device1]]
    # Convert this to the following structure instead: (grouped by output)
    # [[output0_device0, output0_device1],
    #  [output1_device0, output1_device1],
    #  [output2_device0, output2_device1]]
    last_step_tensor_outputs = [list(x) for x in
                                zip(*last_step_tensor_outputs)]

    _set_last_step_outputs(ctx, last_step_tensor_outputs)
    return ctx