Example #1
0
def _train_on_tpu_system(model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  config = model_fn_wrapper.config.tpu_config
  iterations_per_loop = config.iterations_per_loop
  num_shards = config.num_shards

  single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
      dequeue_fn)

  multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat(  # pylint: disable=g-long-lambda
      iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop'))

  (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
                      inputs=[],
                      num_shards=num_shards,
                      outputs_from_all_shards=False)
  return loss
Example #2
0
def _train_on_tpu_system(model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  config = model_fn_wrapper.config.tpu_config
  iterations_per_loop = config.iterations_per_loop
  num_shards = config.num_shards

  single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
      dequeue_fn)

  multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat(  # pylint: disable=g-long-lambda
      iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop'))

  (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
                      inputs=[],
                      num_shards=num_shards,
                      outputs_from_all_shards=False)
  return loss
Example #3
0
 def iterate_on_tpu():
     return training_loop.repeat(iterations, run_fn,
                                 initial_loop_values)
Example #4
0
 def iterate_on_tpu():
   return training_loop.repeat(iterations, run_fn, initial_loop_values)
Example #5
0
 def loop():
     return training_loop.repeat(5, training_step, infeed_queue=infeed)
Example #6
0
  def _experimental_run_steps_on_iterator(
      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
    output_shapes = multi_worker_iterator.output_shapes
    shapes = nest.flatten(output_shapes)
    if any(not s.is_fully_defined() for s in shapes):
      raise ValueError(
          "TPU currently requires fully defined shapes. Either use "
          "set_shape() on the input tensors or use "
          "dataset.batch(..., drop_remainder=True).")

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = input_lib.MultiStepContext()

    def run_fn(inputs):
      """Single step on the TPU device."""
      fn_result = fn(ctx, inputs)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop and TPU replicate context. This is useful in cases
    # where we might need to exit these contexts and get back to the outer
    # context to do some things, for e.g. create an op which should be
    # evaluated only once at the end of the loop on the host. One such usage
    # is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    def rewrite_fn(*args):
      """The rewritten step fn running on TPU."""
      del args

      per_replica_inputs = multi_worker_iterator.get_next()
      replicate_inputs = []
      for replica_id in range(self._num_replicas_in_sync):
        select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
        replicate_inputs.append((nest.map_structure(
            select_replica, per_replica_inputs),))

      replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
      # will flatten it in this case. If run_fn has no tensor outputs,
      # tpu.replicate returns a list of no_ops, we will keep the output as it
      # is.
      if isinstance(replicate_outputs[0], list):
        replicate_outputs = nest.flatten(replicate_outputs)

      return replicate_outputs

    # TODO(sourabhbajaj): The input to while loop should be based on the
    # output type of the step_fn
    assert isinstance(initial_loop_values, list)
    initial_loop_values = initial_loop_values * self._num_replicas_in_sync

    # Put the while loop op on host 0.
    with ops.device(self.get_host_cpu_device(0)):
      replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
                                               initial_loop_values)

    del self._outer_control_flow_context
    ctx.run_op = control_flow_ops.group(replicate_outputs)

    if isinstance(replicate_outputs, list):
      # Filter out any ops from the outputs, typically this would be the case
      # when there were no tensor outputs.
      last_step_tensor_outputs = [
          x for x in replicate_outputs if not isinstance(x, ops.Operation)
      ]

      # Outputs are currently of the structure (flattened)
      # [output0_device0, output1_device0, output2_device0,
      #  output0_device1, output1_device1, output2_device1,
      #  ...]
      # Convert this to the following structure instead: (grouped by output)
      # [[output0_device0, output0_device1],
      #  [output1_device0, output1_device1],
      #  [output2_device0, output2_device1]]
      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
      last_step_tensor_outputs = [
          last_step_tensor_outputs[i::output_num] for i in range(output_num)
      ]
    else:
      # no tensors returned.
      last_step_tensor_outputs = []

    _set_last_step_outputs(ctx, last_step_tensor_outputs)
    return ctx
Example #7
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # pylint: disable=protected-access
        if self._container_strategy()._disable_training_loop_on_host:
            replicate_inputs = [[]] * self._num_replicas_in_sync
            replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        else:

            def rewrite_fn(*args):
                """The rewritten step fn running on TPU."""
                del args
                replicate_inputs = [[]] * self._num_replicas_in_sync
                replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

                # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
                # will flatten it in this case. If run_fn has no tensor outputs,
                # tpu.replicate returns a list of no_ops, we will keep the output as it
                # is.
                if isinstance(replicate_outputs[0], list):
                    replicate_outputs = nest.flatten(replicate_outputs)

                return replicate_outputs

            # TODO(sourabhbajaj): The input to while loop should be based on the
            # output type of the step_fn
            assert isinstance(initial_loop_values, list)
            initial_loop_values = initial_loop_values * self._num_replicas_in_sync

            # Put the while loop op on host 0.
            with ops.device(self.get_host_cpu_device(0)):
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        if self._container_strategy()._disable_training_loop_on_host:
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (grouped by device)
            # [[output0_device0, output1_device0, output2_device0],
            #  [output0_device1, output1_device1, output2_device1]]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            last_step_tensor_outputs = [
                list(x) for x in zip(*last_step_tensor_outputs)
            ]
        else:
            if isinstance(replicate_outputs, list):
                # Filter out any ops from the outputs, typically this would be the case
                # when there were no tensor outputs.
                last_step_tensor_outputs = [
                    x for x in replicate_outputs
                    if not isinstance(x, ops.Operation)
                ]

                # Outputs are currently of the structure (flattened)
                # [output0_device0, output1_device0, output2_device0,
                #  output0_device1, output1_device1, output2_device1,
                #  ...]
                # Convert this to the following structure instead: (grouped by output)
                # [[output0_device0, output0_device1],
                #  [output1_device0, output1_device1],
                #  [output2_device0, output2_device1]]
                output_num = len(
                    last_step_tensor_outputs) // self._num_replicas_in_sync
                last_step_tensor_outputs = [
                    last_step_tensor_outputs[i::output_num]
                    for i in range(output_num)
                ]
            else:
                # no tensors returned.
                last_step_tensor_outputs = []

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
            # value.
            if reduce_op is not None:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx
Example #8
0
 def loop():
   return training_loop.repeat(5, training_step, infeed_queue=infeed)
 def train_shard():
     return training_loop.repeat(
         run_config.tpu_config.iterations_per_loop,
         train_step,
         [1e7],  # initial_loss
         name='loop')
Example #10
0
  def _experimental_run_steps_on_iterator(
      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
    output_shapes = multi_worker_iterator.output_shapes
    shapes = nest.flatten(output_shapes)
    if any(not s.is_fully_defined() for s in shapes):
      raise ValueError(
          "TPU currently requires fully defined shapes. Either use "
          "set_shape() on the input tensors or use "
          "dataset.batch(..., drop_remainder=True).")
    types = nest.flatten(multi_worker_iterator.output_types)

    enqueue_ops = [
        self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes,
                                      iterations)
        for host_id in range(self.num_hosts)]

    def dequeue_fn():
      dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = input_lib.MultiStepContext()

    def run_fn(*args, **kwargs):
      """Single step on the TPU device."""
      del args, kwargs
      fn_result = fn(ctx, dequeue_fn())
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    def iterate_on_tpu():
      return training_loop.repeat(iterations, run_fn, initial_loop_values)

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop and TPU replicate context. This is useful in cases
    # where we might need to exit these contexts and get back to the outer
    # context to do some things, for e.g. create an op which should be
    # evaluated only once at the end of the loop on the host. One such usage
    # is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    # pylint: disable=protected-access
    if self._container_strategy()._disable_training_loop_on_host:
      replicate_inputs = [[]] * self._num_replicas_in_sync
      replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
    else:
      def rewrite_fn(*args):
        """The rewritten step fn running on TPU."""
        del args
        replicate_inputs = [[]] * self._num_replicas_in_sync
        replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

        # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
        # will flatten it in this case. If run_fn has no tensor outputs,
        # tpu.replicate returns a list of no_ops, we will keep the output as it
        # is.
        if isinstance(replicate_outputs[0], list):
          replicate_outputs = nest.flatten(replicate_outputs)

        return replicate_outputs

      # TODO(sourabhbajaj): The input to while loop should be based on the
      # output type of the step_fn
      assert isinstance(initial_loop_values, list)
      initial_loop_values = initial_loop_values * self._num_replicas_in_sync

      # Put the while loop op on host 0.
      with ops.device(self.get_host_cpu_device(0)):
        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
                                                 initial_loop_values)

    del self._outer_control_flow_context
    ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

    if self._container_strategy()._disable_training_loop_on_host:
      # Filter out any ops from the outputs, typically this would be the case
      # when there were no tensor outputs.
      last_step_tensor_outputs = [x for x in replicate_outputs
                                  if not isinstance(x, ops.Operation)]

      # Outputs are currently of the structure (grouped by device)
      # [[output0_device0, output1_device0, output2_device0],
      #  [output0_device1, output1_device1, output2_device1]]
      # Convert this to the following structure instead: (grouped by output)
      # [[output0_device0, output0_device1],
      #  [output1_device0, output1_device1],
      #  [output2_device0, output2_device1]]
      last_step_tensor_outputs = [list(x) for x in
                                  zip(*last_step_tensor_outputs)]
    else:
      if isinstance(replicate_outputs, list):
        # Filter out any ops from the outputs, typically this would be the case
        # when there were no tensor outputs.
        last_step_tensor_outputs = [
            x for x in replicate_outputs if not isinstance(x, ops.Operation)
        ]

        # Outputs are currently of the structure (flattened)
        # [output0_device0, output1_device0, output2_device0,
        #  output0_device1, output1_device1, output2_device1,
        #  ...]
        # Convert this to the following structure instead: (grouped by output)
        # [[output0_device0, output0_device1],
        #  [output1_device0, output1_device1],
        #  [output2_device0, output2_device1]]
        output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
        last_step_tensor_outputs = [
            last_step_tensor_outputs[i::output_num] for i in range(output_num)
        ]
      else:
        # no tensors returned.
        last_step_tensor_outputs = []

    # Convert replicate_outputs to the original dict structure of
    # last_step_outputs.
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been reduced, take the first value
      # from the list as each value should be the same. Else return the full
      # list of values.
      # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
      # value.
      if reduce_op is not None:
        # TODO(priyag): Should this return the element or a list with 1 element
        last_step_tensor_outputs_dict[name] = output[0]
    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

    return ctx
Example #11
0
 def train_shard():
   return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
                               train_step,
                               [1e7],  # initial_loss
                               name='loop')