Exemplo n.º 1
0
 def infeed_func_wrapper(*args):
   args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access
   dequeue_ops = functional_ops._convert_to_list(infeed_queue._dequeue())  # pylint: disable=protected-access
   # Deal with the dequeue depending on whether it's a list or dict.
   if len(dequeue_ops) == 1 and isinstance(dequeue_ops[0], dict):
     kwargs = dequeue_ops[0]
     return func(*(args), **kwargs)
   return func(*(args + dequeue_ops))
Exemplo n.º 2
0
 def pipeline(*args):
     # TF2 replacement for: iterator = dataset.make_one_shot_iterator()
     iterator = compat_v1_data.make_one_shot_iterator(dataset)
     next_example, next_label = iterator.get_next()
     outputs = functional_ops._convert_to_list(args)  # pylint: disable=W0212
     outputs.append(next_example)
     outputs.append(next_label)
     for stage in stages:
         outputs = stage(
             *functional_ops._convert_to_list(outputs))  # pylint: disable=W0212
     return outputs
Exemplo n.º 3
0
 def outfeed_func_wrapper(*args, **kwargs):
   outputs = func(*args, **kwargs)
   # Check if there are output tensors - if there are then enqueue them.
   if not isinstance(outputs, ops.Operation):
     if not isinstance(outputs, dict):
       outputs = functional_ops._convert_to_list(outputs)  # pylint: disable=protected-access
     outputs = outfeed_queue.enqueue(outputs)
   control_outputs.append(outputs)
Exemplo n.º 4
0
 def model(*args):
   loss = fwd_fn(*functional_ops._convert_to_list(args))  # pylint: disable=W0212
   enqueue_op = outfeed_queue.enqueue(loss)
   opt = gradient_accumulation_optimizer.GradientAccumulationOptimizerV2(
       optimizer, num_batches_to_accumulate)
   outs = list(args[:len(args) - infeed_queue.number_of_tuple_elements])
   outs.append(enqueue_op)
   outs.append(opt.minimize(loss))
   return outs
Exemplo n.º 5
0
 def pipeline(*args):
     outputs = args
     for i, stage in zip(device_mapping, stages):
         with scopes.ipu_shard(i):
             outputs = stage(
                 *functional_ops._convert_to_list(outputs))  # pylint: disable=W0212
     loss = outputs
     enqueue_op = outfeed_queue.enqueue(loss)
     opt = gradient_accumulation_optimizer.GradientAccumulationOptimizer(
         optimizer, num_batches_to_accumulate)
     outs = list(args[:len(args) -
                      infeed_queue.number_of_tuple_elements])
     outs.append(enqueue_op)
     outs.append(opt.minimize(loss))
     return outs
Exemplo n.º 6
0
    def multi_conv_wrapper(*args):
      inner_options = options if options else {}

      if not isinstance(inner_options, dict):
        raise TypeError(
            "Expected the multi_conv `options` to be a `dict`, but got %s "
            "instead." % (str(inner_options)))

      option_proto = option_flag_pb2.PoplarOptionFlags()
      for key, value in inner_options.items():
        flag = option_proto.flags.add()
        flag.option = key
        flag.value = value

      def func_wrapper(*args):
        with op_util.gradient_override_scope(training=False):
          return inner_func(*args)

      args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access
      with ops.name_scope("multi_conv") as scope:
        func_graph, captured_args = functional_ops._compile_function(  # pylint: disable=protected-access
            func_wrapper,
            args,
            scope, [],
            allow_external_captures=True)

        with ops.control_dependencies(list(func_graph.control_captures)):
          outputs = gen_functional_ops.multi_conv(
              captured_args,
              to_apply=util.create_new_tf_function(func_graph),
              Tout=func_graph.output_types,
              output_shapes=func_graph.output_shapes,
              option_flags=json_format.MessageToJson(option_proto))

      return func_graph_module.pack_sequence_as(func_graph.structured_outputs,
                                                outputs)
Exemplo n.º 7
0
def _pipeline_stage(func,
                    stage_id,
                    device_id,
                    args,
                    training,
                    infeed_queue=None,
                    outfeed_queue=None,
                    name=None):
  """Internal function for compiling a pipeline stage. This should not be called
  directly and doing so will result in undefined behaviour.

  Creates a pipeline stage.

  Args:
    func: function which will be executed as a stage.
    stage_id: Stage number.
    device_id: IPU the stage will be mapped to.
    args: arguments to the function.
    infeed_queue: optional IPUInfeedQueue, if passed, it is dequeued as part of
      this function.
    outfeed_queue: optional IPUOutfeedQueue, if passed, it is enqueued as part
      of this function.
    name: name of this pipeline sage.

  Returns:
    The values after execting func(args), or the control dependency if
    outfeed_queue is not None.

  """
  name = name if name else "pipeline_stage"
  args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access

  func_to_compile = func
  control_outputs = []
  # If we have an infeed, then we wrap the function in another function which
  # dequeues the infeed.
  if infeed_queue:

    def infeed_func_wrapper(*args):
      args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access
      dequeue_ops = functional_ops._convert_to_list(infeed_queue._dequeue())  # pylint: disable=protected-access
      # Deal with the dequeue depending on whether it's a list or dict.
      if len(dequeue_ops) == 1 and isinstance(dequeue_ops[0], dict):
        kwargs = dequeue_ops[0]
        return func(*(args), **kwargs)
      return func(*(args + dequeue_ops))

    func_to_compile = infeed_func_wrapper

  # If we have an outfeed, then we wrap the function in another function which
  # enqueues the outfeed.
  if outfeed_queue:
    func = func_to_compile

    def outfeed_func_wrapper(*args, **kwargs):
      outputs = func(*args, **kwargs)
      # Check if there are output tensors - if there are then enqueue them.
      if not isinstance(outputs, ops.Operation):
        if not isinstance(outputs, dict):
          outputs = functional_ops._convert_to_list(outputs)  # pylint: disable=protected-access
        outputs = outfeed_queue.enqueue(outputs)
      control_outputs.append(outputs)

    func_to_compile = outfeed_func_wrapper

  def gradient_override_wrapper(*args, **kwargs):
    with op_util.gradient_override_scope(training):
      return func_to_compile(*args, **kwargs)

  with ops.name_scope(name) as scope:
    # pylint: disable=protected-access
    try:
      func_graph, captured_args = functional_ops._compile_function(
          gradient_override_wrapper, args, scope, control_outputs)
    except functional_ops._InvalidCaptureException as e:
      raise ValueError(
          "Trying to capture the tensor %s which is not a resource. This tensor"
          " needs to be passed as either part of the `input` or `infeed_queue`"
          " of the pipeline." % (str(e)))
    # pylint: enable=protected-access

    # Create the pipeline stage and lower the function into XLA.
    with ops.control_dependencies(list(func_graph.control_captures)):
      with scopes.ipu_shard(device_id):
        outputs = gen_functional_ops.pipeline_stage(
            captured_args,
            to_apply=util.create_new_tf_function(func_graph),
            Tout=func_graph.output_types,
            output_shapes=func_graph.output_shapes,
            stage_id=stage_id)
    if isinstance(outputs, ops.Operation):
      return outputs
    return func_graph_module.pack_sequence_as(func_graph.structured_outputs,
                                              outputs)
Exemplo n.º 8
0
  def _pipeline(*args):
    outputs = args
    for stage_id, stage in enumerate(computational_stages):
      stage_infeed_queue = infeed_queue if stage_id == 0 else None
      if stage_id == len(computational_stages) - 1 and not optimizer_function:
        stage_outfeed_queue = outfeed_queue
      else:
        stage_outfeed_queue = None

      stage_name = name + "_stage_" + str(stage_id)
      outputs = _pipeline_stage(stage,
                                stage_id,
                                device_mapping[stage_id],
                                outputs,
                                training=optimizer_function is not None,
                                infeed_queue=stage_infeed_queue,
                                outfeed_queue=stage_outfeed_queue,
                                name=stage_name)

    if optimizer_function:
      outputs = functional_ops._convert_to_list(outputs)  # pylint: disable=protected-access

      # Get the output from the optimizer function
      opt_fn = optimizer_function(*outputs)
      loss = opt_fn.loss
      opt = opt_fn.opt

      # Enqueue loss or any output tensors to the outfeed.
      if outfeed_loss:
        if not outfeed_queue:
          raise ValueError(
              "An outfeed_queue must be provided when outfeed_loss is True")
        control_outputs.append(outfeed_queue.enqueue(opt_fn.loss))
      elif outputs:
        if not outfeed_queue:
          raise ValueError(
              "The last computational stage has tensor outputs: %s, but no"
              " outfeed_queue has been provided." %
              (', '.join(str(t) for t in outputs)))
        control_outputs.append(outfeed_queue.enqueue(outputs))

      # Call the compute gradients function - this will be automatically put
      # into pipeline stages.
      grads_and_vars = opt.compute_gradients(loss)
      # Insert gradient accumulation ops.
      accumulated_grads_and_vars = []
      for grad, var in grads_and_vars:
        if grad is not None:
          with ops.colocate_with(grad):
            # Create an accumulator - variable is used as reference for shape/layout.
            accumulator = gen_poputil_ops.gradient_accumulator_create(var)
            # Add the gradients to the accumulator.
            accumulator = gen_poputil_ops.gradient_accumulator_add(
                accumulator, grad)
            # Sink the accumulators.
            grad = gen_poputil_ops.gradient_accumulator_sink(
                accumulator, num_mini_batches=gradient_accumulation_count)
        # Use the accumulated gradients.
        accumulated_grads_and_vars.append((grad, var))

      # Create an explicit function call for the apply gradients - note that we
      # allow external caputres here.
      apply_grad_ops = []

      def resource_update_():
        apply_grads = opt.apply_gradients(accumulated_grads_and_vars)
        apply_grad_ops.append(apply_grads)

      with ops.name_scope(name + "/WU") as scope:
        func_graph, captured_args = functional_ops._compile_function(  # pylint: disable=protected-access
            resource_update_, [], scope, apply_grad_ops, True)

      # Create the pipeline resource update stage and lower the function into XLA.
      with ops.control_dependencies(list(func_graph.control_captures)):
        outputs = gen_functional_ops.resource_update(
            captured_args,
            to_apply=util.create_new_tf_function(func_graph),
            Tout=func_graph.output_types,
            output_shapes=func_graph.output_shapes,
            offload_weight_update_variables=offload_weight_update_variables,
            replicated_optimizer_state_sharding=
            replicated_optimizer_state_sharding,
            num_batches_to_accumulate=gradient_accumulation_count)

    if not isinstance(outputs, ops.Operation):
      if not outfeed_queue:
        raise ValueError(
            "The last computational stage has tensor outputs: %s, but no"
            " outfeed_queue has been provided." % (', '.join(
                str(t) for t in functional_ops._convert_to_list(outputs))))  # pylint: disable=protected-access

      else:
        raise ValueError(
            "Expected the pipeline resource update stage to output a "
            "tf.Operation, got %s instead." % (str(output)))

    control_outputs.append(outputs)
Exemplo n.º 9
0
def pipeline(computational_stages,
             pipeline_depth=None,
             gradient_accumulation_count=None,
             repeat_count=1,
             batch_serialization_iterations=1,
             inputs=None,
             infeed_queue=None,
             outfeed_queue=None,
             optimizer_function=None,
             device_mapping=None,
             pipeline_schedule=None,
             forward_propagation_stages_poplar_options=None,
             backward_propagation_stages_poplar_options=None,
             weight_update_poplar_options=None,
             offload_weight_update_variables=None,
             replicated_optimizer_state_sharding=False,
             offload_activations=None,
             offload_gradient_accumulation_buffers=None,
             replicated_weight_sharding=None,
             offload_weights=None,
             continuous_weight_updates=False,
             outfeed_loss=False,
             name=None):
  """
  Sets up a series of computational stages, where the outputs of one stage are
  the inputs to the next one. These stages are then executed in parallel across
  multiple IPUs. This approach can be used to split the model where layer(s)
  are executed on different IPUs.

  The first stage takes the `inputs` and the `infeed_queue` (if provided) as
  its inputs. If the `infeed_queue` is provided, it is automatically dequeued
  (similar to the ipu.loops API) therefore care needs to be taken to make sure
  the signature of the first pipeline stage matches both the arguments from
  `inputs` and the `infeed_queue`, otherwise an error is thrown.

  All tensors which are used in the pipeline which are not TensorFlow
  Variables need to be explicitly passed as inputs to the pipeline. If an
  input does not change its value during the execution of the pipeline op
  (for example hyperparameters such as learning rate), it needs to be passed
  as part of `inputs`. Alternatively, if these values change during execution
  (for example the model processes different batches of data) the input should
  be passed through the `infeed_queue`
  (see :class:`~tensorflow.python.ipu.ipu_infeed_queue.IPUInfeedQueue`).

  When training a model, an optional `optimizer_function` function can be
  provided. This function takes all the outputs from the last computational
  stage as inputs, and returns an instance of `OptimizerFunctionOutput` that
  is used to generate the backwards pass of the model using the TensorFlow
  Optimizer API. This will internally create corresponding backpropagation
  pipeline stages for each pipeline stage and colocate them such that the
  activations and weights required for the gradient calculation and
  application stay on the device in order to minimise the number of copies
  between IPUs.

  Note that the gradients, which are calculated by the `compute_gradients`
  function, will be accumulated automatically during the execution of the
  pipeline, unless `continuous_weight_updates` is enabled.

  If the last computational stage has any outputs, then an `outfeed_queue`
  (see :class:`~tensorflow.python.ipu.ipu_outfeed_queue.IPUOutfeedQueue`)
  is required and all the outputs from the last computational stage are enqueued
  to the `outfeed_queue`.

  Note that pipelining also supports recomputation, to enable it, use the
  `tensorflow.ipu.utils.set_recomputation_options()` function when configuring
  the device.

  For example a simple inference network for the MNIST can be split across two
  IPUs:

  .. code-block:: python

    from tensorflow import keras

    # Create the dataset
    #...

    # Create the data queues from/to IPU.
    infeed_queue = ipu_infeed_queue.IPUInfeedQueue(dataset, "infeed")
    outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue("outfeed")

    # Create a pipelined model which is split accross two stages.
    def stage1(image):
      partial = keras.layers.Dense(256, activation=tf.nn.relu)(image)
      partial = keras.layers.Dense(128, activation=tf.nn.relu)(partial)
      return partial

    def stage2(partial):
      logits = keras.layers.Dense(10)(partial)
      probabilities = tf.nn.softmax(logits)
      classes = tf.argmax(input=logits, axis=1)
      return probabilities, classes

    def model():
      with variable_scope.variable_scope("vs", use_resource=True):
        pipeline_op = pipelining_ops.pipeline(
                          computational_stages=[stage1, stage2],
                          gradient_accumulation_count=250,
                          repeat_count=2,
                          inputs=[],
                          infeed_queue=infeed_queue,
                          outfeed_queue=outfeed_queue,
                          device_mapping=[3,1],
                          name="Pipeline")
      return pipeline_op

    with ops.device("/device:IPU:0"):
      compiled_model = ipu_compiler.compile(model, inputs=[])

    outfeed_op = outfeed_queue.dequeue()
    with tf.Session() as sess:
      result = sess.run(compiled_model)
      probabilities, classes = sess.run(outfeed_op)

  In this set up, the model is split across two IPUs. By default the first two
  layers would be executed on the first IPU and the third layer and the
  probabilities and classes on the second IPU but here `device_mapping` is
  used to override the default IPU allocation and instead the first two layers
  will be executed on the fourth IPU and the third layer and the probabilities
  and classed on the second IPU.

  This creates a pipeline of depth 250 (specified by the
  `gradient_accumulation_count`), which means each pipeline stage is executed
  250 times.

  This pipeline is then executed 2 times (specified by the `repeat_count`)
  The results of the pipeline (probabilities and classes) are returned to the
  host by the outfeed queue.

  We can also train this network by providing `optimizer_function`:

  .. code-block:: python

    from tensorflow import keras

    # Create the dataset
    #...

    # Create the data queues from/to IPU.
    infeed_queue = ipu_infeed_queue.IPUInfeedQueue(dataset, "infeed")
    outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue("outfeed")

    # Create a pipelined model which is split accross two stages.
    def stage1(lr, images, labels):
      partial = keras.layers.Dense(256, activation=tf.nn.relu)(images)
      partial = keras.layers.Dense(128, activation=tf.nn.relu)(partial)
      return lr, partial, labels

    def stage2(lr, partial, labels):
      logits = keras.layers.Dense(10)(partial)
      cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                            labels=labels, logits=logits)
      loss = tf.reduce_mean(cross_entropy)
      return lr, loss

    def optimizer_function(lr, loss):
      optimizer = tf.train.GradientDescentOptimizer(lr)
      return pipelining_ops.OptimizerFunctionOutput(optimizer, loss)

    def model(lr):
      with variable_scope.variable_scope("vs", use_resource=True):
        pipeline_op = pipelining_ops.pipeline(
                          computational_stages=[stage1, stage2],
                          gradient_accumulation_count=128,
                          repeat_count=10,
                          inputs=[lr],
                          infeed_queue=infeed_queue,
                          outfeed_queue=outfeed_queue,
                          optimizer_function=optimizer_function,
                          name="Pipeline")
      return pipeline_op

    with ops.device('cpu'):
      lr = tf.placeholder(np.float16, [])

    with ops.device("/device:IPU:0"):
      compiled_model = ipu_compiler.compile(model, inputs=[lr])

    outfeed_op = outfeed_queue.dequeue()
    with tf.Session() as sess:
      result = sess.run(compiled_model, {lr: 0.01})
      losses = sess.run(outfeed_op)

  Here the `tf.train.GradientDescentOptimizer` generates the pipeline stages
  which calculate the gradients and apply them to the weights. Note how the
  loss is returned to the host by the outfeed queue.

  If a model requires multiple computational pipeline stages to access the same
  `tf.Variable`, then all of these computational stages need to be placed on the
  same IPU using the `device_mapping` argument.

  Note that modifying `tf.Variable` values in a pipeline stage and/or during the
  gradient calculation will result in undefined behavior. These variables can
  only be modified by the `apply_gradients` member function of the applied
  Optimizer.

  Args:
    computational_stages: a list of python functions, where each function
      represents a computational pipeline stage. The function takes the
      outputs of the previous pipeline state as its inputs.
    gradient_accumulation_count: the number of times each pipeline stage will
      be executed.
    repeat_count: the number of times the pipeline will be executed.
    batch_serialization_iterations: number of times a loop executes to compute a
      batch on each pipeline stage execution. Currently only supported with the
      `PipelineSchedule.Sequential`.
    inputs: arguments passed to the first pipeline stage.
    infeed_queue: optional IPUInfeedQueue, if passed, it is dequeued and
      passed as an input in the first pipeline stage.
    outfeed_queue: IPUOutfeedQueue, required if the last computational stage
      has any outputs. The outputs of these are enqueued to this queue and
      they can be accessed on the host.
    optimizer_function: optional Python function which takes the output of the
      last computational stage as parameters and returns an instance of
      `pipelining_ops.OptimizerFunctionOutput` in order to generate the
      back-propagation and weight-update parts of the model suitable for
      training.
    device_mapping: If provided, a list of length equal to the number of
      computational stages. An element at index `i` in the list represents which
      IPU the computational stage `computational_stages[i]` should reside on.
      This can be used to make sure computational stages which share
      `tf.Variable`s are resident on the same IPU.
    pipeline_schedule: Which scheduling algorithm to use for pipeline
      lowering. Defaults to `PipelineSchedule.Grouped`.
    forward_propagation_stages_poplar_options: If provided, a list of length
      equal to the number of computational stages. Each element is a
      PipelineStageOptions object which allows for fine grain control of the
      Poplar options for a given forward propagation computational stage.
    backward_propagation_stages_poplar_options: If provided, a list of length
      equal to the number of computational stages. Each element is a
      PipelineStageOptions object which allows for fine grained control of the
      Poplar options for a given backward propagation computational stage.
    weight_update_poplar_options: If provided, a PipelineStageOptions object
      which allows for fine grained control of the Poplar options for the
      weight update stage.
    offload_weight_update_variables: When enabled, any `tf.Variable` which is
      only used by the weight update of the pipeline (for example the
      accumulator variable when using the `tf.MomentumOptimizer`), will be
      stored in the remote memory. During the weight update this variable will
      be streamed onto the device and then streamed back to the remote memory
      after it has been updated. Requires the machine to be configured with
      support for `Poplar remote buffers`. Offloading variables into remote
      memory can reduce maximum memory liveness, but can also increase the
      computation time of the weight update.
      When set to `None` the variables will be placed in either in-processor or
      remote memory automatically based on the current best placement strategy.
      Note that this option has no effect for inference only pipelines.
    replicated_optimizer_state_sharding: If True, any `tf.Variable` which is
      offloaded (for example the accumulator variable when using the
      `tf.MomentumOptimizer`), will be partitioned across the replicas. This
      can exploit the additional bandwidth of the IPU-Links to improve overall
      throughput.
      Note that this option has no effect for inference only pipelines.
    offload_activations: When enabled, all the activations for the batches which
      are not being executed by the pipeline stages at the given time are stored
      in remote memory. Requires the machine to be configured with support for
      `Poplar remote buffers`. Offloading activations into remote memory can
      reduce maximum memory liveness, but can also increase the computation time
      as activations have to be copied from/to the device(s).
      When set to `None`, the activations might be offloaded when beneficial.
      This feature is currently only supported when the pipeline schedule is
      `PipelineSchedule.Sequential` and `batch_serialization_iterations > 1`.
    offload_gradient_accumulation_buffers: When enabled, all the gradient
      accumulation buffers are stored in remote memory. Offloading gradient
      accumulation buffers into remote memory can reduce maximum memory
      liveness, but can also increase the computation time as the buffers have
      to be copied to the device, updated and the copied off the device.
      Requires the machine to be configured with support for `Poplar remote
      buffers`.
      When set to `None`, the `offload_gradient_accumulation_buffers` might be
      offloaded when beneficial.
      Note that this option has no effect for inference only pipelines.
    replicated_weight_sharding: When enabled and running a replicated model, any
      `tf.Variable`s used by the pipeline stage computations (excluding those
      only used by the weight update), will be partitioned across the replicas.
      Whenever the a partitioned `tf.Variable` is accessed, it will be first
      all-gathered across replicas to make sure each replica has access to the
      whole `tf.Variable`. This can exploit the additional bandwidth of the
      IPU-Links to improve overall throughput.
      When set to `None`, the activations might be offloaded when beneficial.
      This feature is enabled by default when the pipeline schedule is
      `PipelineSchedule.Sequential` and `batch_serialization_iterations > 1`,
      where this option can reduce the memory usage at the cost of extra
      communication.
    offload_weights: When enabled and `replicated_weight_sharding` is enabled,
      any `tf.Variable` which are partitioned across replicas will be stored in
      `Poplar remote buffers`.  Offloading variables into remote memory can
      further reduce maximum memory liveness, but can also increase the
      computation time due to extra communication. When set to `None` the
      variables will be placed in either in-processor or remote memory
      automatically based on the current best placement strategy.
    continuous_weight_updates: ** CURRENTLY UNIMPLEMENTED ** When training,
      this option will apply the gradients to the resource variables
      immediately, rather than accumulating the gradients and applying them
      at the end of each execution of the pipeline.
    outfeed_loss: If True, the loss given by the `optimizer_function` will
      be enqueued on the outfeed, instead of the outputs from the last
      computational stage.
    name: name of this pipeline.

  Returns:
    An `Operation` that executes the pipeline.

  """
  name = name if name else "pipeline"

  if pipeline_depth:
    gradient_accumulation_count = pipeline_depth

  if not gradient_accumulation_count:
    raise ValueError("gradient_accumulation_count must be specified.")

  # Ensure inputs is a list, without casting inputs to a boolean. Casting
  # a tf.Tensor to a boolean will be interpreted as an operation in the
  # graph by Autograph.
  inputs = inputs if not isinstance(inputs, type(None)) else []
  inputs = functional_ops._convert_to_list(inputs)  # pylint: disable=protected-access
  inputs = ops.convert_n_to_tensor(inputs)

  if continuous_weight_updates:
    raise NotImplementedError(
        "Continuous weight updates are currently not supported.")

  for i, input in enumerate(inputs):
    if input.dtype == dtypes.resource:
      logging.warn("Passing tensor {} by value.".format(str(input)))
      inputs[i] = input.value()

  if pipeline_schedule is None:
    pipeline_schedule = (PipelineSchedule.Sequential
                         if batch_serialization_iterations > 1 else
                         PipelineSchedule.Grouped)

  if not isinstance(pipeline_schedule, PipelineSchedule):
    raise TypeError("The given pipeline_schedule is not a member of the "
                    "PipelineSchedule enumeration.")

  if (batch_serialization_iterations > 1
      and pipeline_schedule != PipelineSchedule.Sequential):
    raise NotImplementedError("Batch serialization is only supported with the "
                              "`Sequential` schedule.")

  if offload_activations and (
      batch_serialization_iterations < 2
      or pipeline_schedule != PipelineSchedule.Sequential):
    raise NotImplementedError("Activation offloading is only supported with "
                              "the `Sequential` schedule and when "
                              "`batch_serialization_iterations > 1`.")

  if device_mapping is None:
    device_mapping = [0] * len(
        computational_stages) if batch_serialization_iterations > 1 else list(
            range(len(computational_stages)))

  if not isinstance(computational_stages, (list, tuple)):
    raise TypeError(
        "computational_stages argument needs to be a list or a tuple.")

  if infeed_queue:
    if not isinstance(infeed_queue, ipu_infeed_queue.IPUInfeedQueue):
      raise TypeError("infeed_queue is not an instance of "
                      "ipu_infeed_queue.IPUInfeedQueue")

  if outfeed_queue:
    if not isinstance(outfeed_queue, ipu_outfeed_queue.IPUOutfeedQueue):
      raise TypeError("outfeed_queue is not an instance of "
                      "ipu_outfeed_queue.IPUOutfeedQueue")

  # We expect at least one stage.
  if len(computational_stages) < 2:
    raise ValueError("Pipeline requires at least two computational stages.")

  if not isinstance(device_mapping, (list, tuple)):
    raise TypeError("device_mapping argument needs to be a list or a tuple.")

  if len(device_mapping) != len(computational_stages):
    raise ValueError(
        "Each stage must be mapped to an IPU: %d mappings != %d stages" %
        (len(device_mapping), len(computational_stages)))

  # TODO(T18660) interleaved schedule does not support multiple stages on the
  # same IPU during training.
  if pipeline_schedule == PipelineSchedule.Interleaved and len(
      device_mapping) != len(set(device_mapping)) and optimizer_function:
    raise NotImplementedError(
        "The pipelining schedule 'Interleaved' does not currently support "
        "multiple pipeline stages on the same device for training graphs. "
        "Please use a different pipeline schedule.")

  if (pipeline_schedule == PipelineSchedule.Sequential
      and batch_serialization_iterations > 1
      and len(set(device_mapping)) != 1):
    raise NotImplementedError(
        "When using batch serialization, all the pipeline stages need to be "
        "mapped to a single IPU.")

  def bool_to_three_state(value, default=None):
    if value is None:
      return default if default else backend_config_pb2.ThreeState.Name(
          backend_config_pb2.THREESTATE_UNDEFINED)
    elif value:
      return backend_config_pb2.ThreeState.Name(
          backend_config_pb2.THREESTATE_ON)
    return backend_config_pb2.ThreeState.Name(
        backend_config_pb2.THREESTATE_OFF)

  # Convert some of the binary options into three states.
  offload_weight_update_variables = bool_to_three_state(
      offload_weight_update_variables)
  replicated_optimizer_state_sharding = bool_to_three_state(
      replicated_optimizer_state_sharding,
      default=offload_weight_update_variables)
  offload_activations = bool_to_three_state(offload_activations)
  offload_gradient_accumulation_buffers = bool_to_three_state(
      offload_gradient_accumulation_buffers)
  replicated_weight_sharding = bool_to_three_state(replicated_weight_sharding)
  offload_weights = bool_to_three_state(offload_weights,
                                        default=replicated_weight_sharding)

  # Function for setting up and validating the per stage Poplar options.
  def validate_stage_options_and_populate_proto(stages_poplar_options,
                                                proto_list, name):
    if stages_poplar_options is None:
      stages_poplar_options = [
          PipelineStageOptions() for i in range(len(computational_stages))
      ]

    if not isinstance(stages_poplar_options, (list, tuple)):
      raise TypeError(
          "%s must be a list or a tuple of PipelineStageOptions objects." %
          (name))

    if len(stages_poplar_options) != len(computational_stages):
      raise ValueError(
          "%s must be a list or a tuple of PipelineStageOptions objects of "
          "length %d (same number as the number of computational stages) but "
          "is %d." %
          (name, len(computational_stages), len(stages_poplar_options)))

    for stage_options in stages_poplar_options:
      if not isinstance(stage_options, PipelineStageOptions):
        raise TypeError(
            "Expected all elements of %s to be of type PipelineStageOptions, "
            "but got %s instead." % (name, str(stage_options)))

    for stage_options in stages_poplar_options:
      proto_list.append(stage_options.get_proto())

  pipeline_poplar_config = pipeline_config_pb2.PipelinePoplarConfig()

  validate_stage_options_and_populate_proto(
      forward_propagation_stages_poplar_options,
      pipeline_poplar_config.forward_stages,
      "forward_propagation_stages_poplar_options")

  if optimizer_function:
    validate_stage_options_and_populate_proto(
        backward_propagation_stages_poplar_options,
        pipeline_poplar_config.backward_stages,
        "backward_propagation_stages_poplar_options")

    if weight_update_poplar_options is None:
      weight_update_poplar_options = PipelineStageOptions()

    if not isinstance(weight_update_poplar_options, PipelineStageOptions):
      raise TypeError(
          "weight_update_poplar_options to be of type PipelineStageOptions, "
          "but got %s instead." % (str(weight_update_poplar_options)))

    pipeline_poplar_config.resource_update.CopyFrom(
        weight_update_poplar_options.get_proto())

  if outfeed_loss and not optimizer_function:
    raise ValueError(
        "An optimizer_function must be provided when outfeed_loss is True")

  control_outputs = []

  def _pipeline(*args):
    outputs = args
    for stage_id, stage in enumerate(computational_stages):
      stage_infeed_queue = infeed_queue if stage_id == 0 else None
      if stage_id == len(computational_stages) - 1 and not optimizer_function:
        stage_outfeed_queue = outfeed_queue
      else:
        stage_outfeed_queue = None

      stage_name = name + "_stage_" + str(stage_id)
      outputs = _pipeline_stage(stage,
                                stage_id,
                                device_mapping[stage_id],
                                outputs,
                                training=optimizer_function is not None,
                                infeed_queue=stage_infeed_queue,
                                outfeed_queue=stage_outfeed_queue,
                                name=stage_name)

    if optimizer_function:
      outputs = functional_ops._convert_to_list(outputs)  # pylint: disable=protected-access

      # Get the output from the optimizer function
      opt_fn = optimizer_function(*outputs)
      loss = opt_fn.loss
      opt = opt_fn.opt

      # Enqueue loss or any output tensors to the outfeed.
      if outfeed_loss:
        if not outfeed_queue:
          raise ValueError(
              "An outfeed_queue must be provided when outfeed_loss is True")
        control_outputs.append(outfeed_queue.enqueue(opt_fn.loss))
      elif outputs:
        if not outfeed_queue:
          raise ValueError(
              "The last computational stage has tensor outputs: %s, but no"
              " outfeed_queue has been provided." %
              (', '.join(str(t) for t in outputs)))
        control_outputs.append(outfeed_queue.enqueue(outputs))

      # Call the compute gradients function - this will be automatically put
      # into pipeline stages.
      grads_and_vars = opt.compute_gradients(loss)
      # Insert gradient accumulation ops.
      accumulated_grads_and_vars = []
      for grad, var in grads_and_vars:
        if grad is not None:
          with ops.colocate_with(grad):
            # Create an accumulator - variable is used as reference for shape/layout.
            accumulator = gen_poputil_ops.gradient_accumulator_create(var)
            # Add the gradients to the accumulator.
            accumulator = gen_poputil_ops.gradient_accumulator_add(
                accumulator, grad)
            # Sink the accumulators.
            grad = gen_poputil_ops.gradient_accumulator_sink(
                accumulator, num_mini_batches=gradient_accumulation_count)
        # Use the accumulated gradients.
        accumulated_grads_and_vars.append((grad, var))

      # Create an explicit function call for the apply gradients - note that we
      # allow external caputres here.
      apply_grad_ops = []

      def resource_update_():
        apply_grads = opt.apply_gradients(accumulated_grads_and_vars)
        apply_grad_ops.append(apply_grads)

      with ops.name_scope(name + "/WU") as scope:
        func_graph, captured_args = functional_ops._compile_function(  # pylint: disable=protected-access
            resource_update_, [], scope, apply_grad_ops, True)

      # Create the pipeline resource update stage and lower the function into XLA.
      with ops.control_dependencies(list(func_graph.control_captures)):
        outputs = gen_functional_ops.resource_update(
            captured_args,
            to_apply=util.create_new_tf_function(func_graph),
            Tout=func_graph.output_types,
            output_shapes=func_graph.output_shapes,
            offload_weight_update_variables=offload_weight_update_variables,
            replicated_optimizer_state_sharding=
            replicated_optimizer_state_sharding,
            num_batches_to_accumulate=gradient_accumulation_count)

    if not isinstance(outputs, ops.Operation):
      if not outfeed_queue:
        raise ValueError(
            "The last computational stage has tensor outputs: %s, but no"
            " outfeed_queue has been provided." % (', '.join(
                str(t) for t in functional_ops._convert_to_list(outputs))))  # pylint: disable=protected-access

      else:
        raise ValueError(
            "Expected the pipeline resource update stage to output a "
            "tf.Operation, got %s instead." % (str(output)))

    control_outputs.append(outputs)

  with ops.name_scope(name) as scope:
    # pylint: disable=protected-access
    try:
      func_graph, captured_args = functional_ops._compile_function(
          _pipeline, inputs, scope, control_outputs)
    except functional_ops._InvalidCaptureException as e:
      raise ValueError(
          "Trying to capture the tensor %s which is not a resource. This tensor"
          " needs to be passed as either part of the `input` or `infeed_queue`"
          " of the pipeline." % (str(e)))
    # pylint: enable=protected-access

    # Create the pipeline and lower the function into XLA.
    with ops.control_dependencies(list(func_graph.control_captures)):
      output = gen_functional_ops.pipeline(
          captured_args,
          to_apply=util.create_new_tf_function(func_graph),
          Tout=func_graph.output_types,
          output_shapes=func_graph.output_shapes,
          gradient_accumulation_count=gradient_accumulation_count,
          batch_serialization_iterations=batch_serialization_iterations,
          repeat_count=repeat_count,
          schedule=int(pipeline_schedule),
          pipeline_poplar_config=json_format.MessageToJson(
              pipeline_poplar_config),
          offload_activations=offload_activations,
          offload_gradient_accumulation_buffers=
          offload_gradient_accumulation_buffers,
          replicated_weight_sharding=replicated_weight_sharding,
          offload_weights=offload_weights)
    if not isinstance(output, ops.Operation):
      raise ValueError(
          "Expected the pipeline to output a tf.Operation, got %s instead." %
          (str(output)))

    return output