Esempio n. 1
0
 def test_dense_param_sharing(self):
     model1 = combinators.Serial(core.Dense(32), core.Dense(32))
     layer = core.Dense(32)
     model2 = combinators.Serial(layer, layer)
     rng = backend.random.get_prng(0)
     params1 = model1.initialize((1, 32), onp.float32, rng)
     params2 = model2.initialize((1, 32), onp.float32, rng)
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
Esempio n. 2
0
def MultiHeadedAttention(feature_depth,
                         num_heads=8,
                         dropout=0.0,
                         mode='train'):
    """Transformer-style multi-headed attention.

  Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    return combinators.Serial(
        combinators.Parallel(
            # q = k = v = first input
            combinators.Branch(combinators.Copy(), combinators.Copy(),
                               combinators.Copy()),
            combinators.Copy()  # pass the mask
        ),
        MultiHeadedAttentionQKV(  # pylint: disable=no-value-for-parameter
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
    )
Esempio n. 3
0
def MultiHeadedAttentionQKV(
    feature_depth, num_heads=8, dropout=0.0, mode='train'):
  """Transformer-style multi-headed attention.

  Accepts inputs of the form (q, k, v), mask.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
  return combinators.Serial(
      combinators.Parallel(
          core.Dense(feature_depth),
          core.Dense(feature_depth),
          core.Dense(feature_depth),
          combinators.NoOp()
      ),
      PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
          feature_depth=feature_depth, num_heads=num_heads,
          dropout=dropout, mode=mode),
      combinators.Parallel(core.Dense(feature_depth), combinators.NoOp())
  )
Esempio n. 4
0
def GeneralGRUCell(candidate_transform,
                   memory_transform_fn=None,
                   gate_nonlinearity=core.Sigmoid,
                   candidate_nonlinearity=core.Tanh,
                   dropout_rate_c=0.1,
                   sigmoid_bias=0.5):
    r"""Parametrized Gated Recurrent Unit (GRU) cell construction.

  GRU update equations:
  $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$
  $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$
  $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$
  $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$

  See combinators.Gate for details on the gating function.


  Args:
    candidate_transform: Transform to apply inside the Candidate branch. Applied
      before nonlinearities.
    memory_transform_fn: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
    gate_block = [  # u_t
        candidate_transform(),
        core.AddConstant(constant=sigmoid_bias),
        gate_nonlinearity(),
    ]
    reset_block = [  # r_t
        candidate_transform(),
        core.AddConstant(
            constant=sigmoid_bias),  # Want bias to start positive.
        gate_nonlinearity(),
    ]
    candidate_block = [
        cb.Branch([], reset_block),
        cb.Multiply(),  # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
        candidate_transform(),  # Final projection + tanh to get Ct
        candidate_nonlinearity(),  # Candidate gate

        # Only apply dropout on the C gate. Paper reports 0.1 as a good default.
        core.Dropout(rate=dropout_rate_c)
    ]
    memory_transform = memory_transform_fn() if memory_transform_fn else []
    return cb.Serial([
        cb.Branch(memory_transform, gate_block, candidate_block),
        cb.Gate(),
    ])
Esempio n. 5
0
def ChunkedCausalMultiHeadedAttention(feature_depth,
                                      num_heads=8,
                                      dropout=0.0,
                                      chunk_selector=None,
                                      mode='train'):
    """Transformer-style causal multi-headed attention operating on chunks.

  Accepts inputs that are a list of chunks and applies causal attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    prepare_attention_input = combinators.Serial(
        combinators.Branch(
            combinators.Branch(  # q = k = v = first input
                combinators.Copy(), combinators.Copy(), combinators.Copy()),
            CausalMask(axis=-2),  # pylint: disable=no-value-for-parameter
        ),
        combinators.Parallel(
            combinators.Parallel(
                core.Dense(feature_depth),
                core.Dense(feature_depth),
                core.Dense(feature_depth),
            ), combinators.Copy()))
    return combinators.Serial(
        combinators.Map(prepare_attention_input),
        ChunkedAttentionSelector(selector=chunk_selector),  # pylint: disable=no-value-for-parameter
        combinators.Map(
            PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter
                feature_depth=feature_depth,
                num_heads=num_heads,
                dropout=dropout,
                mode=mode),
            check_shapes=False),
        combinators.Map(combinators.Select(0),
                        check_shapes=False),  # drop masks
        combinators.Map(core.Dense(feature_depth)))
Esempio n. 6
0
    def test_serial_no_op_list(self):
        layer = cb.Serial([])
        input_shape = ((3, 2), (4, 7))
        expected_shape = ((3, 2), (4, 7))
        output_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(output_shape, expected_shape)

        input_shape = ((3, 2), (4, 7)) + _REST_OF_STACK
        expected_shape = ((3, 2), (4, 7)) + _REST_OF_STACK
        output_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(output_shape, expected_shape)
Esempio n. 7
0
def MaskedScalar(metric_layer, mask_id=None, has_weights=False):
    """Metric as scalar compatible with Trax masking."""
    # Stack of (inputs, targets) --> (metric, weight-mask).
    metric_and_mask = [
        cb.Parallel(
            [],
            cb.Dup()  # Duplicate targets
        ),
        cb.Parallel(
            metric_layer,  # Metric: (inputs, targets) --> metric
            WeightMask(mask_id=mask_id)  # pylint: disable=no-value-for-parameter
        )
    ]
    if not has_weights:
        # Take (metric, weight-mask) and return the weighted mean.
        return cb.Serial([metric_and_mask, WeightedMean()])  # pylint: disable=no-value-for-parameter
    return cb.Serial([
        metric_and_mask,
        cb.Parallel(
            [],
            cb.Multiply()  # Multiply given weights by mask_id weights
        ),
        WeightedMean()  # pylint: disable=no-value-for-parameter
    ])
Esempio n. 8
0
def GeneralGRUCell(candidate_transform,
                   memory_transform=combinators.Identity,
                   gate_nonlinearity=core.Sigmoid,
                   candidate_nonlinearity=core.Tanh,
                   dropout_rate_c=0.1,
                   sigmoid_bias=0.5):
  r"""Parametrized Gated Recurrent Unit (GRU) cell construction.

  GRU update equations:
  $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$
  $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$
  $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$
  $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$

  See combinators.GateBranches for details on the gating function.


  Args:
    candidate_transform: Transform to apply inside the Candidate branch. Applied
      before nonlinearities.
    memory_transform: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
  return combinators.Serial(
      combinators.Branch(num_branches=3),
      combinators.Parallel(
          # s_{t-1} branch - optionally transform
          # Typically is an identity.
          memory_transform(),

          # u_t (Update gate) branch
          combinators.Serial(
              candidate_transform(),
              # Want bias to start out positive before sigmoids.
              core.AddConstant(constant=sigmoid_bias),
              gate_nonlinearity()),

          # c_t (Candidate) branch
          combinators.Serial(
              combinators.Branch(num_branches=2),
              combinators.Parallel(
                  combinators.Identity(),
                  # r_t (Reset) Branch
                  combinators.Serial(
                      candidate_transform(),
                      # Want bias to start out positive before sigmoids.
                      core.AddConstant(constant=sigmoid_bias),
                      gate_nonlinearity())),
              ## Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
              combinators.MultiplyBranches(),

              # Final projection + tanh to get Ct
              candidate_transform(),
              candidate_nonlinearity()),  # Candidate gate

          # Only apply dropout on the C gate.
          # Paper reports that 0.1 is a good default.
          core.Dropout(rate=dropout_rate_c)),

      # Gate memory and candidate
      combinators.GateBranches())
Esempio n. 9
0
def L2LossScalar(mask_id=None, has_weights=False):
    """L2 loss as scalar compatible with Trax masking."""
    return cb.Serial(L2Scalar(mask_id=mask_id, has_weights=has_weights),
                     core.MulConstant(constant=-1.0))
Esempio n. 10
0
def CrossEntropyLossScalar(mask_id=None, has_weights=False):
    """Cross-entropy loss as scalar compatible with Trax masking."""
    return cb.Serial(
        CrossEntropyScalar(mask_id=mask_id, has_weights=has_weights),
        core.MulConstant(constant=-1.0))
Esempio n. 11
0
 def test_serial_dup_dup(self):
     layer = cb.Serial(cb.Dup(), cb.Dup())
     input_shape = (3, 2)
     expected_shape = ((3, 2), (3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Esempio n. 12
0
 def test_serial_div_div(self):
     layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0))
     input_shape = (3, 2)
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Esempio n. 13
0
 def test_serial_no_op(self):
     layer = cb.Serial(None)
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Esempio n. 14
0
 def test_serial_one_in_one_out(self):
     layer = cb.Serial(core.Div(divisor=2.0))
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)