def test_dense_param_sharing(self): model1 = combinators.Serial(core.Dense(32), core.Dense(32)) layer = core.Dense(32) model2 = combinators.Serial(layer, layer) rng = backend.random.get_prng(0) params1 = model1.initialize((1, 32), onp.float32, rng) params2 = model2.initialize((1, 32), onp.float32, rng) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def MultiHeadedAttention(feature_depth, num_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form (x, mask) and constructs (q, k, v) from x. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ return combinators.Serial( combinators.Parallel( # q = k = v = first input combinators.Branch(combinators.Copy(), combinators.Copy(), combinators.Copy()), combinators.Copy() # pass the mask ), MultiHeadedAttentionQKV( # pylint: disable=no-value-for-parameter feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), )
def MultiHeadedAttentionQKV( feature_depth, num_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention. Accepts inputs of the form (q, k, v), mask. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate mode: str: 'train' or 'eval' Returns: Multi-headed self-attention result and the mask. """ return combinators.Serial( combinators.Parallel( core.Dense(feature_depth), core.Dense(feature_depth), core.Dense(feature_depth), combinators.NoOp() ), PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), combinators.Parallel(core.Dense(feature_depth), combinators.NoOp()) )
def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), core.AddConstant( constant=sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Branch([], reset_block), cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial([ cb.Branch(memory_transform, gate_block, candidate_block), cb.Gate(), ])
def ChunkedCausalMultiHeadedAttention(feature_depth, num_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = combinators.Serial( combinators.Branch( combinators.Branch( # q = k = v = first input combinators.Copy(), combinators.Copy(), combinators.Copy()), CausalMask(axis=-2), # pylint: disable=no-value-for-parameter ), combinators.Parallel( combinators.Parallel( core.Dense(feature_depth), core.Dense(feature_depth), core.Dense(feature_depth), ), combinators.Copy())) return combinators.Serial( combinators.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter combinators.Map( PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), check_shapes=False), combinators.Map(combinators.Select(0), check_shapes=False), # drop masks combinators.Map(core.Dense(feature_depth)))
def test_serial_no_op_list(self): layer = cb.Serial([]) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape) input_shape = ((3, 2), (4, 7)) + _REST_OF_STACK expected_shape = ((3, 2), (4, 7)) + _REST_OF_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def MaskedScalar(metric_layer, mask_id=None, has_weights=False): """Metric as scalar compatible with Trax masking.""" # Stack of (inputs, targets) --> (metric, weight-mask). metric_and_mask = [ cb.Parallel( [], cb.Dup() # Duplicate targets ), cb.Parallel( metric_layer, # Metric: (inputs, targets) --> metric WeightMask(mask_id=mask_id) # pylint: disable=no-value-for-parameter ) ] if not has_weights: # Take (metric, weight-mask) and return the weighted mean. return cb.Serial([metric_and_mask, WeightedMean()]) # pylint: disable=no-value-for-parameter return cb.Serial([ metric_and_mask, cb.Parallel( [], cb.Multiply() # Multiply given weights by mask_id weights ), WeightedMean() # pylint: disable=no-value-for-parameter ])
def GeneralGRUCell(candidate_transform, memory_transform=combinators.Identity, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.GateBranches for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ return combinators.Serial( combinators.Branch(num_branches=3), combinators.Parallel( # s_{t-1} branch - optionally transform # Typically is an identity. memory_transform(), # u_t (Update gate) branch combinators.Serial( candidate_transform(), # Want bias to start out positive before sigmoids. core.AddConstant(constant=sigmoid_bias), gate_nonlinearity()), # c_t (Candidate) branch combinators.Serial( combinators.Branch(num_branches=2), combinators.Parallel( combinators.Identity(), # r_t (Reset) Branch combinators.Serial( candidate_transform(), # Want bias to start out positive before sigmoids. core.AddConstant(constant=sigmoid_bias), gate_nonlinearity())), ## Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) combinators.MultiplyBranches(), # Final projection + tanh to get Ct candidate_transform(), candidate_nonlinearity()), # Candidate gate # Only apply dropout on the C gate. # Paper reports that 0.1 is a good default. core.Dropout(rate=dropout_rate_c)), # Gate memory and candidate combinators.GateBranches())
def L2LossScalar(mask_id=None, has_weights=False): """L2 loss as scalar compatible with Trax masking.""" return cb.Serial(L2Scalar(mask_id=mask_id, has_weights=has_weights), core.MulConstant(constant=-1.0))
def CrossEntropyLossScalar(mask_id=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return cb.Serial( CrossEntropyScalar(mask_id=mask_id, has_weights=has_weights), core.MulConstant(constant=-1.0))
def test_serial_dup_dup(self): layer = cb.Serial(cb.Dup(), cb.Dup()) input_shape = (3, 2) expected_shape = ((3, 2), (3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_serial_div_div(self): layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0)) input_shape = (3, 2) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_serial_no_op(self): layer = cb.Serial(None) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_serial_one_in_one_out(self): layer = cb.Serial(core.Div(divisor=2.0)) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)