예제 #1
0
 def _BeamSearchDecode(self, input_batch):
     p = self.params
     with tf.name_scope('fprop'), tf.name_scope(p.name):
         encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src)
         encoder_outputs = self.dec.AddExtraDecodingInfo(
             encoder_outputs, input_batch.tgt)
         decoder_outs = self.dec.BeamSearchDecode(encoder_outputs)
         return self._ProcessBeamSearchDecodeOut(input_batch,
                                                 encoder_outputs,
                                                 decoder_outs)
예제 #2
0
  def FProp(self, theta, *args):
    """FProp through multiple devices in the split.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: A tuple of Tensors (one or more). Every tensor's first dimension is
        the same (the batch dimension).

    Returns:
      The sub layer's output.
    """
    p = self.params
    with tf.name_scope(p.name):
      assert all(isinstance(x, tf.Tensor) for x in args)
      cluster = self.cluster
      num = cluster.num_devices_per_split
      if num == 1:
        return self.sub.FProp(theta.sub, *args)
      inps = py_utils.SplitRecursively(list(args), num, axis=0)
      outs = []
      for i, xs in enumerate(inps):
        device = cluster.WorkerDeviceInModelSplit(i)
        tf.logging.info('%d on device %s', i, device)
        with tf.device(device):
          ys = self.sub.FProp(theta.sub, *xs)
          if isinstance(ys, tuple):
            outs += [list(ys)]
          else:
            outs += [ys]  # ys is a single tensor
      ret = py_utils.ConcatRecursively(outs, axis=0)
      if isinstance(ret, list):
        return tuple(ret)
      else:
        return ret  # ys is a single tensor
 def FProp(self, theta, source_id, source_paddings, target_id,
           target_paddings, source_segment_id, target_segment_id,
           source_pos_id, target_pos_id, source_task_id, target_task_id):
     p = self.params
     with tf.name_scope(p.name):
         src_task_emb, src_task_emb_theta = None, None
         if p.enc_task_emb:
             src_task_emb, src_task_emb_theta = self.src_task_emb, theta.src_task_emb
         source_vecs = self.GetEmbeddings(
             theta.src_token_emb, self.src_token_emb, theta.src_pos_emb,
             self.src_pos_emb, theta.src_dropout, self.src_dropout,
             source_id, source_pos_id, src_task_emb_theta, src_task_emb,
             source_task_id)
         target_vecs = None
         if p.add_tgt_embedding_layer:
             tgt_task_emb, tgt_task_emb_theta = None, None
             if p.enc_task_emb:
                 tgt_task_emb, tgt_task_emb_theta = (self.tgt_task_emb,
                                                     theta.tgt_task_emb)
             target_vecs = self.GetEmbeddings(
                 theta.tgt_token_emb, self.tgt_token_emb, theta.tgt_pos_emb,
                 self.tgt_pos_emb, theta.tgt_dropout, self.tgt_dropout,
                 target_id, target_pos_id, tgt_task_emb_theta, tgt_task_emb,
                 target_task_id)
         rets = (source_vecs, source_paddings, target_vecs, target_paddings,
                 source_segment_id, target_segment_id, None, None)
         rets += (source_task_id, target_task_id) if p.ret_task_ids else ()
         return rets
예제 #4
0
 def FProp(self, theta, *args):
   p = self.params
   with tf.name_scope(p.name):
     args = _ToTuple(self.body.FProp(theta.body, *args))
     for fetch in p.fetches:
       args += (self.body.GetDescendant(fetch).activation,)
     return args
예제 #5
0
  def FProp(self, theta, *args):
    """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
    p = self.params
    for arg in args:
      if arg is not None:
        arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

    theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat)
    inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
    # Infer out_shapes from FPropMeta.
    out_shapes = self._InferOutShapes(args)

    def _CellFn(unused_theta, unused_state0, inputs):
      """Recurrent cell function wrapper of body.FProp."""
      # Sets shapes for both theta and inputs to self.body.FProp.
      for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                          list(args) + theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(out_shapes)
      # Passes fprop outputs to the next layer through state.
      state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
      return state1, py_utils.NestedMap()

    with tf.name_scope(p.name):
      # Initiate state0 with inferred output shapes.
      state0 = py_utils.NestedMap(
          outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes])
      # Runs body.FProp p.repeat times using Recurrent.
      acc_states, _ = recurrent.Recurrent(
          theta=py_utils.NestedMap(),
          state0=state0,
          inputs=inputs,
          cell_fn=_CellFn)

      # Retrieves fprop outputs from state1 and sets shapes.
      output_tensors = tuple(acc_states.outputs)
      for out_idx in range(len(output_tensors)):
        output_tensors[out_idx].set_shape(
            tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

      return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
예제 #6
0
 def FProp(self, theta, *args):
   r"""Applies lambda(x, \*kwargs) for every non-None arg."""
   del theta
   p = self.params
   with tf.name_scope(p.name):
     ret = [None if x is None else p.fn(x, **p.kwargs) for x in args]
     return tuple(ret) if len(ret) > 1 else ret[0]
예제 #7
0
    def PrepareExternalInputs(self, theta, external_inputs):
        """Prepares external inputs for each sub-step.

    The external_inputs parameter of this method is processed by the
    external_inputs of each sub-step, then processed by the sub-step's
    PrepareExternalInputs method.

    Args:
      theta: variables used by sub-steps.
      external_inputs: A NestedMap of [n_batch, ...] tensors.

    Returns:
      A NestedMap of prepared inputs, where the keys are the names of
        each sub-step.
    """
        graph_tensors = builder_layers.GraphTensors()
        graph_tensors.StoreTensor('external_inputs', external_inputs)
        prepared_inputs = py_utils.NestedMap()
        with tf.name_scope(self.params.name):
            for seq in self._seq:
                if seq.external_signature:
                    template = py_utils.NestedMap(
                        inputs=seq.external_signature.inputs)
                    packed = template.Transform(graph_tensors.GetTensor)
                    seq_external_inputs = packed.inputs[0]
                    prepared_inputs[seq.name] = seq.step.PrepareExternalInputs(
                        theta[seq.name], seq_external_inputs)
                else:
                    prepared_inputs[seq.name] = py_utils.NestedMap()
        return prepared_inputs
 def FProp(self,
           theta,
           source_vecs,
           source_paddings,
           target_vecs,
           target_paddings,
           source_segment_id,
           target_segment_id,
           transparent_acc,
           transparent_acc_helper,
           source_task_id=None,
           target_task_id=None):
     p = self.params
     with tf.name_scope(p.name):
         if p.has_aux_atten:  # Decoder FProp
             return _common_gpipe_transformer_decoder_fprop(
                 self, GPipeTransformerLayer, theta, source_vecs,
                 source_paddings, target_vecs, target_paddings,
                 source_segment_id, target_segment_id, transparent_acc,
                 transparent_acc_helper, source_task_id, target_task_id)
         else:  # Encoder FProp
             return _common_gpipe_transformer_encoder_fprop(
                 self, GPipeTransformerLayer, theta, source_vecs,
                 source_paddings, target_vecs, target_paddings,
                 source_segment_id, target_segment_id, transparent_acc,
                 transparent_acc_helper, source_task_id, target_task_id)
def fast_gather(values,
                ids,
                ids_size,
                max_value=None,
                axis=0,
                batch_major_state=True):
    """Fast implementation of gather on TPUs.

  Args:
    values: Values to gather from.
    ids: ids (rows to gather)
    ids_size: id space size.
    max_value: Optional hint on maximum value for int32 that allows to speed up
      the gather operation.
    axis: axis to gather on. Defaults to 0 (rows).
    batch_major_state: Whether the values to gather from use batch major or not.
      Defaults to True.

  Returns:
    Gathered values.
  Raises:
    Value error if values is type int64.
  """
    values = tf.convert_to_tensor(values)
    ids = tf.convert_to_tensor(ids)
    with tf.name_scope("fast_gather"):
        return _Gatherer(ids, ids_size)(values,
                                        max_value=max_value,
                                        axis=axis,
                                        batch_major_state=batch_major_state)
 def Step(recurrent_theta, state0, inputs):
     """Computes one decoder step."""
     del inputs
     with tf.name_scope('single_sampler_step'):
         # Compute logits and states.
         bs_result, bs_state1 = pre_step_callback(
             recurrent_theta.theta,
             recurrent_theta.encoder_outputs,
             tf.expand_dims(state0.ids, 1),  # [batch, 1].
             state0.bs_state,
             num_hyps_per_beam=1)
         batch = tf.shape(bs_result.log_probs)[0]
         state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
         state1.logits = bs_result.log_probs
         # Sample ids from logits. [batch].
         state1.ids = tf.reshape(
             tf.random.stateless_categorical(
                 state1.logits / p.temperature,
                 num_samples=1,
                 seed=tf.stack(
                     [recurrent_theta.random_seed, state0.timestep]),
                 dtype=state0.ids.dtype,
                 name='sample_next_id'), [batch])
         if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
             state1.ids = tf.where(
                 tf.math.logical_and(
                     bs_result.is_last_chunk,
                     tf.equal(state1.ids, p.target_eoc_id)),
                 tf.fill(tf.shape(state1.ids), p.target_eos_id),
                 state1.ids)
         state1.bs_state = post_step_callback(
             recurrent_theta.theta, recurrent_theta.encoder_outputs,
             state1.ids, bs_state1)
     return state1, py_utils.NestedMap()
예제 #11
0
 def FProp(self, theta, x):
   tf.logging.vlog(1, 'layer %s', self.params.name)
   with tf.name_scope(self.params.name):
     for (name, ch) in self._seq:
       th = theta[name]
       tf.logging.vlog(1, '  call %s %s %s', ch.params.name, ch, x)
       x = ch.FProp(th, x)
     return x
예제 #12
0
    def FProp(self, theta, input_batch):
        """Encodes source as represented by `inputs` and `paddings`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:
        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:

      - encoded: The encoded features, a tensor of shape [time, batch, depth]
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
    """
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Now the rnn layers.
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            self._emb_out = xs
            ps = paddings
            # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell
            # with the same cc_schedule so that the RNN layer output is within
            # clipping range.
            xs = self.rnn[0].FProp(theta.rnn[0], xs, ps)
            xs = self.dropout.FProp(theta.dropout, xs)
            for i in range(1, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, _ = layer.FProp(theta.rnn[i], xs, ps)
                ys = self.dropout.FProp(theta.dropout, ys)
                if hasattr(layer.params, 'cell'):
                    layer_params = layer.params.cell
                else:
                    layer_params = layer.params
                if layer_params.num_input_nodes == layer_params.num_output_nodes:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    # When cc_schedule is specified, make sure lstm_tpl is
                    # QuantizedLSTMCell with the same cc_schedule so that the RNN layer
                    # output is within clipping range.
                    xs = ys
            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
예제 #13
0
  def FProp(self, theta, *args):
    p = self.params
    # Collects all variable key and values into sets.
    theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat)

    def _ArgsToState(arg_list):
      """Returns a NestedMap from a list of FProp args."""
      state = py_utils.NestedMap()
      # Maintains a mapping from arg_idx to tensor. states cannot contains
      # None tensors.
      for idx in range(len(args)):
        if arg_list[idx] is not None:
          state['_s{}'.format(idx)] = arg_list[idx]
      return state

    def _StateToArgs(state):
      """Returns a list of FProp args from a NestedMap."""
      arg_list = []
      for idx in range(len(args)):
        attr = '_s{}'.format(idx)
        arg_list.append(state[attr] if attr in state else None)
        if arg_list[-1] is not None:
          arg_list[-1].set_shape(args[idx].shape)
      return arg_list

    def _CellFn(unused_theta, state0, theta_i):
      """Recurrent cell function wrapper of body.FProp."""
      # Retrieves fprop arguments from state and sets shapes.
      frop_inputs = _StateToArgs(state0)

      # Sets shapes for theta_i as well.
      for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      frop_outputs = self.body.FProp(theta_i, *frop_inputs)
      frop_outputs = _ToTuple(frop_outputs)
      assert len(frop_outputs) == len(frop_inputs)

      # Passes fprop outputs to the next layer through state.
      state1 = _ArgsToState(frop_outputs)
      return state1, py_utils.NestedMap()

    with tf.name_scope(p.name):
      # Add FProp arg list to state0.
      state0 = _ArgsToState(args)
      # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap.
      _, state1 = recurrent.Recurrent(
          theta=py_utils.NestedMap(),
          state0=state0,
          inputs=theta_stack,  # Pass cell_fn theta through inputs.
          cell_fn=_CellFn)

      # Retrieves fprop outputs from state1 and sets shapes.
      output_tensors = _StateToArgs(state1)
      return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
예제 #14
0
 def FProp(self, theta, current_step):
     p = self.params
     assert p.total_steps > 0
     assert p.initial_value > p.final_value
     with tf.name_scope(p.name):
         decay_gap = p.initial_value - p.final_value
         return p.final_value + 0.5 * decay_gap * (1 + tf.cos(
             math.pi *
             tf.minimum(1.0,
                        tf.cast(current_step, tf.float32) / p.total_steps)))
예제 #15
0
  def FProp(self, theta, *args):
    r"""Applies a function (p.fn) on args.

    Args:
      theta: Unused.
      *args: A tuple of Tensors (one or more).

    Returns:
      fn(\*args).
    """
    with tf.name_scope(self.params.name):
      return self.params.fn(*args)
예제 #16
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
예제 #17
0
  def FProp(self, theta, inputs, paddings):
    """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
    p = self.params
    with tf.name_scope(p.name):
      inputs = py_utils.with_dependencies([
          py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
          py_utils.assert_shape_match(
              tf.shape(inputs),
              tf.concat([
                  tf.shape(paddings),
                  [-1, symbolic.ToStatic(self.input_channels)]
              ], 0))
      ], inputs)

      def _ApplyPadding(tensor_in, padding_in):
        padding_expanded = tf.expand_dims(tf.expand_dims(padding_in, -1), -1)
        return tensor_in * (1.0 - padding_expanded)

      # Zeroing out padded inputs.
      inputs = _ApplyPadding(inputs, paddings)

      # Apply conv on 'inputs'.
      out = self._ApplyConv(theta, inputs)

      if p.partial_conv:
        out = self._RescaleBoundary(out, paddings)
      # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
      # But there's likely no real problems. Trying to set it gives an error:
      # pooling with SAME padding is not implemented for dilation_rate > 1.
      # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
      # implementation.  Consider updating it to be the actual shape.
      conv_padding = ComputeConvOutputPadding(
          paddings, window=p.filter_stride[0], stride=p.filter_stride[0])
      # Assuming padded nodes will be properly zero-ed out if necessary by
      # sub-sequent layers.
      # out = _ApplyPadding(out, conv_padding)
      out = py_utils.HasShape(
          out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
      return out, conv_padding
예제 #18
0
  def FProp(self, theta, inputs):
    """Adds bias to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., dims].

    Returns:
      Inputs plus bias.
    """
    with tf.name_scope(self.params.name):
      return inputs + theta.b
예제 #19
0
 def _Traverse(layer):
     """Adds accumulators to layer and its descendant layers."""
     if isinstance(layer, (list, tuple)):
         for layer_i in layer:
             _Traverse(layer_i)
         return
     with tf.name_scope(layer.params.name):
         for cost_metric_name in COST_METRICS:
             dtype = COST_METRICS[cost_metric_name]
             layer.RegisterAccumulator(
                 cost_metric_name,
                 bn_layers.AddingAccumulator(shape=[], dtype=dtype))
         for _, child in sorted(layer.children.items()):
             _Traverse(child)
예제 #20
0
  def FProp(self, theta, *args):
    p = self.params

    with tf.name_scope(p.name):
      # Computes sub layers in parallel.
      outputs = []
      for (name, ch) in self._seq:
        th = theta[name]
        out = ch.FProp(th, *args)
        if isinstance(out, (list, tuple)):
          outputs.append(tuple(out))
        else:
          outputs.append((out,))
      rets = p.merge(outputs)
      return rets if len(rets) > 1 else rets[0]
예제 #21
0
def CollectVarHistogram(vs_gs):
  """Adds histogram summaries for variables and gradients."""

  for name, (var, grad) in vs_gs.FlattenItems():
    name = py_utils.SanitizeScopeKey(name)
    with tf.device(var.device), tf.name_scope(name + '/summary'):
      if isinstance(grad, tf.IndexedSlices):
        var = tf.gather(var, grad.indices)
        grad = grad.values
      if var.dtype.is_complex:
        var = tf.abs(var)
        grad = tf.abs(grad)

    histogram('var_hist/' + name, var)
    histogram('grad_hist/' + name, grad)
예제 #22
0
    def FProp(self, theta, input_batch, state0=None):
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Reshape to [t, b]
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)

            # Setup streaming states.
            if not state0:
                state0 = self.zero_state(theta, tf.shape(inputs)[1])
            state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers)

            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, state1.rnn[i] = layer.FProp(theta.rnn[i],
                                                xs,
                                                ps,
                                                state0=state0.rnn[i])
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id,
                                      state=state1)
예제 #23
0
 def FProp(self, theta, *args):
   p = self.params
   with tf.name_scope(p.name) as name_scope:
     for i, arg in enumerate(args):
       if not isinstance(arg, tf.Tensor):
         tf.logging.info(
             'FProp non-Tensor input in {}: arg_{} arg = {}'.format(
                 name_scope, i, arg))
       else:
         tf.logging.info(
             'FProp inputs in {}: arg_{} shape = {} dtype = {}'.format(
                 name_scope, i, arg.shape, arg.dtype.name))
   if len(args) == 1:
     return args[0]
   else:
     return args
def TraverseLayer(layer, fn):
  """Traverses the layer tree and invokes fn(node) on each node.

  Args:
    layer: a BaseLayer.
    fn: a function of (layer, layer_theta) -> None.
  """
  if isinstance(layer, (list, tuple)):
    for layer_i in layer:
      TraverseLayer(layer_i, fn)
    return

  with tf.name_scope(layer.params.name):
    fn(layer)
    # Traverse all children in alphabetical order.
    for _, child in sorted(layer.children.items()):
      TraverseLayer(child, fn)
예제 #25
0
    def ZeroState(self, theta, prepared_inputs, batch_size):
        """Creates a zero state NestedMap for this step.

    Args:
      theta: variables used by sub-steps.
      prepared_inputs: Output from a call to PrepareExternalInputs.
      batch_size: The number of items in the batch that FProp will process.

    Returns:
      A NestedMap of ZeroState results for each sub-step.
    """
        state0 = py_utils.NestedMap()
        with tf.name_scope(self.params.name):
            for seq in self._seq:
                state0[seq.name] = seq.step.ZeroState(
                    theta[seq.name], prepared_inputs[seq.name], batch_size)
        return state0
예제 #26
0
  def FProp(self, theta, inputs, *args):
    p = self.params
    with tf.name_scope(p.name) as scope:
      expert_dist = self._GetExpertDist(theta, inputs, *args)
      if not self.do_eval:
        summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist)

      # Excludes non-variable extra_theta like global_step.
      var_set = set([key for key, _ in self.body.vars.FlattenItems()])
      values = []
      for key, value in theta.body.FlattenItems():
        if key in var_set and value is not None:
          # Weighted average for all variables created in the body layer.
          value = tf.einsum('i,i...->...', expert_dist, value)
        values.append(value)
      weighted_theta = theta.body.Pack(values)
      return self.body.FProp(weighted_theta, inputs, *args)
예제 #27
0
    def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
        """A single inference step for this step graph.

    Args:
      theta: variables used by sub-steps.
      prepared_inputs: A NestedMap containing external_inputs that were
        pre-processed by the PrepareExternalInputs method of each sub-step. The
        keys are the names of the sub-steps.
      step_inputs: A NestedMap of [batch, ...] tensors. The structure of this
        depends on the graph implementation.
      padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this
        batch element is empty in this step.
      state0: A NestedMap of state variables produced by either ZeroState or a
        previous invocation of this FProp step. The keys are the names of the
        sub-steps.

    Returns:
      (output, state1), both of which are NestedMaps.
      output is implementation-dependent and is defined by the output_signature
      parameter.
      state1 is a NestedMap where the keys are names of sub-steps and the values
      are state outputs from their FProp methods.
    """
        p = self.params
        graph_tensors = builder_layers.GraphTensors()
        graph_tensors.StoreTensor('prepared_inputs', prepared_inputs)
        graph_tensors.StoreTensor('step_inputs', step_inputs)
        state1 = py_utils.NestedMap()
        with tf.name_scope(p.name):
            for seq in self._seq:
                tf.logging.vlog(1, 'GraphStep: call %s', seq.name)
                external = None
                if seq.external_signature:
                    external = prepared_inputs[seq.name]
                template = py_utils.NestedMap(inputs=seq.signature.inputs)
                packed = template.Transform(graph_tensors.GetTensor)
                input_args = packed.inputs[0]
                out, seq_state1 = seq.step.FProp(theta[seq.name], external,
                                                 input_args, padding,
                                                 state0[seq.name])
                graph_tensors.StoreTensor(seq.signature.outputs[0], out)
                state1[seq.name] = seq_state1
        template = py_utils.NestedMap(inputs=self.output_signature.inputs)
        output_tensors = template.Transform(graph_tensors.GetTensor).inputs[0]
        return output_tensors, state1
예제 #28
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings)
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.zeros_like(norm_variance)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_variance)),
            ]):
                if p.use_fused_batch_norm_for_eval and self.do_eval:
                    bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                          gamma,
                                                          beta,
                                                          norm_mean,
                                                          norm_variance,
                                                          self._epsilon,
                                                          is_training=False)
                else:
                    bn_output = tf.nn.batch_normalization(
                        inputs, norm_mean, norm_variance, beta, gamma,
                        self._epsilon)

                if p.set_padded_output_to_zero:
                    bn_output *= 1.0 - paddings

            return bn_output
 def FProp(self,
           theta,
           source_vecs,
           source_paddings,
           target_vecs,
           target_paddings,
           source_segment_id,
           target_segment_id,
           transparent_acc,
           transparent_acc_helper,
           source_task_id=None,
           target_task_id=None):
     with tf.name_scope(self.params.name):
         return _common_gpipe_transformer_decoder_fprop(
             self, GPipeEvolvedTransformerDecoderLayer, theta, source_vecs,
             source_paddings, target_vecs, target_paddings,
             source_segment_id, target_segment_id, transparent_acc,
             transparent_acc_helper, source_task_id, target_task_id)
예제 #30
0
    def FProp(self, theta, current_step):
        p = self.params
        with tf.name_scope(p.name):

            steps = self._best_step
            best_step = steps[0]
            last_step = steps[1]

            ref_step = tf.maximum(self._ref_step, best_step)
            f = self._cur_factor

            # Decay if no improvement within window.
            new_factor = tf.where(last_step - ref_step < p.window, f,
                                  tf.maximum(p.min_factor, f * p.decay))
            # Update ref_step if we decayed.
            new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step)
            update_step = tf.assign(self._ref_step, new_step)
            with tf.control_dependencies([update_step]):
                return tf.assign(self._cur_factor, new_factor)