Beispiel #1
0
def GeneralGRUCell(candidate_transform,
                   memory_transform_fn=None,
                   gate_nonlinearity=activation_fns.Sigmoid,
                   candidate_nonlinearity=activation_fns.Tanh,
                   dropout_rate_c=0.1,
                   sigmoid_bias=0.5):
    r"""Parametrized Gated Recurrent Unit (GRU) cell construction.

  GRU update equations:
  $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$
  $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$
  $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$
  $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$

  See combinators.Gate for details on the gating function.


  Args:
    candidate_transform: Transform to apply inside the Candidate branch. Applied
      before nonlinearities.
    memory_transform_fn: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
    gate_block = [  # u_t
        candidate_transform(),
        base.Fn(lambda x: x + sigmoid_bias),
        gate_nonlinearity(),
    ]
    reset_block = [  # r_t
        candidate_transform(),
        base.Fn(lambda x: x + sigmoid_bias),  # Want bias to start positive.
        gate_nonlinearity(),
    ]
    candidate_block = [
        cb.Dup(),
        reset_block,
        cb.Multiply(),  # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
        candidate_transform(),  # Final projection + tanh to get Ct
        candidate_nonlinearity(),  # Candidate gate

        # Only apply dropout on the C gate. Paper reports 0.1 as a good default.
        core.Dropout(rate=dropout_rate_c)
    ]
    memory_transform = memory_transform_fn() if memory_transform_fn else []
    return cb.Serial(
        cb.Branch(memory_transform, gate_block, candidate_block),
        cb.Gate(),
    )
Beispiel #2
0
 def test_fn_layer_varargs_n_in(self):
     with self.assertRaisesRegex(ValueError, 'variable arg'):
         base.Fn(lambda *args: args[0])
     # Check that varargs work when n_in is set.
     id_layer = base.Fn(lambda *args: args[0], n_in=1)
     input_signature = ShapeDtype((2, 7))
     expected_shape = (2, 7)
     output_shape = base.check_shape_agreement(id_layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Beispiel #3
0
 def test_fn_layer_difficult_n_out(self):
     with self.assertRaisesRegex(ValueError, 'n_out'):
         # Determining the output of this layer is hard with dummies.
         base.Fn(lambda x: np.concatencate([x, x], axis=4))
     # Check that this layer works when n_out is set.
     layer = base.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1)
     input_signature = ShapeDtype((2, 1, 2, 2, 3))
     expected_shape = (2, 1, 2, 2, 6)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Beispiel #4
0
    def test_output_signature(self):
        input_signature = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5)))
        layer = base.Fn(lambda x, y: x + y)  # n_in = 2, n_out = 1
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, ShapeDtype((2, 3, 5)))

        input_signature = ShapeDtype((5, 7))
        layer = base.Fn(lambda x: (x, 2 * x, 3 * x))  # n_in = 1, n_out = 3
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, (ShapeDtype((5, 7)), ) * 3)
        self.assertNotEqual(output_signature, (ShapeDtype((4, 7)), ) * 3)
        self.assertNotEqual(output_signature, (ShapeDtype((5, 7)), ) * 2)
Beispiel #5
0
  def test_output_signature(self):
    input_signature = (shapes.ShapeDtype((2, 3, 5)),
                       shapes.ShapeDtype((2, 3, 5)))
    layer = base.Fn('2in1out', lambda x, y: x + y)
    output_signature = layer.output_signature(input_signature)
    self.assertEqual(output_signature, shapes.ShapeDtype((2, 3, 5)))

    input_signature = shapes.ShapeDtype((5, 7))
    layer = base.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3)
    output_signature = layer.output_signature(input_signature)
    self.assertEqual(output_signature, (shapes.ShapeDtype((5, 7)),) * 3)
    self.assertNotEqual(output_signature, (shapes.ShapeDtype((4, 7)),) * 3)
    self.assertNotEqual(output_signature, (shapes.ShapeDtype((5, 7)),) * 2)
Beispiel #6
0
 def test_weights_state(self):
   layer = base.Fn(
       '2in2out',
       lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2)
   weights, state = layer.new_weights_and_state(None)
   self.assertEmpty(weights)
   self.assertEmpty(state)
Beispiel #7
0
def MakeZeroState(depth_multiplier=1):
  """Makes zeros of shape like x but removing the length (axis 1)."""
  def f(x):  # pylint: disable=invalid-name
    assert len(x.shape) == 3, 'Expecting x of shape [batch, length, depth].'
    return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),
                     dtype=jnp.float32)
  return base.Fn('MakeZeroState', f)
Beispiel #8
0
def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()):
    """Bidirectional combinator for RNNs.

  Args:
    forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`.
    axis: a time axis of the inputs. Default value is `1`.
    merge_layer: A combinator used to combine outputs of the forward
      and backward RNNs. Default value is 'trax.layers.Concatenate'.

  Example:
      Bidirectional(RNN(n_units=8))

  Returns:
    The Bidirectional combinator for RNNs.
  """
    backward_layer = copy.deepcopy(forward_layer)
    flip = base.Fn('_FlipAlongTimeAxis', lambda x: jnp.flip(x, axis=axis))
    backward = Serial(
        flip,
        backward_layer,
        flip,
    )

    return Serial(
        Branch(forward_layer, backward),
        merge_layer,
    )
Beispiel #9
0
def WeightedFScore(beta=1., initial_category_index=0):
    """Returns a layer that computes a weighted F-score.

  The weighted F-score summarize how well the classifier's `k` predictions
  align with the observed/gold instances of `k`. It additionally
  weights the summary by the number of observed/gold and predicted examples
  in each class.

  Args:
    beta: a parameter that determines the weight of recall in the F-score.
    initial_category_index: an index of the initial category.

  The layer takes two inputs:

    - Model output from one batch, an ndarray of float-valued elements.

    - A batch of element-wise target values, which matches the shape of the
      model output.

  The layer returns a weighted F-score across all the classes.
  """
    def f(model_output, targets):  # pylint: disable=invalid-name
        beta2 = beta**2
        predictions = jnp.argmax(model_output, axis=-1)
        n_categories = model_output.shape[-1]
        f_scores = jnp.empty(0)
        weights = jnp.empty(0)
        for k in range(initial_category_index, n_categories):
            _, _, n_k_targets, precision, recall = _precision_recall(
                predictions, targets, k)
            f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2))
            weights = jnp.append(weights, n_k_targets)
        return jnp.average(f_scores, weights=weights)

    return base.Fn('WeightedFScore', f)
Beispiel #10
0
def CategoryAccuracy():
    r"""Returns a layer that computes category prediction accuracy.

  The layer takes two inputs:

    - A batch of activation vectors. The components in a given vector should
      be mappable to a probability distribution in the following loose sense:
      within a vector, a higher component value corresponds to a higher
      probability, such that argmax within a vector (``axis=-1``) picks the
      index (category) having the highest probablity.

    - A batch of target categories; each target is an integer in
      :math:`\{0, ..., N-1\}`.

  The predicted category from each vector is the index of the highest-valued
  vector component. The layer returns the accuracy of these predictions
  averaged over the batch.
  """
    def f(model_output, targets):  # pylint: disable=invalid-name
        predictions = jnp.argmax(model_output, axis=-1)
        shapes.assert_same_shape(predictions, targets)
        n_total = predictions.size
        n_correct = jnp.sum(jnp.equal(predictions, targets))
        return n_correct / n_total

    return base.Fn('CategoryAccuracy', f)
Beispiel #11
0
def SmoothL1Loss():
    r"""Returns a layer that computes a weighted, smoothed L1 loss for one batch.

  The layer takes three inputs:

    - Model output from one batch, an ndarray of float-valued elements.

    - A batch of element-wise target values, which matches the shape of the
      model output.

    - A batch of weights, which matches the shape of the model output.

  The layer computes a "smooth" L1 loss (a.k.a. Huber loss), for model output
  float :math:`y_i` and target float :math:`t_i`:

  .. math::
      \text{output} = \left\{ \begin{array}{cl}
          \frac 1 2 (y_i - t_i)^2, & \text{if}\ |y_i - t_i| < 1, \\
          |y_i - t_i| - \frac 1 2, & \text{otherwise}.
      \end{array} \right.

  The layer returns a weighted average of these element-wise values.
  """
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(model_output, weights)
        l1_dist = jnp.abs(model_output - targets)
        smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5)
        weighted_smooth_dist = weights * smooth_dist
        return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)

    return base.Fn('SmoothL1Loss', f)
Beispiel #12
0
def _Accuracy():
  """Returns a layer that scores predicted versus target category."""
  def f(predicted_category, target_category):  # pylint: disable=invalid-name
    # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment.
    # shapes.assert_same_shape(predicted_category, target_category)
    return jnp.equal(predicted_category, target_category).astype(jnp.float32)
  return base.Fn('_Accuracy', f)
Beispiel #13
0
def CategoryCrossEntropy():
    """Returns a layer that computes cross entropy from activations and integers.

  The layer takes two inputs:

    - A batch of activation vectors. The components in a given vector should
      be pre-softmax activations (mappable to a probability distribution via
      softmax). For performance reasons, the softmax and cross entropy
      computations are combined inside the layer.

    - A batch of target categories; each target is an integer in
      `{0, ..., N-1}`, where `N` is the activation vector depth/dimensionality.

  To compute cross-entropy, the layer derives probability distributions from
  its inputs:

    - activation vectors: vector --> SoftMax(vector)

    - target categories: integer --> OneHot(integer)

  (The conversion of integer category targets to one-hot vectors amounts to
  assigning all the probability mass to the target category.) Cross-entropy
  per batch item is computed between the resulting distributions; notionally:

      cross_entropy(one_hot(targets), softmax(model_output))

  The layer returns the average of these cross-entropy values over all items in
  the batch.
  """
    def f(model_output, targets):  # pylint: disable=invalid-name
        cross_entropies = _category_cross_entropy(model_output, targets)
        return jnp.average(cross_entropies)

    return base.Fn('CategoryCrossEntropy', f)
Beispiel #14
0
def WeightedCategoryCrossEntropy():
    r"""Returns a layer like ``CategoryCrossEntropy``, with weights as third input.

  The layer takes three inputs:

    - A batch of activation vectors. The components in a given vector should
      be pre-softmax activations (mappable to a probability distribution via
      softmax). For performance reasons, the softmax and cross-entropy
      computations are combined inside the layer.

    - A batch of target categories; each target is an integer in
      :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector
      depth/dimensionality.

    - A batch of weights, which matches or can be broadcast to match the shape
      of the target ndarray. This arg can give uneven weighting to different
      items in the batch (depending, for instance, on the item's target
      category).

  The layer returns the weighted average of these cross-entropy values over all
  items in the batch.
  """
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        cross_entropies = _category_cross_entropy(model_output, targets)
        return jnp.sum(cross_entropies * weights) / jnp.sum(weights)

    return base.Fn('WeightedCategoryCrossEntropy', f)
Beispiel #15
0
def SmoothL1Loss():
    """Returns a layer that computes total smooth L1 loss for one batch."""
    def smoothl1loss(model_output, targets, weights):  # pylint: disable=invalid-name
        r"""Returns weighted smooth L1 norm of `model_output - targets`.

    The smooth L1 loss, also known as the Huber loss, is defined as:
    .. math::
        z_i =
        \begin{cases}
        0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
        |x_i - y_i| - 0.5, & \text{otherwise }
        \end{cases}

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`, containing
          element-wise weight values.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        l1_dist = jnp.abs(model_output - targets)
        smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5)
        shapes.assert_same_shape(smooth_dist, weights)
        weighted_smooth_dist = weights * smooth_dist
        return jnp.sum(weighted_smooth_dist) / jnp.sum(weights)

    return base.Fn('SmoothL1Loss', smoothl1loss)
Beispiel #16
0
def WeightedCategoryAccuracy():
    r"""Returns a layer that computes a weighted category prediction accuracy.

  The layer takes three inputs:

    - A batch of activation vectors. The components in a given vector should
      be mappable to a probability distribution in the following loose sense:
      within a vector, a higher component value corresponds to a higher
      probability, such that argmax within a vector (``axis=-1``) picks the
      index (category) having the highest probablity.

    - A batch of target categories; each target is an integer in
      :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector
      depth/dimensionality.

    - A batch of weights, which matches or can be broadcast to match the shape
      of the target ndarray. This arg can give uneven weighting to different
      items in the batch (depending, for instance, on the item's target
      category).

  The predicted category from each vector is the index of the highest-valued
  vector component. The layer returns a weighted average accuracy of these
  predictions.
  """
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        predictions = jnp.argmax(model_output, axis=-1)
        shapes.assert_same_shape(predictions, targets)
        ones_and_zeros = jnp.equal(predictions, targets)
        return jnp.sum(ones_and_zeros * weights) / jnp.sum(weights)

    return base.Fn('WeightedCategoryAccuracy', f)
Beispiel #17
0
def MacroAveragedFScore(beta=1., initial_category_index=0):
    r"""Returns a layer that computes a macro-averaged F-score.

  Args:
    beta: a parameter that determines the weight of recall in the F-score.
    initial_category_index: an index of the initial category.

  The layer takes two inputs:

    - Model output from one batch, an ndarray of float-valued elements.

    - A batch of element-wise target values, which matches the shape of the
      model output.

  The layer returns an macro-averaged F-score across all the classes.
  """
    def f(model_output, targets):  # pylint: disable=invalid-name
        def non_nan(x):  # pylint: disable=invalid-name
            return jnp.where(jnp.isnan(x), 0., x)

        beta2 = beta**2
        predictions = jnp.argmax(model_output, axis=-1)
        n_categories = model_output.shape[-1]
        f_scores = jnp.empty(0)
        for k in range(initial_category_index, n_categories):
            n_correct = sum((predictions == k) & (targets == k))
            precision = non_nan(n_correct / sum(predictions == k))
            recall = non_nan(n_correct / sum(targets == k))
            f_score = non_nan((beta2 + 1) * (precision * recall) /
                              ((beta2 * precision) + recall))
            f_scores = jnp.append(f_scores, f_score)
        return jnp.mean(f_scores)

    return base.Fn('MacroAveragedFScore', f)
Beispiel #18
0
def InnerSRUCell():
    """The inner (non-parallel) computation of an SRU."""
    def f(cur_x_times_one_minus_f, cur_f, cur_state):  # pylint: disable=invalid-name
        res = cur_f * cur_state + cur_x_times_one_minus_f
        return res, res

    return base.Fn('InnerSRUCell', f, n_out=2)
Beispiel #19
0
def _CrossEntropy():
  """Returns a layer that computes prediction-target cross entropies."""
  def f(model_output, target_category):  # pylint: disable=invalid-name
    # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment.
    # shapes.assert_shape_equals(target_category, model_output.shape[:-1])
    target_distribution = core.one_hot(target_category, model_output.shape[-1])
    return -1.0 * jnp.sum(model_output * target_distribution, axis=-1)
  return base.Fn('_CrossEntropy', f)
Beispiel #20
0
 def TestModelSavingInputs():
   def f(inputs):
     # Save the inputs for a later check.
     test_model_inputs.append(inputs)
     # Change type to np.float32 and add the logit dimension.
     return jnp.broadcast_to(
         inputs.astype(np.float32)[:, :, None], inputs.shape + (vocab_size,)
     )
   return layers_base.Fn('TestModelSavingInputs', f)
Beispiel #21
0
def MakeZeroState(depth_multiplier=1):
  """Makes zeros of shape like x but removing the length (axis 1)."""
  def f(x):  # pylint: disable=invalid-name
    if len(x.shape) != 3:
      raise ValueError(f'Layer input should be a rank 3 tensor representing'
                       f' (batch_size, sequence_length, feature_depth); '
                       f'instead got shape {x.shape}.')
    return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),
                     dtype=jnp.float32)
  return base.Fn('MakeZeroState', f)
Beispiel #22
0
 def test_fn_layer_example(self):
     layer = base.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))
     input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7)))
     expected_shape = ((2, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
     inp = (np.array([2]), np.array([3]))
     x, xs = layer(inp)
     self.assertEqual(int(x), 5)
     self.assertEqual([int(y) for y in xs], [2, 3])
Beispiel #23
0
def _BinaryCrossEntropy():
  """Returns a layer that computes prediction-target cross entropies."""
  def f(model_output, target_category):  # pylint: disable=invalid-name
    shapes.assert_same_shape(model_output, target_category)
    batch_size = model_output.shape[0]
    j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output))
    j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output))
    j = -1.0/batch_size * jnp.squeeze(j)
    return j
  return base.Fn('_BinaryCrossEntropy', f)
Beispiel #24
0
def SRU(n_units, activation=None, mode='train'):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    mode: if 'predict' then we save the previous state for one-by-one inference

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        ScanSRUCell(mode=mode),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation if activation is not None else [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)),
        # Set the name to SRU and don't print sublayers.
        name=f'SRU_{n_units}',
        sublayers_to_print=[])
Beispiel #25
0
    def _merge_heads():
        """Returns a layer that undoes splitting, after multi-head computation."""
        def f(x):
            seq_len = x.shape[1]

            # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
            x = x.reshape((-1, n_heads, seq_len, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, n_heads * d_head))
            return x

        return base.Fn('MergeHeads', f)
Beispiel #26
0
def TestModel(extra_dim):
  """Dummy sequence model for testing."""
  def f(inputs):
    # Cast the input to float32 - this is for simulating discrete-input models.
    inputs = inputs.astype(np.float32)
    # Add an extra dimension if requested, e.g. the logit dimension for output
    # symbols.
    if extra_dim is not None:
      return jnp.broadcast_to(inputs[:, :, None], inputs.shape + (extra_dim,))
    else:
      return inputs
  return layers_base.Fn('TestModel', f)
Beispiel #27
0
    def _split_into_heads():
        """Returns a layer that reshapes tensors for multi-headed computation."""
        def f(x):
            batch_size = x.shape[0]
            seq_len = x.shape[1]

            # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
            x = x.reshape((batch_size, seq_len, n_heads, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, d_head))
            return x

        return base.Fn('SplitIntoHeads', f)
Beispiel #28
0
def SRU(n_units, activation=None):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)))
Beispiel #29
0
def _WeightedSequenceMean():
  """Returns a layer that computes a weighted sequence accuracy mean."""
  def f(values, weights):  # pylint: disable=invalid-name
    # This function assumes weights are 0 or 1.
    # Then compute 1: not-correct, 0: correct or masked
    not_correct = (1.0 - values) * weights
    axis_to_sum = list(range(1, len(not_correct.shape)))
    # Summing not-correct on all axes but batch. We're summing 0s and 1s,
    # so the sum is 0 if it's all 0 and >=1 in all other cases.
    not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum)
    # Sequence is correct if not_correct_seq is 0, reverting here.
    correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq)
    return jnp.mean(correct_seq)  # Mean over batch.
  return base.Fn('_WeightedSequenceMean', f)
Beispiel #30
0
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))