Пример #1
0
  def testVarArgsAndDefaults(self):
    """Tests that arg checker works for a function with varargs and defaults."""

    def func(x, y, z=17, *q):
      return x + y + z + len(q)

    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 2, None))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 3, None))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 4, None))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 5, None))
    self.assertEqual("at least 2 arguments",
                     tpu_function.check_function_argument_count(func, 1, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 1, queue))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 2, queue))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 3, queue))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 4, queue))
    self.assertEqual("at least 2 arguments",
                     tpu_function.check_function_argument_count(func, 0, queue))
    def testVarArgsAndDefaults(self):
        """Tests that arg checker works for a function with varargs and defaults."""
        def func(x, y, z=17, *q):
            return x + y + z + len(q)

        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 4, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 5, None))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 1, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 1, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 4, queue))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 0, queue))
    def testDefaultArgs(self):
        """Tests that arg checker works for a function with no varargs."""
        def func(x, y, z=17):
            return x + y + z

        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, None))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 1, None))
        self.assertEqual(
            "at most 3 arguments",
            tpu_function.check_function_argument_count(func, 4, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 1, queue))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 0, queue))
        self.assertEqual(
            "at most 3 arguments",
            tpu_function.check_function_argument_count(func, 4, queue))
Пример #4
0
  def testSimple(self):
    """Tests that arg checker works for functions with no varargs or defaults.
    """

    def func(x, y, z):
      return x + y + z

    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 3, None))
    self.assertEqual("exactly 3 arguments",
                     tpu_function.check_function_argument_count(func, 2, None))
    queue = tpu_feed.InfeedQueue(2)
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 1, queue))
    self.assertEqual("exactly 3 arguments",
                     tpu_function.check_function_argument_count(func, 2, queue))
    def testSimple(self):
        """Tests that arg checker works for functions with no varargs or defaults.
    """
        def func(x, y, z):
            return x + y + z

        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, None))
        self.assertEqual(
            "exactly 3 arguments",
            tpu_function.check_function_argument_count(func, 2, None))
        queue = tpu_feed.InfeedQueue(2)
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 1, queue))
        self.assertEqual(
            "exactly 3 arguments",
            tpu_function.check_function_argument_count(func, 2, queue))
Пример #6
0
  def testDefaultArgs(self):
    """Tests that arg checker works for a function with no varargs."""

    def func(x, y, z=17):
      return x + y + z

    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 3, None))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 2, None))
    self.assertEqual("at least 2 arguments",
                     tpu_function.check_function_argument_count(func, 1, None))
    self.assertEqual("at most 3 arguments",
                     tpu_function.check_function_argument_count(func, 4, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 2, queue))
    self.assertEqual(None,
                     tpu_function.check_function_argument_count(func, 1, queue))
    self.assertEqual("at least 2 arguments",
                     tpu_function.check_function_argument_count(func, 0, queue))
    self.assertEqual("at most 3 arguments",
                     tpu_function.check_function_argument_count(func, 4, queue))
Пример #7
0
def replicate(computation,
              inputs=None,
              infeed_queue=None,
              device_assignment=None,
              name=None):
    """Builds a graph operator that runs a replicated TPU computation.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: The name of the operator.
  Returns:
    A list of lists of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
    if name is None:
        name = "TPUReplicate"
    inputs = [[]] if inputs is None else inputs

    metadata_kwargs = {}
    if device_assignment is not None:
        # Turn the Numpy array into a flattened list so we can pass it as an
        # operator attribute.
        metadata_kwargs = {
            "topology":
            device_assignment.topology.serialized(),
            "device_assignment":
            device_assignment.core_assignment.flatten().tolist(),
            "computation_shape":
            device_assignment.computation_shape.tolist()
        }

    if ((not isinstance(inputs, list))
            or any(not isinstance(inp, (list, tuple)) for inp in inputs)):
        raise TypeError(
            "tpu.replicate() inputs must be a list of lists/tuples")

    num_replicas = len(inputs)

    # No replicas? Nothing to do.
    if num_replicas == 0:
        return []

    # Converts inputs to Tensors.
    inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

    # Verifies that all replicas have matching numbers and types of inputs
    input_types = [x.dtype for x in inputs[0]]
    input_arity = len(input_types)
    for i in range(num_replicas):
        if len(inputs[i]) != input_arity:
            raise ValueError("Replicas must have the same number of inputs. "
                             "Replica 0 had {} inputs, replica {} had {} "
                             "inputs.".format(input_arity, i, len(inputs[i])))

        types = [x.dtype for x in inputs[i]]
        if types != input_types:
            raise ValueError(
                "Replicas must have matching input types. Replica 0 had "
                "input types {}, replica {} had input types {}".format(
                    input_types, i, types))

    arg_error = tpu_function.check_function_argument_count(
        computation, input_arity, infeed_queue)
    if arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s, but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]), arg_error))
        else:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s and %d additional inputs from infeed,"
                " but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]),
                 infeed_queue.number_of_tuple_elements, arg_error))

    graph = ops.get_default_graph()

    with ops.name_scope(name, "replicate"):
        # Fan-in: Builds a TPUReplicatedInput node for each input.
        computation_inputs = []
        for i in range(0, input_arity):
            replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
            computation_inputs.append(
                tpu_ops.tpu_replicated_input(replicas,
                                             name="input{}".format(i)))

        context = TPUReplicateContext(name=graph.unique_name("cluster"))
        try:
            context.Enter()

            metadata = tpu_ops.tpu_replicate_metadata(
                num_replicas=num_replicas, **metadata_kwargs)

            with tpu_function.tpu_shard_context(
                    num_replicas), ops.control_dependencies([metadata]):

                # The EncapsulateTPUComputations rewrite needs to identify the
                # replicated arguments inside each computation. Adds identity operators
                # tagged with an attribute _tpu_replicated_input to identify the
                # replicated inputs.
                # pylint: disable=protected-access
                with graph._attr_scope({
                        "_tpu_replicated_input":
                        attr_value_pb2.AttrValue(b=True)
                }):
                    computation_inputs = [
                        array_ops.identity(
                            x, name="replicated_input_{}".format(i))
                        for i, x in enumerate(computation_inputs)
                    ]
                # pylint: enable=protected-access

                # If there is an infeed queue, adds the dequeued values to the
                # computation's inputs.
                if infeed_queue is not None:
                    infeed_queue.set_number_of_shards(num_replicas)
                    for t in infeed_queue.generate_dequeue_op():
                        computation_inputs.append(t)

                # Only resource variables work inside a TPU computation, so turn on
                # resource variables for the computation.
                # TODO(phawkins): consider removing this code. It will
                # be less confusing to clients if they knowingly choose to use resource
                # variables.
                vscope = variable_scope.get_variable_scope()
                saved_use_resource = vscope.use_resource
                vscope.set_use_resource(True)

                outputs = computation(*computation_inputs)

                vscope.set_use_resource(saved_use_resource)

            # If the computation only returned one value, makes it a tuple.
            if not isinstance(outputs, (list, tuple)):
                outputs = (outputs, )

            try:
                with ops.device(core(0)):
                    outputs = [
                        o if isinstance(o, ops.Operation) else
                        ops.convert_to_tensor(o) for o in outputs
                    ]
            except Exception as e:
                raise ValueError(
                    "TPU function return values must all either be Operations or "
                    "convertible to Tensors. Got '%s'" % str(e))

            # Separates the returned Operations and Tensors.
            output_operations = [
                o for o in outputs if isinstance(o, ops.Operation)
            ]
            output_tensors = [
                o for o in outputs if not isinstance(o, ops.Operation)
            ]

            if outputs != output_tensors + output_operations:
                raise ValueError(
                    "TPU functions must return zero-or more Tensor values followed by "
                    "zero or more Operations.")
            output_arity = len(output_tensors)

            # Wraps outputs in Identity ops. Otherwise a replicated input copied
            # straight to an output would bypass the replicate(). This would be bad
            # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
            # be rewritten away, leading to a runtime error.
            # TODO(phawkins): extend the rewrite to elide these nodes instead.
            new_output_tensors = []
            for t in output_tensors:
                with ops.device(t.device if t.device else core(0)):
                    new_output_tensors.append(array_ops.identity(t))
            output_tensors = new_output_tensors
        finally:
            context.report_unsupported_operations()
            context.Exit()

        # Fan-out: Builds a TPUReplicatedOutput node for each output.
        outputs = [
            tpu_ops.tpu_replicated_output(output_tensors[i],
                                          num_replicas,
                                          name="output{}".format(i))
            for i in xrange(output_arity)
        ]

        with ops.control_dependencies(output_operations):
            if output_arity == 0:
                # Returns a list of NoOps dependent on the replication Op, indexed by
                # [replica_num].
                return [
                    control_flow_ops.no_op(name="%s_shard_%d" % (name, i))
                    for i in range(num_replicas)
                ]
            else:
                # Wraps the outputs in identity operators so the names of any possible
                # `fetch` nodes are preserved by the replication rewrite.
                return [[
                    array_ops.identity(outputs[out][replica],
                                       name="output_%d_shard_%d" %
                                       (out, replica))
                    for out in xrange(output_arity)
                ] for replica in xrange(num_replicas)]
Пример #8
0
def split_compile_and_replicate(computation,
                                inputs=None,
                                infeed_queue=None,
                                device_assignment=None,
                                name=None,
                                use_tpu=True):
    """Builds graph operators that runs compilation and replicated computation.

  This is a lower level interface than replicate that returns a separate compile
  and execute output tensor. In the generated graph the compile op feeds into
  the execute op and no additional compilation is incurred when running the
  compile op before the execute op. The compile op returns additional
  information about the compilation but does not return the compiled program.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: (Deprecated) Does nothing.
    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
      backends. Currently, only supports a default placement (computation is
      placed on GPU if one is available, and on CPU if not).
  Returns:
    A list of lists with the first list corresponding to the compile op and the
    second a list of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
    del name
    inputs = [[]] if inputs is None else inputs

    metadata_kwargs = {}
    if device_assignment is not None:
        # Turn the Numpy array into a flattened list so we can pass it as an
        # operator attribute.
        metadata_kwargs = {
            "topology":
            device_assignment.topology.serialized(),
            "device_assignment":
            device_assignment.core_assignment.flatten().tolist()
        }
        # TODO(phawkins): remove this case after the forward compatibility window
        # expires on 2018-10-5.
        if api_compat.forward_compatible(2018, 10, 5):
            metadata_kwargs["num_cores_per_replica"] = (
                device_assignment.num_cores_per_replica)
        else:
            metadata_kwargs["computation_shape"] = [
                device_assignment.num_cores_per_replica
            ]

    if ((not isinstance(inputs, list))
            or any(not isinstance(inp, (list, tuple)) for inp in inputs)):
        raise TypeError(
            "tpu.replicate() inputs must be a list of lists/tuples")

    num_replicas = len(inputs)

    # No replicas? Nothing to do.
    if num_replicas == 0:
        return []

    # Converts inputs to Tensors.
    inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

    # Verifies that all replicas have matching numbers and types of inputs
    input_types = [x.dtype for x in inputs[0]]
    input_arity = len(input_types)
    for i in range(num_replicas):
        if len(inputs[i]) != input_arity:
            raise ValueError("Replicas must have the same number of inputs. "
                             "Replica 0 had {} inputs, replica {} had {} "
                             "inputs.".format(input_arity, i, len(inputs[i])))

        types = [x.dtype for x in inputs[i]]
        if types != input_types:
            raise ValueError(
                "Replicas must have matching input types. Replica 0 had "
                "input types {}, replica {} had input types {}".format(
                    input_types, i, types))

    arg_error = tpu_function.check_function_argument_count(
        computation, input_arity, infeed_queue)
    if arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s, but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]), arg_error))
        else:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s and %d additional inputs from infeed,"
                " but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]),
                 infeed_queue.number_of_tuple_elements, arg_error))

    graph = ops.get_default_graph()

    # Fan-in: Builds a TPUReplicatedInput node for each input.
    computation_inputs = []
    for i in range(0, input_arity):
        replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
        computation_inputs.append(
            tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

    cluster_name = graph.unique_name("cluster")
    pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
    context = TPUReplicateContext(name=cluster_name,
                                  num_replicas=num_replicas,
                                  pivot=pivot)
    try:
        context.Enter()

        metadata = tpu_ops.tpu_replicate_metadata(num_replicas=num_replicas,
                                                  use_tpu=use_tpu,
                                                  **metadata_kwargs)

        with tpu_function.tpu_shard_context(
                num_replicas), ops.control_dependencies([metadata]):

            # Add identity ops so even unused inputs are "consumed" by the
            # computation. This is to avoid orphaned TPUReplicatedInput nodes.
            # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
            # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
            computation_inputs = [
                array_ops.identity(x, name="replicated_input_{}".format(i))
                for i, x in enumerate(computation_inputs)
            ]

            # If there is an infeed queue, adds the dequeued values to the
            # computation's inputs.
            if infeed_queue is not None:
                infeed_queue.set_number_of_shards(num_replicas)
                for t in infeed_queue.generate_dequeue_op():
                    computation_inputs.append(t)

            # Only resource variables work inside a TPU computation, so turn on
            # resource variables for the computation.
            # TODO(phawkins): consider removing this code. It will
            # be less confusing to clients if they knowingly choose to use resource
            # variables.
            # Partitioned variables is not supported (b/112311320).
            vscope = variable_scope.get_variable_scope()
            saved_use_resource = vscope.use_resource
            saved_custom_getter = vscope.custom_getter

            def custom_getter(getter, name, *args, **kwargs):
                """Variables on TPU have a few restrictions."""
                partitioner = kwargs["partitioner"]
                if partitioner is not None:
                    kwargs["partitioner"] = None
                    logging.warning(
                        "Partitioned variables are not supported on TPU. Got "
                        "`partitioner` that is {} for variable {}. "
                        "Setting `partitioner` to `None`.".format(
                            partitioner, name))
                if saved_custom_getter is None:
                    return getter(name, *args, **kwargs)
                else:
                    return saved_custom_getter(getter, name, *args, **kwargs)

            vscope.set_use_resource(True)
            vscope.set_custom_getter(custom_getter)

            outputs = computation(*computation_inputs)

            vscope.set_use_resource(saved_use_resource)
            vscope.set_custom_getter(saved_custom_getter)

        # If the computation returns `None`, make it an empty tuple.
        if outputs is None:
            outputs = tuple()
        # If the computation only returned one value, makes it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        # Append `no_op` here so that fetching any return value of this function
        # will trigger TPUExecute node.
        outputs += (control_flow_ops.no_op(), )
        try:
            with ops.device(core(0)):
                outputs = [
                    o if isinstance(o, ops.Operation) else
                    ops.convert_to_tensor(o) for o in outputs
                ]
        except Exception as e:
            raise ValueError(
                "TPU function return values must all either be Operations or "
                "convertible to Tensors. Got '%s'" % str(e))

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "TPU functions must return zero-or more Tensor values followed by "
                "zero or more Operations.")
        output_arity = len(output_tensors)

        # Wraps outputs in Identity ops. Otherwise a replicated input copied
        # straight to an output would bypass the replicate(). This would be bad
        # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
        # be rewritten away, leading to a runtime error.
        # TODO(phawkins): extend the rewrite to elide these nodes instead.
        new_output_tensors = []
        for t in output_tensors:
            with ops.device(t.device if t.device else core(0)):
                new_output_tensors.append(array_ops.identity(t))
        output_tensors = new_output_tensors
        context.ExitResult(output_tensors)
    finally:
        context.report_unsupported_operations()
        context.Exit()
        host_compute_core = context.HostComputeCore()

    if host_compute_core:
        attr_value = attr_value_pb2.AttrValue()
        attr_value.list.s.extend(
            [compat.as_bytes(x) for x in host_compute_core])
        metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access

    # Fan-out: Builds a TPUReplicatedOutput node for each output.
    outputs = [
        tpu_ops.tpu_replicated_output(output_tensors[i],
                                      num_replicas,
                                      name="output{}".format(i))
        for i in xrange(output_arity)
    ]

    with ops.control_dependencies([metadata]):
        if use_tpu:
            compile_status = tpu_ops.tpu_compilation_result()
            op = compile_status.op
            attr_value = attr_value_pb2.AttrValue(
                s=compat.as_bytes(cluster_name))
            op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
        else:
            compile_status = control_flow_ops.no_op(name="compilation_status")

    with ops.control_dependencies(output_operations):
        if output_arity == 0:
            # Returns a list of NoOps dependent on the replication Op, indexed by
            # [replica_num].
            return [
                compile_status,
                [
                    control_flow_ops.no_op(name="shard_%d" % i)
                    for i in range(num_replicas)
                ]
            ]
        else:
            # Wraps the outputs in identity operators so the names of any possible
            # `fetch` nodes are preserved by the replication rewrite.
            return [
                compile_status,
                [[
                    array_ops.identity(outputs[out][replica],
                                       name="output_%d_shard_%d" %
                                       (out, replica))
                    for out in xrange(output_arity)
                ] for replica in xrange(num_replicas)]
            ]
Пример #9
0
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
  """Builds a training loop for TPUs.

  The set of loop-carried tensors corresponds to `inputs`.  Both
  `condition` and `body` take the current value of the loop-carried
  tensors. 'body' additionally takes a tuple of infeed from
  infeed_queue if infeed_queue is not None. `condition` must return a
  single boolean value that determines whether iteration
  continues. `body` must return an updated list of values for the
  loop-carried tensors.

  Args:
    condition: a Python function that builds the loop condition.
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the training loop, or
      None (equivalent to an empty list).
    infeed_queue: if not None, the infeed queue from which to append a tuple
      of arguments as inputs to condition.
    name: an optional name for the loop.

  Returns:
    The final values of the loop-carried tensors.

  Raises:
    TypeError: if body or condition has the wrong signature.
  """

  # Converts inputs to Tensors.
  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
                                      x in inputs]
  input_types = [x.dtype for x in inputs]
  input_arity = len(inputs)

  body_arg_error = tpu_function.check_function_argument_count(
      body, input_arity, infeed_queue)
  if body_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s, but the loop body needs %s" % (
              input_arity, str([i.name for i in inputs]), body_arg_error))
    else:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s and %d additional inputs from "
          "infeed, but the computation needs %s" % (input_arity, str(
              [i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
                                                    body_arg_error))
  condition_arg_error = tpu_function.check_function_argument_count(
      condition, input_arity, None)
  if condition_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop condition function cannot be called with the "
          "specified inputs. You specified %d inputs: %s, but the loop "
          "condition needs %s" % (input_arity, str([i.name for i in inputs]),
                                  condition_arg_error))
    else:
      raise TypeError(
          "Supplied loop condition function cannot be called with the "
          "specified inputs. You specified %d inputs: %s, but the loop "
          "condition needs %s. Note that infeed is not passed to the loop "
          "condition." % (input_arity, str([i.name for i in inputs]),
                          condition_arg_error))

  def condition_wrapper(*inputs):
    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []
    return condition(*inputs)

  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO (phawkins): in principle this is too restrictive since it serializes id:1047 gh:1048
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      return control_flow_ops.tuple(output_tensors,
                                    control_inputs=output_operations)
    else:
      return output_tensors

  # If the body has arity 0, add a dummy loop-carried value to which we can add
  # control dependencies from any side-effecting operations.
  if input_arity == 0:
    inputs = [array_ops.constant(0)]
  return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs,
                                     name=name)
Пример #10
0
def _compile_internal(computation, inputs=None):
  """Builds graph operators that compiles and symbolically executes computation.

  Args:
    computation: A Python function that builds the computation to compile and
      execute.
    inputs: A list of input tensors or `None` (equivalent to `[]`). Its order
      should match ordering of computation arguments.
  Returns:
    A list of output tensors from computation.
  Raises:
    ValueError: If any element in computation outputs is neither an operations
      or a value that can be converted to tensor.
    TypeError: If `inputs` is not a list or tuple.
  """
  if inputs is None:
    inputs = []

  if not isinstance(inputs, collections.Sequence):
    raise TypeError('inputs must be a list')

  # Converts inputs to Tensors.
  inputs = [ops.convert_to_tensor(x) for x in inputs]
  input_arity = len(inputs)

  arg_error = tpu_function.check_function_argument_count(
      computation, input_arity, infeed_queue=None)
  if arg_error is not None:
    raise TypeError(
        'Supplied computation cannot be called with the specified inputs. You '
        'specified %d inputs: %s, but the computation needs %s' %
        (input_arity, str([i.name for i in inputs[0]]), arg_error))

  cluster_name = ops.get_default_graph().unique_name('cluster')
  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
  context = XLACompileContext(name=cluster_name, pivot=pivot)
  try:
    context.Enter()

    # Add identity ops so even unused inputs are 'consumed' by the
    # computation.
    computation_inputs = [
        array_ops.identity(x, name='input_{}'.format(i))
        for i, x in enumerate(inputs)
    ]

    # Only resource variables work inside an XLA computation, so turn on
    # resource variables for the computation.
    vscope = variable_scope.get_variable_scope()
    saved_use_resource = vscope.use_resource
    vscope.set_use_resource(True)

    outputs = computation(*computation_inputs)

    # Restore variable scope after computation.
    vscope.set_use_resource(saved_use_resource)

    # If the computation returns `None`, make it an empty tuple.
    if outputs is None:
      outputs = tuple()
    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, collections.Sequence):
      outputs = (outputs,)

    # Append `no_op` here so that return value of this function always contains
    # at least one op that can trigger XlaLaunch node.
    outputs += (control_flow_ops.no_op(),)
    try:
      outputs = [
          o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
          for o in outputs
      ]
    except Exception as e:
      raise ValueError(
          'XLA computation function return values must all either be Operations'
          ' or convertible to Tensors. Got error: "%s"' % str(e))

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          'XLA computation function must return zero or more Tensor values '
          'followed by zero or more Operations.')
    output_arity = len(output_tensors)

    new_output_tensors = []
    for t in output_tensors:
      with ops.device(t.device if t.device else ''):
        new_output_tensors.append(array_ops.identity(t))

    output_tensors = new_output_tensors
    context.ExitResult(output_tensors)
  finally:
    context.report_unsupported_operations()
    context.Exit()

  outputs = [
      xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i))
      for i in xrange(output_arity)
  ]

  with ops.control_dependencies(output_operations):
    if output_arity == 0:
      # When XLA computation returns only operations and no tensors, a NoOp
      # dependent on the operations in outputs is returned. Otherwise final
      # outputs would be empty and there is no way to trigger returned
      # operations.
      return control_flow_ops.no_op(name='output_0')
    else:
      # Wraps the outputs in identity operators that carries control
      # dependencies.
      return [
          array_ops.identity(outputs[i], name='output_%d' % i)
          for i in xrange(output_arity)
      ]
Пример #11
0
def split_compile_and_replicate(computation,
                                inputs=None,
                                infeed_queue=None,
                                device_assignment=None,
                                name=None,
                                use_tpu=True):
  """Builds graph operators that runs compilation and replicated computation.

  This is a lower level interface than replicate that returns a separate compile
  and execute output tensor. In the generated graph the compile op feeds into
  the execute op and no additional compilation is incurred when running the
  compile op before the execute op. The compile op returns additional
  information about the compilation but does not return the compiled program.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: (Deprecated) Does nothing.
    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
      backends. Currently, only supports a default placement (computation is
      placed on GPU if one is available, and on CPU if not).
  Returns:
    A list of lists with the first list corresponding to the compile op and the
    second a list of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
  del name
  inputs = [[]] if inputs is None else inputs

  metadata_kwargs = {}
  if device_assignment is not None:
    # Turn the Numpy array into a flattened list so we can pass it as an
    # operator attribute.
    metadata_kwargs = {
        "topology":
            device_assignment.topology.serialized(),
        "device_assignment":
            device_assignment.core_assignment.flatten().tolist(),
        "computation_shape":
            device_assignment.computation_shape.tolist()
    }

  if ((not isinstance(inputs, list)) or
      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")

  num_replicas = len(inputs)

  # No replicas? Nothing to do.
  if num_replicas == 0:
    return []

  # Converts inputs to Tensors.
  inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

  # Verifies that all replicas have matching numbers and types of inputs
  input_types = [x.dtype for x in inputs[0]]
  input_arity = len(input_types)
  for i in range(num_replicas):
    if len(inputs[i]) != input_arity:
      raise ValueError("Replicas must have the same number of inputs. "
                       "Replica 0 had {} inputs, replica {} had {} "
                       "inputs.".format(input_arity, i, len(inputs[i])))

    types = [x.dtype for x in inputs[i]]
    if types != input_types:
      raise ValueError(
          "Replicas must have matching input types. Replica 0 had "
          "input types {}, replica {} had input types {}".format(
              input_types, i, types))

  arg_error = tpu_function.check_function_argument_count(
      computation, input_arity, infeed_queue)
  if arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s, but the computation needs %s" % (
              input_arity, str([i.name for i in inputs[0]]), arg_error))
    else:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s and %d additional inputs from infeed,"
          " but the computation needs %s" % (input_arity, str(
              [i.name
               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
                                             arg_error))

  graph = ops.get_default_graph()

  # Fan-in: Builds a TPUReplicatedInput node for each input.
  computation_inputs = []
  for i in range(0, input_arity):
    replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
    computation_inputs.append(
        tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

  cluster_name = graph.unique_name("cluster")
  pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
  context = TPUReplicateContext(
      name=cluster_name, num_replicas=num_replicas, pivot=pivot)
  try:
    context.Enter()

    metadata = tpu_ops.tpu_replicate_metadata(
        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)

    with tpu_function.tpu_shard_context(
        num_replicas), ops.control_dependencies([metadata]):

      # Add identity ops so even unused inputs are "consumed" by the
      # computation. This is to avoid orphaned TPUReplicatedInput nodes.
      # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
      # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
      computation_inputs = [
          array_ops.identity(x, name="replicated_input_{}".format(i))
          for i, x in enumerate(computation_inputs)
      ]

      # If there is an infeed queue, adds the dequeued values to the
      # computation's inputs.
      if infeed_queue is not None:
        infeed_queue.set_number_of_shards(num_replicas)
        for t in infeed_queue.generate_dequeue_op():
          computation_inputs.append(t)

      # Only resource variables work inside a TPU computation, so turn on
      # resource variables for the computation.
      # TODO(phawkins): consider removing this code. It will
      # be less confusing to clients if they knowingly choose to use resource
      # variables.
      # Partitioned variables is not supported (b/112311320).
      def custom_getter(getter, name, *args, **kwargs):
        """Variables on TPU have a few restrictions."""
        partitioner = kwargs["partitioner"]
        if partitioner is not None:
          kwargs["partitioner"] = None
          logging.warning(
              "Partitioned variables are not supported on TPU. Got "
              "`partitioner` that is {} for variable {}. "
              "Setting `partitioner` to `None`."
              .format(partitioner, name))
        return getter(name, *args, **kwargs)

      vscope = variable_scope.get_variable_scope()

      saved_use_resource = vscope.use_resource
      saved_custom_getter = vscope.custom_getter

      vscope.set_use_resource(True)
      vscope.set_custom_getter(custom_getter)

      outputs = computation(*computation_inputs)

      vscope.set_use_resource(saved_use_resource)
      vscope.set_custom_getter(saved_custom_getter)

    # If the computation returns `None`, make it an empty tuple.
    if outputs is None:
      outputs = tuple()
    # If the computation only returned one value, makes it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    # Append `no_op` here so that fetching any return value of this function
    # will trigger TPUExecute node.
    outputs += (control_flow_ops.no_op(),)
    try:
      with ops.device(core(0)):
        outputs = [
            o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
            for o in outputs
        ]
    except Exception as e:
      raise ValueError(
          "TPU function return values must all either be Operations or "
          "convertible to Tensors. Got '%s'" % str(e))

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU functions must return zero-or more Tensor values followed by "
          "zero or more Operations.")
    output_arity = len(output_tensors)

    # Wraps outputs in Identity ops. Otherwise a replicated input copied
    # straight to an output would bypass the replicate(). This would be bad
    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
    # be rewritten away, leading to a runtime error.
    # TODO(phawkins): extend the rewrite to elide these nodes instead.
    new_output_tensors = []
    for t in output_tensors:
      with ops.device(t.device if t.device else core(0)):
        new_output_tensors.append(array_ops.identity(t))
    output_tensors = new_output_tensors
    context.ExitResult(output_tensors)
  finally:
    context.report_unsupported_operations()
    context.Exit()
    host_compute_core = context.HostComputeCore()

  if host_compute_core:
    attr_value = attr_value_pb2.AttrValue()
    attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access

  # Fan-out: Builds a TPUReplicatedOutput node for each output.
  outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
                                           name="output{}".format(i))
             for i in xrange(output_arity)]

  with ops.control_dependencies([metadata]):
    if use_tpu:
      compile_status = tpu_ops.tpu_compilation_result()
      op = compile_status.op
      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
    else:
      compile_status = control_flow_ops.no_op(name="compilation_status")

  with ops.control_dependencies(output_operations):
    if output_arity == 0:
      # Returns a list of NoOps dependent on the replication Op, indexed by
      # [replica_num].
      return [
          compile_status, [
              control_flow_ops.no_op(name="shard_%d" % i)
              for i in range(num_replicas)
          ]
      ]
    else:
      # Wraps the outputs in identity operators so the names of any possible
      # `fetch` nodes are preserved by the replication rewrite.
      return [
          compile_status, [[
              array_ops.identity(
                  outputs[out][replica],
                  name="output_%d_shard_%d" % (out, replica))
              for out in xrange(output_arity)
          ]
                           for replica in xrange(num_replicas)]
      ]
Пример #12
0
def replicate(computation,
              inputs=None,
              infeed_queue=None,
              global_tpu_id=None,
              name=None):
  """Builds a graph operator that runs a replicated TPU computation.

  Args:
    computation: a Python function that builds the computation to replicate.
    inputs: a list of lists of input tensors or None (equivalent to
      [[]]), indexed by [replica_num][input_num]. All replicas must
      have the same number of inputs.
    infeed_queue: if not None, the InfeedQueue from which to append a tuple
      of arguments as inputs to computation.
    global_tpu_id: if not None, a Numpy 2D array indicating the global
      id of each TPU device in the system. The outer dimension of the
      array is host task id, and the inner dimension is device ordinal,
      so e.g., global_tpu_id[x][y] indicates the global id of device
      /task:x/device:TPU_NODE:y.
    name: name of the operator.
  Returns:
    A list of lists of output tensors, indexed by [replica_num][output_num].
  Raises:
    ValueError: if all replicas do not have equal numbers of input tensors.
    ValueError: if the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
  if name is None:
    name = "TPUReplicate"
  inputs = [[]] if inputs is None else inputs

  if global_tpu_id is not None:
    # Turn the Numpy array into a flattened list.
    global_tpu_id = global_tpu_id.flatten().tolist()

  if ((not isinstance(inputs, list)) or
      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")

  num_replicas = len(inputs)

  # No replicas? Nothing to do.
  if num_replicas == 0:
    return []

  # Converts inputs to Tensors.
  inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

  # Verifies that all replicas have matching numbers and types of inputs
  input_types = [x.dtype for x in inputs[0]]
  input_arity = len(input_types)
  for i in range(num_replicas):
    if len(inputs[i]) != input_arity:
      raise ValueError("Replicas must have the same number of inputs. "
                       "Replica 0 had {} inputs, replica {} had {} "
                       "inputs.".format(input_arity, i, len(inputs[i])))

    types = [x.dtype for x in inputs[i]]
    if types != input_types:
      raise ValueError(
          "Replicas must have matching input types. Replica 0 had "
          "input types {}, replica {} had input types {}".format(
              input_types, i, types))

  arg_error = tpu_function.check_function_argument_count(
      computation, input_arity, infeed_queue)
  if arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s, but the computation needs %s" % (
              input_arity, str([i.name for i in inputs[0]]), arg_error))
    else:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s and %d additional inputs from infeed,"
          " but the computation needs %s" % (input_arity, str(
              [i.name
               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
                                             arg_error))

  graph = ops.get_default_graph()

  with ops.name_scope(name, "replicate"):
    # Fan-in: Builds a TPUReplicatedInput node for each input.
    computation_inputs = []
    for i in range(0, input_arity):
      replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
      computation_inputs.append(
          tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

    context = TPUReplicateContext(name=graph.unique_name("cluster"))
    try:
      context.Enter()

      metadata = tpu_ops.tpu_replicate_metadata(
          num_replicas=num_replicas, global_tpu_id=global_tpu_id)

      with tpu_function.tpu_shard_context(
          num_replicas), ops.control_dependencies([metadata]):

        # The EncapsulateTPUComputations rewrite needs to identify the
        # replicated arguments inside each computation. Adds identity operators
        # tagged with an attribute _tpu_replicated_input to identify the
        # replicated inputs.
        # pylint: disable=protected-access
        with graph._attr_scope({"_tpu_replicated_input":
                                attr_value_pb2.AttrValue(b=True)}):
          computation_inputs = [
              array_ops.identity(x, name="replicated_input_{}".format(i))
              for i, x in enumerate(computation_inputs)]
        # pylint: enable=protected-access

        # If there is an infeed queue, adds the dequeued values to the
        # computation's inputs.
        if infeed_queue is not None:
          infeed_queue.set_number_of_shards(num_replicas)
          for t in infeed_queue.generate_dequeue_op():
            computation_inputs.append(t)

        # Only resource variables work inside a TPU computation, so turn on
        # resource variables for the computation.
        # TODO(phawkins): consider removing this code. It will
        # be less confusing to clients if they knowingly choose to use resource
        # variables.
        vscope = variable_scope.get_variable_scope()
        saved_use_resource = vscope.use_resource
        vscope.set_use_resource(True)

        outputs = computation(*computation_inputs)

        vscope.set_use_resource(saved_use_resource)

      # If the computation only returned one value, makes it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)

      try:
        with ops.device(core(0)):
          outputs = [
              o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
              for o in outputs
          ]
      except Exception as e:
        raise ValueError(
            "TPU function return values must all either be Operations or "
            "convertible to Tensors. Got '%s'" % str(e))

      # Separates the returned Operations and Tensors.
      output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
      output_tensors = [o for o in outputs
                        if not isinstance(o, ops.Operation)]

      if outputs != output_tensors + output_operations:
        raise ValueError(
            "TPU functions must return zero-or more Tensor values followed by "
            "zero or more Operations.")
      output_arity = len(output_tensors)

      # Wraps outputs in Identity ops. Otherwise a replicated input copied
      # straight to an output would bypass the replicate(). This would be bad
      # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
      # be rewritten away, leading to a runtime error.
      # TODO(phawkins): extend the rewrite to elide these nodes instead.
      new_output_tensors = []
      for t in output_tensors:
        with ops.device(t.device if t.device else core(0)):
          new_output_tensors.append(array_ops.identity(t))
      output_tensors = new_output_tensors
    finally:
      context.Exit()

    # Fan-out: Builds a TPUReplicatedOutput node for each output.
    outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
                                             name="output{}".format(i))
               for i in xrange(output_arity)]

    with ops.control_dependencies(output_operations):
      if output_arity == 0:
        # Returns a list of NoOps dependent on the replication Op, indexed by
        # [replica_num].
        return [
            control_flow_ops.no_op(name="%s_shard_%d" % (name, i))
            for i in range(num_replicas)
        ]
      else:
        # Wraps the outputs in identity operators so the names of any possible
        # `fetch` nodes are preserved by the replication rewrite.
        return [
            [array_ops.identity(outputs[out][replica],
                                name="output_%d_shard_%d" % (out, replica))
             for out in xrange(output_arity)]
            for replica in xrange(num_replicas)
        ]
Пример #13
0
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
  """Builds a training loop for TPUs.

  The set of loop-carried tensors corresponds to `inputs`.  Both
  `condition` and `body` take the current value of the loop-carried
  tensors. 'body' additionally takes a tuple of infeed from
  infeed_queue if infeed_queue is not None. `condition` must return a
  single boolean value that determines whether iteration
  continues. `body` must return an updated list of values for the
  loop-carried tensors.

  Args:
    condition: a Python function that builds the loop condition.
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the training loop, or
      None (equivalent to an empty list).
    infeed_queue: if not None, the infeed queue from which to append a tuple
      of arguments as inputs to condition.
    name: an optional name for the loop.

  Returns:
    The final values of the loop-carried tensors.

  Raises:
    TypeError: if body or condition has the wrong signature.
  """

  # Converts inputs to Tensors.
  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
                                      x in inputs]
  input_types = [x.dtype for x in inputs]
  input_arity = len(inputs)

  body_arg_error = tpu_function.check_function_argument_count(
      body, input_arity, infeed_queue)
  if body_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s, but the loop body needs %s" % (
              input_arity, str([i.name for i in inputs]), body_arg_error))
    else:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s and %d additional inputs from "
          "infeed, but the computation needs %s" % (input_arity, str(
              [i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
                                                    body_arg_error))
  condition_arg_error = tpu_function.check_function_argument_count(
      condition, input_arity, None)
  if condition_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop condition function cannot be called with the "
          "specified inputs. You specified %d inputs: %s, but the loop "
          "condition needs %s" % (input_arity, str([i.name for i in inputs]),
                                  condition_arg_error))
    else:
      raise TypeError(
          "Supplied loop condition function cannot be called with the "
          "specified inputs. You specified %d inputs: %s, but the loop "
          "condition needs %s. Note that infeed is not passed to the loop "
          "condition." % (input_arity, str([i.name for i in inputs]),
                          condition_arg_error))

  def condition_wrapper(*inputs):
    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []
    return condition(*inputs)

  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO(phawkins): in principle this is too restrictive since it serializes
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      return control_flow_ops.tuple(output_tensors,
                                    control_inputs=output_operations)
    else:
      return output_tensors

  # If the body has arity 0, add a dummy loop-carried value to which we can add
  # control dependencies from any side-effecting operations.
  if input_arity == 0:
    inputs = [array_ops.constant(0)]
  return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs,
                                     name=name)
Пример #14
0
def replicate(computation,
              inputs=None,
              infeed_queue=None,
              device_assignment=None,
              name=None):
  """Builds a graph operator that runs a replicated TPU computation.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: (Deprecated) Does nothing.
  Returns:
    A list of lists of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
  del name
  inputs = [[]] if inputs is None else inputs

  metadata_kwargs = {}
  if device_assignment is not None:
    # Turn the Numpy array into a flattened list so we can pass it as an
    # operator attribute.
    metadata_kwargs = {
        "topology":
            device_assignment.topology.serialized(),
        "device_assignment":
            device_assignment.core_assignment.flatten().tolist(),
        "computation_shape":
            device_assignment.computation_shape.tolist()
    }

  if ((not isinstance(inputs, list)) or
      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")

  num_replicas = len(inputs)

  # No replicas? Nothing to do.
  if num_replicas == 0:
    return []

  # Converts inputs to Tensors.
  inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

  # Verifies that all replicas have matching numbers and types of inputs
  input_types = [x.dtype for x in inputs[0]]
  input_arity = len(input_types)
  for i in range(num_replicas):
    if len(inputs[i]) != input_arity:
      raise ValueError("Replicas must have the same number of inputs. "
                       "Replica 0 had {} inputs, replica {} had {} "
                       "inputs.".format(input_arity, i, len(inputs[i])))

    types = [x.dtype for x in inputs[i]]
    if types != input_types:
      raise ValueError(
          "Replicas must have matching input types. Replica 0 had "
          "input types {}, replica {} had input types {}".format(
              input_types, i, types))

  arg_error = tpu_function.check_function_argument_count(
      computation, input_arity, infeed_queue)
  if arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s, but the computation needs %s" % (
              input_arity, str([i.name for i in inputs[0]]), arg_error))
    else:
      raise TypeError(
          "Supplied computation cannot be called with the specified inputs. "
          "You specified %d inputs: %s and %d additional inputs from infeed,"
          " but the computation needs %s" % (input_arity, str(
              [i.name
               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
                                             arg_error))

  graph = ops.get_default_graph()

  # Fan-in: Builds a TPUReplicatedInput node for each input.
  computation_inputs = []
  for i in range(0, input_arity):
    replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
    computation_inputs.append(
        tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

  context = TPUReplicateContext(
      name=graph.unique_name("cluster"), num_replicas=num_replicas)
  try:
    context.Enter()

    metadata = tpu_ops.tpu_replicate_metadata(
        num_replicas=num_replicas, **metadata_kwargs)

    with tpu_function.tpu_shard_context(
        num_replicas), ops.control_dependencies([metadata]):

      # The EncapsulateTPUComputations rewrite needs to identify the
      # replicated arguments inside each computation. Adds identity operators
      # tagged with an attribute _tpu_replicated_input to identify the
      # replicated inputs.
      # pylint: disable=protected-access
      with graph._attr_scope({"_tpu_replicated_input":
                              attr_value_pb2.AttrValue(b=True)}):
        computation_inputs = [
            array_ops.identity(x, name="replicated_input_{}".format(i))
            for i, x in enumerate(computation_inputs)]
      # pylint: enable=protected-access

      # If there is an infeed queue, adds the dequeued values to the
      # computation's inputs.
      if infeed_queue is not None:
        infeed_queue.set_number_of_shards(num_replicas)
        for t in infeed_queue.generate_dequeue_op():
          computation_inputs.append(t)

      # Only resource variables work inside a TPU computation, so turn on
      # resource variables for the computation.
      # TODO(phawkins): consider removing this code. It will
      # be less confusing to clients if they knowingly choose to use resource
      # variables.
      vscope = variable_scope.get_variable_scope()
      saved_use_resource = vscope.use_resource
      vscope.set_use_resource(True)

      outputs = computation(*computation_inputs)

      vscope.set_use_resource(saved_use_resource)

    # If the computation only returned one value, makes it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    try:
      with ops.device(core(0)):
        outputs = [
            o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
            for o in outputs
        ]
    except Exception as e:
      raise ValueError(
          "TPU function return values must all either be Operations or "
          "convertible to Tensors. Got '%s'" % str(e))

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU functions must return zero-or more Tensor values followed by "
          "zero or more Operations.")
    output_arity = len(output_tensors)

    # Wraps outputs in Identity ops. Otherwise a replicated input copied
    # straight to an output would bypass the replicate(). This would be bad
    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
    # be rewritten away, leading to a runtime error.
    # TODO(phawkins): extend the rewrite to elide these nodes instead.
    new_output_tensors = []
    for t in output_tensors:
      with ops.device(t.device if t.device else core(0)):
        new_output_tensors.append(array_ops.identity(t))
    output_tensors = new_output_tensors
  finally:
    context.report_unsupported_operations()
    context.Exit()
    host_compute_core = context.HostComputeCore()

  if host_compute_core:
    attr_value = attr_value_pb2.AttrValue()
    attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access

  # Fan-out: Builds a TPUReplicatedOutput node for each output.
  outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
                                           name="output{}".format(i))
             for i in xrange(output_arity)]

  with ops.control_dependencies(output_operations):
    if output_arity == 0:
      # Returns a list of NoOps dependent on the replication Op, indexed by
      # [replica_num].
      return [
          control_flow_ops.no_op(name="shard_%d" % i)
          for i in range(num_replicas)
      ]
    else:
      # Wraps the outputs in identity operators so the names of any possible
      # `fetch` nodes are preserved by the replication rewrite.
      return [
          [array_ops.identity(outputs[out][replica],
                              name="output_%d_shard_%d" % (out, replica))
           for out in xrange(output_arity)]
          for replica in xrange(num_replicas)
      ]