예제 #1
0
def get_zero_batch(batch_size=None,
                   max_len=None,
                   key_size=2,
                   return_tgt_mask=False):
    """Returns zero batch.

  Args:
    batch_size: batch size.
    max_len: max length.
    key_size: key size.
    return_tgt_mask: if to return tgt_mask.
  Returns: a tuple of tensors
    key: int32 tensor [batch_size, key_size]
    tgt_id: int32 tensor [batch_size, max_len]
    tgt_segment_id: float32 tensor [batch_size, max_len]
    tgt_segment_pos: int32 tensor [batch_size, max_len]
    tgt_labels: int32 tensor [batch_size, max_len]
    tgt_sample_temperature: float32 tensor [batch_size]
    tgt_mask: optional float32 tensor [batch_size, max_len, max_len]
  """
    batch = preload_zero(n=1,
                         batch_size=batch_size,
                         max_len=max_len,
                         key_size=key_size)
    batch = py_utils.Transform(lambda x: np.squeeze(x, 0), batch)
    if return_tgt_mask:
        tgt_mask = np.zeros([batch_size, max_len, max_len], np.float32)
        batch = (*batch, tgt_mask)
    return batch
예제 #2
0
    def testGraphLayer(self):
        g = tf.Graph()
        with g.as_default(), self.SetEval(True):
            tf.random.set_seed(24332)

            def _FnMeta(*shapes):
                return py_utils.NestedMap(flops=1, out_shapes=shapes)

            p = layers.GraphLayer.Params().Set(
                name='graph',
                input_endpoints=['x'],
                output_endpoints=['y'],
                sub=[
                    ('x.a->y.c',
                     layers.FnLayer.Params().Set(fn=lambda x: 2 * x,
                                                 fn_meta=_FnMeta)),
                    ('x.b->y.d',
                     layers.FnLayer.Params().Set(name='bar',
                                                 fn=lambda x: x + 2,
                                                 fn_meta=_FnMeta)),
                    ('y.c,y.d->y.e, y.f',
                     layers.FnLayer.Params().Set(name='baz',
                                                 fn=lambda x, y:
                                                 (x + y, x - y),
                                                 fn_meta=_FnMeta)),
                ])
            l = p.Instantiate()
            x = py_utils.NestedMap(a=tf.constant(1.0), b=tf.constant(2.0))
            y = l.FProp(l.theta, x)
            y_shape = l.FPropMeta(
                p, py_utils.Transform(lambda t: tshape.Shape(t.shape),
                                      x)).out_shapes[0]
            self.assertDictEqual(
                py_utils.Transform(lambda t: t.shape.as_list(), y),
                py_utils.Transform(lambda t: t.ToTensorShape().as_list(),
                                   y_shape))

        with self.session(graph=g) as sess:
            sess.run(tf.global_variables_initializer())
            y_val = sess.run(y)
            print(y_val)
            self.assertEqual(py_utils.NestedMap(c=2.0, d=4.0, e=6.0, f=-2.0),
                             y_val)
예제 #3
0
 def testEmptySequentialLayerFPropMeta(self):
   g = tf.Graph()
   with g.as_default():
     p = layers.SequentialLayer.Params().Set(name='seq')
     l = p.Instantiate()
     x = py_utils.NestedMap(val=tf.random.normal(shape=[2, 32]))
     y = l.FPropDefaultTheta(x)
     self.assertIsInstance(y.val, tf.Tensor)
     y_shape = l.FPropMeta(
         p, py_utils.Transform(lambda t: tshape.Shape(t.shape),
                               x)).out_shapes[0]
     self.assertEqual(y.val.shape.as_list(),
                      y_shape.val.ToTensorShape().as_list())
예제 #4
0
    def _InternalGetTheta(self):
        ret = py_utils.Transform(lambda x: x.theta, self.children)

        private_theta = self._private_theta

        # When ExecutorTpu specifies the EMA (e.g. when running eval/decode program
        # with EMA enabled), use the EMA version of the variables if applicable.
        if self.cluster.is_executor_tpu and self.do_eval and self.ema:
            vars_loaded_as_ema = self.params.is_inference or (
                self.do_eval and not py_utils.use_tpu())
            assert not vars_loaded_as_ema, (
                'Not able to use EMA variables since the layer variables are '
                'potentially already loaded as EMA variables.')

            def MaybeUseEmaVar(x):
                if not isinstance(x, tf.Variable):
                    raise ValueError(
                        'EMA is used but self._private_theta contains '
                        f'non-variables: {x}.')
                ema_x = self.ema.average(x)
                return ema_x if ema_x is not None else x

            private_theta = py_utils.Transform(MaybeUseEmaVar, private_theta)

        if (self._params.fprop_dtype is not None
                and self._params.fprop_dtype != self._params.dtype):

            def MaybeCastToFPropDtype(x):
                # Need to check `.base_dtype` as x.dtype may be tf.float32_ref.
                if x is not None and x.dtype.base_dtype == self._params.dtype:
                    return tf.cast(x, self._params.fprop_dtype)
                else:
                    return x

            private_theta = py_utils.Transform(MaybeCastToFPropDtype,
                                               private_theta)

        ret.update(private_theta)
        return ret
예제 #5
0
    def _TransformVarsInternal(self, fn):
        """Internal: replaces each variable v in self._private_vars with fn(v).

    Also recursively invokes _TransformVarsInternal() on self.children.

    Args:
      fn: A function that takes a variable and returns a variable or a wrapper
        of the variable.
    """
        self._private_vars_transform_restore_stack.append(self._private_vars)
        self._private_vars = {
            key: fn(x)
            for key, x in self._private_vars.items()
        }
        py_utils.Transform(
            lambda c: c._TransformVarsInternal(fn),  # pylint: disable=protected-access
            self.children)
예제 #6
0
    def CreateChildren(
        self, name: str, params: Union[List[BaseLayerParamsT],
                                       Mapping[str, BaseLayerParamsT]]
    ) -> None:
        """Create a list or dict of sub layers.

    The created sub layer list can be accessed by `name`. E.g.::

        self.CreateChildren('foo', ...)
        self.foo[10].FProp...

    or::

        self.children['foo'][10].Fprop...
        self.children.foo[10].Fprop...

    Args:
      name: The name for the sub layers, which is used as the key into
        vars/theta.
      params: a list or dict of `Hyperparams` objects to create.
    """
        if hasattr(self,
                   '_disable_create_child') and self._disable_create_child:
            raise ValueError(
                'Attempting to call CreateChildren outside of __init__.')
        self._CheckName(name)

        uid = itertools.count()

        def Instantiate(p):
            p = self.CopyBaseParams(self.params, p.Copy())
            if not p.name:
                p.name = '%s_%d' % (name, next(uid))
            return p.Instantiate()

        with self._CreateChildContext(name):
            self._private_children[name] = py_utils.Transform(
                Instantiate, params)
예제 #7
0
  def _CalculateOutputShapes(self, input_shapes):
    """Calcuate the output shape of intermediate layers.

    Given the FPropMeta function in each FeatureExtractionLayer, calcuates
    the shapes of outputs of that layer. This is used to recover the shape
    information in StackedRecurrent.

    Args:
      input_shapes: NestedMap or tuple of input TensorShapes.

    Returns:
      Return a list of K + 1 NestedMaps or lists of tShape where K is
      the number of partitions.
    """
    p = self.params
    shapes = []

    # Converts TensorShape to tshape.Shape.
    def _ToTShape(x):
      if x is None:
        return None
      return tshape.Shape(x.as_list())

    shapes = py_utils.Transform(_ToTShape, input_shapes)
    shapes = _ToTuple(shapes)

    state_shapes = []
    for (_, cell) in self._before_layers:
      shapes = cell.FPropMeta(cell.params, *shapes).out_shapes

    state_shapes.append(shapes[0] if p.nested_map_fprop else shapes)

    for (_, cell) in self._cells:
      shapes = cell.FPropMeta(cell.params, *shapes).out_shapes
      state_shapes.append(shapes[0] if p.nested_map_fprop else shapes)

    return state_shapes
예제 #8
0
 def fetch_shapes(self):
   # Conversion from dict to NestedMap required.
   return py_utils.Transform(
       lambda x: self._graph.get_tensor_by_name(x).shape.as_list(),
       py_utils.NestedMap(self._fetches))
예제 #9
0
 def _VarNamesDebugString(vars_):
     return py_utils.Transform(lambda x: x.name, vars_).DebugString()
예제 #10
0
  def _DecodeOnce(self, sess=None, path=''):
    """Decode a single checkpoint."""
    with self._cluster:
      # Attempt to restore the checkpoint
      self._checkpointer.RestoreFromPath(checkpoint_path=path)

      global_step = self._model.global_step.numpy()
      if global_step < self._task.params.eval.start_decoder_after:
        return

      if self._task.input.params.resettable:
        tf.logging.info('Resetting input_generator.')
        self._task.input_generator.Reset()

      dec_metrics = self._task.CreateDecoderMetrics()
      if not dec_metrics:
        tf.logging.info('Empty decoder metrics')
        return
      buffered_decode_out = []
      num_samples_metric = dec_metrics['num_samples_in_batch']

      samples_per_summary = self._task.params.eval.decoder_samples_per_summary
      if samples_per_summary is None:
        samples_per_summary = self._task.params.eval.samples_per_summary
      if samples_per_summary == 0:
        assert self._task.input.params.resettable

      start_time = time.time()
      while samples_per_summary == 0 or (num_samples_metric.total_value <
                                         samples_per_summary):
        try:
          tf.logging.info('Fetching dec_output.')
          fetch_start = time.time()
          # Decoder calls FProp multiple times for each checkpoint. Multiple
          # summaries at the same step is often confusing.  Instead, models
          # should generate aggregate summaries using PostProcessDecodeOut.
          # Other types of summaries (images, audio etc.) will be generated for
          # the first batch only.
          is_first_loop = num_samples_metric.total_value == 0
          decode_fn = (
              self._decode_fn_with_summary
              if is_first_loop else self._decode_fn)
          input_batch, dec_output = decode_fn()

          for key in self._task.input_generator.GetCpuPassthroughKeys():
            if key in input_batch:
              if key in dec_output:
                tf.logging.warning(
                    f'Key {key} already present in decode output. '
                    f'Not adding from input batch.')
              else:
                dec_output[key] = input_batch[key]

          dec_output = py_utils.Transform(lambda x: x.numpy(), dec_output)

          post_process_start = time.time()
          tf.logging.info('Done fetching (%f seconds)' %
                          (post_process_start - fetch_start))
          decode_out = self._task.PostProcessDecodeOut(dec_output, dec_metrics)

          if decode_out:
            if isinstance(decode_out, dict):
              decode_out = decode_out.items()

            if is_first_loop:
              # Add summaries only for the first batch of data.
              with self._summary_writer.as_default():
                for key, value in decode_out:
                  if isinstance(value, tf.Summary):
                    tf.logging.info(f'Adding summary {key} with tags '
                                    f'{[x.tag for x in value.value]}.')
                    tf.compat.v2.summary.experimental.write_raw_pb(
                        tf.constant(value.SerializeToString()), global_step)

            buffered_decode_out.extend(
                kv for kv in decode_out if not isinstance(kv[1], tf.Summary))

          tf.logging.info(
              'Total examples done: %d/%d '
              '(%f seconds decode postprocess)', num_samples_metric.total_value,
              samples_per_summary,
              time.time() - post_process_start)

        except tf.errors.OutOfRangeError:
          if not self._task.input.params.resettable:
            raise
          break

      tf.logging.info('Done decoding ckpt: %s', path)

      elapsed_secs = time.time() - start_time
      example_rate = num_samples_metric.total_value / elapsed_secs
      msg = 'step:%6d, elapsed_secs: %0.2f, examples/sec: %0.2f' % (
          global_step, elapsed_secs, example_rate)
      with self._summary_writer.as_default():
        tf.compat.v2.summary.scalar(
            'decode_secs', elapsed_secs, step=global_step)
        tf.compat.v2.summary.scalar(
            'examples/sec', example_rate, step=global_step)
        tf.compat.v2.summary.scalar(
            'total_samples', num_samples_metric.total_value, step=global_step)
        for key, metric in sorted(dec_metrics.items()):
          msg += ' %s:%.8g' % (key, metric.value)
          tf.compat.v2.summary.scalar(key, metric.value, step=global_step)
        self._summary_writer.flush()
      self._SetStatusMessage(msg)

      self._ExportMetrics(
          # Metrics expects python int, but global_step is numpy.int64.
          decode_checkpoint=int(global_step),
          dec_metrics=dec_metrics,
          example_rate=example_rate)

      decode_out_path = self.GetDecodeOutPath(self._decoder_dir, global_step)
      decode_finalize_args = base_model.DecodeFinalizeArgs(
          decode_out_path=decode_out_path, decode_out=buffered_decode_out)
      self._task.DecodeFinalize(decode_finalize_args)
예제 #11
0
 def subgraph_feed_shapes(self, subgraph_name):
     # Conversion from dict to NestedMap required.
     return py_utils.Transform(
         lambda x: self._graph.get_tensor_by_name(x).shape.as_list(),
         py_utils.NestedMap(self._get_subgraph_feeds(subgraph_name)))
예제 #12
0
 def AddIdentityToTheta(layer):
   # pylint: disable=protected-access
   layer._private_theta = py_utils.Transform(tf.identity,
                                             layer._private_theta)
   # pylint: enable=protected-access
   layer.children.Transform(AddIdentityToTheta)
예제 #13
0
 def accumulators(self):
     """Returns `.NestedMap` of `Accumulator` instances for this and children."""
     ret = py_utils.Transform(lambda x: x.accumulators, self.children)
     for k, acc in self._private_accumulators.items():
         ret[k] = acc
     return ret
예제 #14
0
 def _UndoTransformVarsInternal(self):
     """Internal. Undoes _TransformVarsInternal()."""
     self._private_vars = self._private_vars_transform_restore_stack.pop()
     py_utils.Transform(
         lambda c: c._UndoTransformVarsInternal(),  # pylint: disable=protected-access
         self.children)
예제 #15
0
  def _EvalOnce(self, sess=None, path=''):
    """Eval a single checkpoint."""
    with self._cluster:
      # Attempt to restore the checkpoint
      self._checkpointer.RestoreFromPath(checkpoint_path=path)

      # Save any additional information to disk before evaluation.
      if self._eval_type == 'train':
        self._task.Export(path)

      global_step = self._model.global_step.numpy()
      if global_step < self._task.params.eval.start_eval_after:
        return

      if self._task.input.params.resettable:
        tf.logging.info('Resetting input_generator.')
        self._task.input_generator.Reset()

      metrics_dict = None
      num_samples_metric = None
      samples_per_summary = self._task.params.eval.samples_per_summary
      if samples_per_summary == 0:
        assert self._task.input.params.resettable
      while (samples_per_summary == 0 or metrics_dict is None or
             num_samples_metric.total_value < samples_per_summary):
        try:
          # Evaler calls FProp multiple times for each checkpoint. Multiple
          # summaries at the same step is often confusing.  Instead, models
          # should update eval_metrics and generate aggregate summaries. Other
          # types of summaries (images, audio etc.) will be generated for the
          # first batch only.
          eval_fn = (
              self._eval_fn_with_summary
              if metrics_dict is None else self._eval_fn)
          eval_metrics = eval_fn()

          if metrics_dict is None:
            metrics_dict = {
                name: metrics.AverageMetric() for name in eval_metrics
            }
            num_samples_metric = metrics_dict['num_samples_in_batch']

          eval_metrics = py_utils.Transform(lambda x: x.numpy(), eval_metrics)
          for name, (value, weight) in eval_metrics.items():
            metrics_dict[name].Update(value, weight)
          tf.logging.info('Total examples done: %d/%d',
                          num_samples_metric.total_value, samples_per_summary)
        except tf.errors.OutOfRangeError:
          if not self._task.input.params.resettable:
            raise
          break

      if metrics_dict is None:
        metrics_dict = {}

      # Replace average values with total values for certain metrics.
      if 'num_predictions' in metrics_dict:
        metrics_dict['num_predictions'].total_weight = 1.0
      if 'num_words' in metrics_dict:
        metrics_dict['num_words'].total_weight = 1.0

      msg = 'step:%6d' % global_step
      with self._summary_writer.as_default():
        tf.compat.v2.summary.scalar(
            'total_samples', num_samples_metric.total_value, step=global_step)
        for key, metric in sorted(metrics_dict.items()):
          msg += ' %s:%.8g' % (key, metric.value)
          tf.compat.v2.summary.scalar(key, metric.value, step=global_step)
        self._summary_writer.flush()
      self._SetStatusMessage(msg)
예제 #16
0
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if self.do_eval:
      outputs = copy.copy(args)
      for (name, l) in self._before_layers + self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tensors.
    input_shapes = self._get_input_shapes(*args)
    state_dtype = self._get_state_dtype(*args)
    state_shapes = self._CalculateOutputShapes(input_shapes)
    tf.logging.info('state_shapes={}'.format(state_shapes))

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0

        def _FPropInputSetShape(name, t_shape):
          if t_shape is None:
            return None
          inputs[name].set_shape(t_shape.ToTensorShape().as_list())
          return inputs[name]

        if p.nested_map_fprop:
          # pylint: disable=protected-access
          fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape)
          # pylint: enable=protected-access
        else:
          fprop_inputs = []
          for input_idx, input_shape in enumerate(state_shapes[i]):
            name = 's{}'.format(input_idx)
            fprop_inputs.append(_FPropInputSetShape(name, input_shape))

        with py_utils.RemoveAssertContext(remove=True):
          with CellFnFPropOpReplacementWrapper():
            tf.logging.info('cell {} input {}'.format(i, fprop_inputs))
            mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
            SetOverWriteGlobalStep(mb_tensor)
            _, cell = self._cells[i]
            fprop_inputs = _ToTuple(fprop_inputs)
            outputs = cell.FProp(theta, *fprop_inputs)

        if p.nested_map_fprop:
          assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
          state1 = outputs.Filter(lambda x: x is not None)
        else:
          state1 = py_utils.NestedMap()
          outputs = _ToTuple(outputs)
          assert len(outputs) == len(state_shapes[i + 1])
          for output_idx in range(len(outputs)):
            if outputs[output_idx] is not None:
              name = 's{}'.format(output_idx)
              state1[name] = outputs[output_idx]
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])

      def _TfZeros(t_shape):
        if t_shape is None:
          return None
        return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype)

      if p.nested_map_fprop:
        init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1])
        init_state = init_state.Filter(lambda x: x is not None)
      else:
        init_state = py_utils.NestedMap()
        for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
          state = _TfZeros(state)
          if state is not None:
            name = 's{}'.format(output_idx)
            init_state[name] = state
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      init_states.append(init_state)

      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = _ToTuple(args)
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)

      def _StackAndSplit(x):
        # Split tensors into microbatches.
        if x is None:
          return None
        return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))

      if p.nested_map_fprop:
        inputs = py_utils.Transform(_StackAndSplit, previous[0])
        inputs = inputs.Filter(lambda x: x is not None)
      else:
        inputs = py_utils.NestedMap()
        for output_idx, output_tensor in enumerate(previous):
          output_tensor = _StackAndSplit(output_tensor)
          if output_tensor is not None:
            name = 's{}'.format(output_idx)
            inputs[name] = output_tensor
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])
    tf.logging.info('pipeline input = {}'.format(inputs))
    output_state, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):

      def _ReshapeRetVal(name, t_shape):
        """Restore shape for tensors in microbatches."""
        if t_shape is None:
          return None
        output_tensor = output_state[name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        output_shape = t_shape.ToTensorShape().as_list()
        output_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, output_shape)
        return output_tensor

      # Construct the final return values from output_state.
      if p.nested_map_fprop:
        # pylint: disable=protected-access
        output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
        # pylint: enable=protected-access
      else:
        output_tensors = []
        for output_idx, state_shape in enumerate(state_shapes[-1]):
          output_name = 's{}'.format(output_idx)
          output_tensor = _ReshapeRetVal(output_name, state_shape)
          output_tensors.append(output_tensor)
        if len(output_tensors) == 1:
          output_tensors = output_tensors[0]
        else:
          output_tensors = tuple(output_tensors)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      return output_tensors
예제 #17
0
 def vars(self):
     """Returns variables of this layer and its children in a `.NestedMap`."""
     ret = py_utils.Transform(lambda x: x.vars, self.children)
     for k in self._private_vars.keys():
         ret[k] = self._private_vars[k]
     return ret