Esempio n. 1
0
    def test_var_args_and_defaults(self):
        """Tests that arg checker works for a function with varargs and defaults."""
        def func(x, y, z=17, *q):  # pylint: disable=keyword-arg-before-vararg
            return x + y + z + len(q)

        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 5, None))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 1, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 1, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, queue))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 0, queue))
Esempio n. 2
0
  def test_simple(self):
    """Tests that arg checker works for functions with no varargs or defaults.
    """

    def func(x, y, z):
      return x + y + z

    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual('exactly 3 arguments',
                     xla.check_function_argument_count(func, 2, None))
    queue = tpu_feed.InfeedQueue(2)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual('exactly 3 arguments',
                     xla.check_function_argument_count(func, 2, queue))
Esempio n. 3
0
  def testSimple(self):
    """Tests that arg checker works for functions with no varargs or defaults.
    """

    def func(x, y, z):
      return x + y + z

    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual('exactly 3 arguments',
                     xla.check_function_argument_count(func, 2, None))
    queue = tpu_feed.InfeedQueue(2)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual('exactly 3 arguments',
                     xla.check_function_argument_count(func, 2, queue))
Esempio n. 4
0
  def test_var_args_and_defaults(self):
    """Tests that arg checker works for a function with varargs and defaults."""

    def func(x, y, z=17, *q):  # pylint: disable=keyword-arg-before-vararg
      return x + y + z + len(q)

    self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 5, None))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 1, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 4, queue))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 0, queue))
Esempio n. 5
0
  def test_var_args(self):
    """Tests that arg checker works for a function with varargs."""

    def func(x, y, *z):
      return x + y + len(z)

    self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 1, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 0, queue))
Esempio n. 6
0
  def testVarArgs(self):
    """Tests that arg checker works for a function with varargs."""

    def func(x, y, *z):
      return x + y + len(z)

    self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 1, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 0, queue))
Esempio n. 7
0
def repeat(n, body, inputs=None, infeed_queue=None, use_while_v1=True):
  """Builds a loop that executes a fixed number of iterations.

  The set of loop-carried tensors correspond to `inputs`.
  `body` must be a function that takes and returns the values of the
  loop-carried tensors.

  Args:
    n: the number of loop iterations
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the loop or
      None (equivalent to an empty list).
    infeed_queue: if not None, the IPUInfeedQueue from which data is consumed.
    use_while_v1: if True, then use a Tensorflow v1.x dataflow while loop.
  Returns:
    The final values of the loop-carried tensors.
  Raises:
    ValueError: if there is a type error.
    TypeError: if body has the wrong signature.
  """
  if inputs is None:
    inputs = []

  input_arity = len(_convert_to_list(inputs))
  body_arg_error = xla.check_function_argument_count(body, input_arity,
                                                     infeed_queue)
  if body_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s, but the loop body needs %s." %
          (input_arity, str(inputs), body_arg_error))
    else:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s and %d additional inputs from "
          "infeed, but the computation needs %s." %
          (input_arity, str(inputs), infeed_queue.number_of_tuple_elements,
           body_arg_error))
  return while_loop(
      lambda *args: True,
      body,
      inputs=inputs,
      infeed_queue=infeed_queue,
      maximum_iterations=n,
      use_while_v1=use_while_v1)
Esempio n. 8
0
  def testDefaultArgs(self):
    """Tests that arg checker works for a function with no varargs."""

    def func(x, y, z=17):
      return x + y + z

    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 1, None))
    self.assertEqual('at most 3 arguments',
                     xla.check_function_argument_count(func, 4, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 0, queue))
    self.assertEqual('at most 3 arguments',
                     xla.check_function_argument_count(func, 4, queue))
Esempio n. 9
0
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
    """Builds a training loop for TPUs.

  The set of loop-carried tensors corresponds to `inputs`.  Both
  `condition` and `body` take the current value of the loop-carried
  tensors. 'body' additionally takes a tuple of infeed from
  infeed_queue if infeed_queue is not None. `condition` must return a
  single boolean value that determines whether iteration
  continues. `body` must return an updated list of values for the
  loop-carried tensors.

  Args:
    condition: a Python function that builds the loop condition.
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the training loop, or
      None (equivalent to an empty list).
    infeed_queue: if not None, the infeed queue from which to append a tuple
      of arguments as inputs to condition.
    name: (Deprecated) Does nothing.

  Returns:
    The final values of the loop-carried tensors.

  Raises:
    TypeError: if body or condition has the wrong signature.
  """
    del name
    # Converts inputs to Tensors.
    inputs = [] if inputs is None else [
        ops.convert_to_tensor(x) for x in inputs
    ]
    input_types = [x.dtype for x in inputs]
    input_arity = len(inputs)

    body_arg_error = xla.check_function_argument_count(body, input_arity,
                                                       infeed_queue)
    if body_arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied loop body function cannot be called with the specified "
                "inputs. You specified %d inputs: %s, but the loop body needs %s"
                % (input_arity, str([i.name for i in inputs]), body_arg_error))
        else:
            raise TypeError(
                "Supplied loop body function cannot be called with the specified "
                "inputs. You specified %d inputs: %s and %d additional inputs from "
                "infeed, but the computation needs %s" %
                (input_arity, str([i.name for i in inputs]),
                 infeed_queue.number_of_tuple_elements, body_arg_error))
    condition_arg_error = xla.check_function_argument_count(
        condition, input_arity, None)
    if condition_arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied loop condition function cannot be called with the "
                "specified inputs. You specified %d inputs: %s, but the loop "
                "condition needs %s" %
                (input_arity, str([i.name
                                   for i in inputs]), condition_arg_error))
        else:
            raise TypeError(
                "Supplied loop condition function cannot be called with the "
                "specified inputs. You specified %d inputs: %s, but the loop "
                "condition needs %s. Note that infeed is not passed to the loop "
                "condition." %
                (input_arity, str([i.name
                                   for i in inputs]), condition_arg_error))

    def condition_wrapper(*inputs):
        # Discards the dummy output added for arity-0 loops.
        if input_arity == 0:
            inputs = []
        return condition(*inputs)

    def body_wrapper(*inputs):
        """Wrapper around `body` that handles infeed queues and control deps."""
        inputs = list(inputs)

        # Discards the dummy output added for arity-0 loops.
        if input_arity == 0:
            inputs = []

        # Runs `body` with the dequeue_ops appended.
        if infeed_queue:
            number_of_shards = tpu_function.get_tpu_context().number_of_shards
            if number_of_shards is None:
                raise ValueError(
                    "Can't build training loop with infeed when there is "
                    "no tpu_shard_context. Are you building a loop or "
                    "graph directly rather than from inside tpu.rewrite, "
                    "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
            infeed_queue.set_number_of_shards(number_of_shards)
            dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
        else:
            dequeue_ops = []
        outputs = body(*(inputs + dequeue_ops))

        # If the computation only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        outputs = [
            o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
            for o in outputs
        ]

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "TPU training loop body must return zero or more Tensor values "
                "followed by zero or more Operations.")

        output_types = [op.dtype for op in output_tensors]
        if input_types != output_types:
            raise TypeError(
                "Mismatch between input types and output types for training loop "
                "body: {} vs {}".format(input_types, output_types))

        # Add the dequeue operations to output_operations to ensure they are run
        # by the loop, even if the programmer's loop body does not use them.
        output_operations += dequeue_ops

        # Add a dummy output, if needed.
        if not output_tensors:
            output_tensors = array_ops.constant(0)

        if output_operations:
            # TODO(phawkins): in principle this is too restrictive since it serializes
            # the training loop steps. In practice it does not matter since this loop
            # will be compiled by XLA.
            output_tensors = control_flow_ops.tuple(
                output_tensors, control_inputs=output_operations)

        if tensor_tracer.TensorTracer.is_enabled():
            num_replicas = tpu_function.get_tpu_context().number_of_shards
            if num_replicas is None:
                num_replicas = 1
            tt = tensor_tracer.TensorTracer()
            output_tensors = tt.trace_tpu(ops.get_default_graph(),
                                          output_tensors, None, num_replicas)
        return output_tensors

    # If the body has arity 0, add a dummy loop-carried value to which we can add
    # control dependencies from any side-effecting operations.
    if input_arity == 0:
        inputs = [array_ops.constant(0)]
    return control_flow_ops.while_loop(condition_wrapper,
                                       body_wrapper,
                                       inputs,
                                       name="",
                                       parallel_iterations=1)
Esempio n. 10
0
def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
  """Builds a training loop for TPUs.

  The set of loop-carried tensors corresponds to `inputs`.  Both
  `condition` and `body` take the current value of the loop-carried
  tensors. 'body' additionally takes a tuple of infeed from
  infeed_queue if infeed_queue is not None. `condition` must return a
  single boolean value that determines whether iteration
  continues. `body` must return an updated list of values for the
  loop-carried tensors.

  Args:
    condition: a Python function that builds the loop condition.
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the training loop, or
      None (equivalent to an empty list).
    infeed_queue: if not None, the infeed queue from which to append a tuple
      of arguments as inputs to condition.
    name: (Deprecated) Does nothing.

  Returns:
    The final values of the loop-carried tensors.

  Raises:
    TypeError: if body or condition has the wrong signature.
  """
  del name
  # Converts inputs to Tensors.
  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
                                      x in inputs]
  input_types = [x.dtype for x in inputs]
  input_arity = len(inputs)

  body_arg_error = xla.check_function_argument_count(
      body, input_arity, infeed_queue)
  if body_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s, but the loop body needs %s" % (
              input_arity, str([i.name for i in inputs]), body_arg_error))
    else:
      raise TypeError(
          "Supplied loop body function cannot be called with the specified "
          "inputs. You specified %d inputs: %s and %d additional inputs from "
          "infeed, but the computation needs %s" % (input_arity, str(
              [i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
                                                    body_arg_error))
  condition_arg_error = xla.check_function_argument_count(
      condition, input_arity, None)
  if condition_arg_error is not None:
    if infeed_queue is None:
      raise TypeError(
          "Supplied loop condition function cannot be called with the "
          "specified inputs. You specified %d inputs: %s, but the loop "
          "condition needs %s" % (input_arity, str([i.name for i in inputs]),
                                  condition_arg_error))
    else:
      raise TypeError(
          "Supplied loop condition function cannot be called with the "
          "specified inputs. You specified %d inputs: %s, but the loop "
          "condition needs %s. Note that infeed is not passed to the loop "
          "condition." % (input_arity, str([i.name for i in inputs]),
                          condition_arg_error))

  def condition_wrapper(*inputs):
    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []
    return condition(*inputs)

  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO(phawkins): in principle this is too restrictive since it serializes
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      output_tensors = control_flow_ops.tuple(output_tensors,
                                              control_inputs=output_operations)

    if tensor_tracer.TensorTracer.is_enabled():
      num_replicas = tpu_function.get_tpu_context().number_of_shards
      if num_replicas is None:
        num_replicas = 1
      tt = tensor_tracer.TensorTracer()
      output_tensors = tt.trace_tpu(ops.get_default_graph(),
                                    output_tensors, None,
                                    num_replicas)
    return output_tensors

  # If the body has arity 0, add a dummy loop-carried value to which we can add
  # control dependencies from any side-effecting operations.
  if input_arity == 0:
    inputs = [array_ops.constant(0)]
  return control_flow_ops.while_loop(
      condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
Esempio n. 11
0
def while_loop(condition,
               body,
               inputs=None,
               infeed_queue=None,
               maximum_iterations=None,
               use_while_v1=True):
    """Builds a while loop for IPUs.

  The set of loop-carried tensors corresponds to `inputs`.  Both
  `condition` and `body` take the current value of the loop-carried
  tensors. `condition` must return a single boolean value that determines
  whether iteration continues. `body` must return an updated list of values for
  the loop-carried tensors.

  Args:
    condition: a Python function that builds the loop condition.
    body: a Python function that builds the loop body.
    inputs: a list of initial values passed into the loop, or
      None (equivalent to an empty list).
    infeed_queue: if not None, the IPUInfeedQueue from which data is consumed.
    use_while_v1: if True, then use a Tensorflow v1.x dataflow while loop.

  Returns:
    The final values of the loop-carried tensors.

  Raises:
    TypeError: if body or condition has the wrong signature.
  """

    # Converts inputs to Tensors if not TensorArray.
    inputs = [] if inputs is None else [
        ops.convert_to_tensor(x) if not isinstance(x, TensorArray) else x
        for x in _convert_to_list(inputs)
    ]
    input_types = [x.dtype for x in inputs]
    input_arity = len(inputs)
    body_arg_error = xla.check_function_argument_count(body, input_arity,
                                                       infeed_queue)
    if body_arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied loop body function cannot be called with the specified "
                "inputs. You specified %d inputs: %s, but the loop body needs %s."
                % (input_arity, str(inputs), body_arg_error))
        raise TypeError(
            "Supplied loop body function cannot be called with the specified "
            "inputs. You specified %d inputs: %s and %d additional inputs from "
            "infeed, but the computation needs %s." %
            (input_arity, str(inputs), infeed_queue.number_of_tuple_elements,
             body_arg_error))
    condition_arg_error = xla.check_function_argument_count(
        condition, input_arity, None)
    if condition_arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied loop condition function cannot be called with the "
                "specified inputs. You specified %d inputs: %s, but the loop "
                "condition needs %s." %
                (input_arity, str(inputs), condition_arg_error))
        raise TypeError(
            "Supplied loop condition function cannot be called with the "
            "specified inputs. You specified %d inputs: %s, but the loop "
            "condition needs %s. Note that infeed is not passed to the loop "
            "condition." % (input_arity, str(inputs), condition_arg_error))

    def condition_wrapper(*inputs):
        # Discards the dummy output added for arity-0 loops.
        if input_arity == 0:
            inputs = []
        return condition(*inputs)

    def body_wrapper(*inputs):
        """Wrapper around `body` that handles infeed queues and control deps."""
        inputs = list(inputs)

        # Discards the dummy output added for arity-0 loops.
        if input_arity == 0:
            inputs = []

        # Runs `body` with the dequeue_ops appended.
        if infeed_queue:
            dequeue_ops = infeed_queue._dequeue()
        else:
            dequeue_ops = []

        body_args, body_kwargs = _body_arguments(dequeue_ops)
        outputs = body(*(inputs + body_args), **body_kwargs)

        # If the computation only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        outputs = [
            o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
            for o in outputs
        ]

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "IPU loop body must return zero or more Tensor values "
                "followed by zero or more Operations.")

        output_types = [op.dtype for op in output_tensors]
        if input_types != output_types:
            raise TypeError(
                "Mismatch between input types and output types for loop "
                "body: {} vs {}.".format(input_types, output_types))

        # Add a dummy output, if needed.
        if not output_tensors:
            output_tensors = [array_ops.constant(0)]

        if output_operations:
            output_tensors = control_flow_ops.tuple(
                output_tensors, control_inputs=output_operations)

        return output_tensors[0] if len(
            output_tensors) == 1 else output_tensors

    # If the body has arity 0, add a dummy loop-carried value to which we can add
    # control dependencies from any side-effecting operations.
    if input_arity == 0:
        inputs = [array_ops.constant(0)]

    if use_while_v1:
        while_fn = control_flow_ops.while_loop
    else:
        while_fn = while_v2.while_loop
        logging.warning("Usage of while_v2 is still experimental.")

    outputs = while_fn(condition_wrapper,
                       body_wrapper,
                       inputs,
                       maximum_iterations=maximum_iterations,
                       name="",
                       parallel_iterations=1)

    # Check the infeed queue has been used - this is more of a courtesy to the
    # user.
    if infeed_queue is not None and not infeed_queue.dequeued:
        raise ValueError("The infeed queue has not been dequeued.")

    outputs = _convert_to_list(outputs)
    if len(outputs) == 1:
        # If there were no inputs, only return the op for the dummy output.
        return outputs[0].op if input_arity == 0 else outputs[0]
    return outputs