Exemplo n.º 1
0
 def decode_loop_fn():
     if not self.num_batches:
         infinite_repeat(decode_fn, infeed)
     else:
         training_loop.repeat(self.num_batches,
                              decode_fn,
                              infeed_queue=infeed)
Exemplo n.º 2
0
 def TpuTrain():
     loop_result = tpu_training_loop.repeat(
         self._train_steps_per_loop,
         TpuTrainStep,
         inputs=[],
         name='train_loop')
     return loop_result
Exemplo n.º 3
0
 def TpuEval():
     loop_result = tpu_training_loop.repeat(
         self._steps_per_loop,
         TpuEvalStep,
         inputs=self._eval_metrics.initial_values,
         name='eval_loop')
     # Final metrics are the avg across self._steps_per_loop steps.
     return self._eval_metrics.FinalizeMetrics(loop_result)
Exemplo n.º 4
0
 def eval_loop(self):
     per_replica_eval_batch_size = self.eval_batch_size // self.num_replicas
     tf.get_variable_scope().reuse_variables()
     predictions = tf.zeros(
         [self.eval_steps, per_replica_eval_batch_size, 2])
     _, predictions = training_loop.repeat(int(self.eval_steps),
                                           self.eval_step,
                                           [tf.constant(0), predictions])
     with tf.control_dependencies(
         [tpu_ops.outfeed_enqueue_tuple([predictions])]):
         return tf.no_op()
Exemplo n.º 5
0
 def DecodeLoopFn():
     return tpu_training_loop.repeat(self._steps_per_loop,
                                     _DecodeStep,
                                     inputs=[])
Exemplo n.º 6
0
 def loop():
     return training_loop.repeat(5,
                                 training_step,
                                 infeed_queue=infeed)
Exemplo n.º 7
0
  def _experimental_run_steps_on_iterator(
      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
    output_shapes = multi_worker_iterator.output_shapes
    shapes = nest.flatten(output_shapes)
    if any(not s.is_fully_defined() for s in shapes):
      raise ValueError(
          "TPU currently requires fully defined shapes. Either use "
          "set_shape() on the input tensors or use "
          "dataset.batch(..., drop_remainder=True).")

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)
    ctx = input_lib.MultiStepContext()

    def run_fn(inputs):
      """Single step on the TPU device."""
      fn_result = fn(ctx, inputs)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      if flat_last_step_outputs:
        with ops.control_dependencies([fn_result]):
          return [array_ops.identity(f) for f in flat_last_step_outputs]
      else:
        return fn_result

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop and TPU replicate context. This is useful in cases
    # where we might need to exit these contexts and get back to the outer
    # context to do some things, for e.g. create an op which should be
    # evaluated only once at the end of the loop on the host. One such usage
    # is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    def rewrite_fn(*args):
      """The rewritten step fn running on TPU."""
      del args

      per_replica_inputs = multi_worker_iterator.get_next()
      replicate_inputs = []
      for replica_id in range(self._num_replicas_in_sync):
        select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
        replicate_inputs.append((nest.map_structure(
            select_replica, per_replica_inputs),))

      replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
      # will flatten it in this case. If run_fn has no tensor outputs,
      # tpu.replicate returns a list of no_ops, we will keep the output as it
      # is.
      if isinstance(replicate_outputs[0], list):
        replicate_outputs = nest.flatten(replicate_outputs)

      return replicate_outputs

    # TODO(sourabhbajaj): The input to while loop should be based on the
    # output type of the step_fn
    assert isinstance(initial_loop_values, list)
    initial_loop_values = initial_loop_values * self._num_replicas_in_sync

    # Put the while loop op on TPU host 0.
    with ops.device(self._host_device):
      if self.steps_per_run == 1:
        replicate_outputs = rewrite_fn()
      else:
        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
                                                 initial_loop_values)

    del self._outer_control_flow_context
    ctx.run_op = control_flow_ops.group(replicate_outputs)

    if isinstance(replicate_outputs, list):
      # Filter out any ops from the outputs, typically this would be the case
      # when there were no tensor outputs.
      last_step_tensor_outputs = [
          x for x in replicate_outputs if not isinstance(x, ops.Operation)
      ]

      # Outputs are currently of the structure (flattened)
      # [output0_device0, output1_device0, output2_device0,
      #  output0_device1, output1_device1, output2_device1,
      #  ...]
      # Convert this to the following structure instead: (grouped by output)
      # [[output0_device0, output0_device1],
      #  [output1_device0, output1_device1],
      #  [output2_device0, output2_device1]]
      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
      last_step_tensor_outputs = [
          last_step_tensor_outputs[i::output_num] for i in range(output_num)
      ]
    else:
      # no tensors returned.
      last_step_tensor_outputs = []

    _set_last_step_outputs(ctx, last_step_tensor_outputs)
    return ctx
 def train_eval_loop():
     return training_loop.repeat(self.num_epochs_tensor,
                                 train_eval_step, [])
 def train_eval_loop():
   return training_loop.repeat(self.hparams.max_train_epochs,
                               train_eval_step, [])
 def train_loop():
     return training_loop.repeat(self.train_steps_tensor,
                                 tpu_train_step, [_INITIAL_LOSS])
 def train_eval_loop():
     return training_loop.repeat(self.max_train_iterations,
                                 train_eval_step, [0])
 def train_loop():
     return training_loop.repeat(self.iterations, tpu_train_step,
                                 [_INITIAL_LOSS])
Exemplo n.º 13
0
 def train_loop():
     return training_loop.repeat(self.iterations_per_loop, train_step,
                                 tf.constant(0))
Exemplo n.º 14
0
 def eval_loop(self):
     tf.get_variable_scope().reuse_variables()
     return training_loop.repeat(int(self.eval_steps), self.eval_step)
 def eval_loop():
     return training_loop.repeat(self.eval_steps_tensor, tpu_eval_step,
                                 [])
Exemplo n.º 16
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(inputs):
            """Single step on the TPU device."""
            fn_result = fn(ctx, inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        def rewrite_fn(*args):
            """The rewritten step fn running on TPU."""
            del args

            per_replica_inputs = multi_worker_iterator.get_next()
            replicate_inputs = []
            for replica_id in range(self._num_replicas_in_sync):
                select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
                replicate_inputs.append(
                    (nest.map_structure(select_replica, per_replica_inputs), ))

            replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

            # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
            # will flatten it in this case. If run_fn has no tensor outputs,
            # tpu.replicate returns a list of no_ops, we will keep the output as it
            # is.
            if isinstance(replicate_outputs[0], list):
                replicate_outputs = nest.flatten(replicate_outputs)

            return replicate_outputs

        # TODO(sourabhbajaj): The input to while loop should be based on the
        # output type of the step_fn
        assert isinstance(initial_loop_values, list)
        initial_loop_values = initial_loop_values * self._num_replicas_in_sync

        # Put the while loop op on TPU host 0.
        with ops.device(self._host_device):
            if self.steps_per_run == 1:
                replicate_outputs = rewrite_fn()
            else:
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs)

        if isinstance(replicate_outputs, list):
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (flattened)
            # [output0_device0, output1_device0, output2_device0,
            #  output0_device1, output1_device1, output2_device1,
            #  ...]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            output_num = len(
                last_step_tensor_outputs) // self._num_replicas_in_sync
            last_step_tensor_outputs = [
                last_step_tensor_outputs[i::output_num]
                for i in range(output_num)
            ]
        else:
            # no tensors returned.
            last_step_tensor_outputs = []

        _set_last_step_outputs(ctx, last_step_tensor_outputs)
        return ctx
 def eval_loop():
   if self.eval_steps > 0:
     return training_loop.repeat(self.eval_steps, tpu_eval_step, [])
   else:
     return tf.no_op()