def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Dup(), reset_block, cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial( cb.Dup(), cb.Dup(), cb.Parallel(memory_transform, gate_block, candidate_block), cb.Gate(), )
def SumOfWeights(id_to_mask=None, has_weights=False): """Returns a layer to compute sum of weights of all non-masked elements.""" multiply_by_weights = cb.Multiply() if has_weights else [] return cb.Serial( cb.Drop(), # Drop inputs. _ElementMask(id_to_mask=id_to_mask), multiply_by_weights, core.Sum(axis=None) # Sum all. )
def _WeightedMaskedMean(metric_layer, id_to_mask, has_weights): """Computes weighted masked mean of metric_layer(predictions, targets).""" multiply_by_weights = cb.Multiply() if has_weights else [] # Create a layer with 2 or 3 inputs: # - predictions targets (weights) # that applies the specified metric to a batch and gathers the results into # a single scalar. return cb.Serial( cb.Select([0, 1, 1]), cb.Parallel(metric_layer, _ElementMask(id_to_mask=id_to_mask)), cb.Parallel([], multiply_by_weights), # Stack now: metric_values weights _WeightedMean() )
def CountWeights(mask_id=None, has_weights=False): """Sum the weights assigned to all elements.""" if has_weights: return cb.Serial( cb.Drop(), # Drop inputs. WeightMask(mask_id=mask_id), # pylint: disable=no-value-for-parameter cb.Multiply(), # Multiply with provided mask. core.Sum(axis=None) # Sum all weights. ) return cb.Serial( cb.Drop(), # Drop inputs. WeightMask(mask_id=mask_id), # pylint: disable=no-value-for-parameter core.Sum(axis=None) # Sum all weights. )
def MaskedScalar(metric_layer, mask_id=None, has_weights=False): """Metric as scalar compatible with Trax masking.""" # Stack of (inputs, targets) --> (metric, weight-mask). metric_and_mask = [ cb.Parallel( [], cb.Dup() # Duplicate targets ), cb.Parallel( metric_layer, # Metric: (inputs, targets) --> metric WeightMask(mask_id=mask_id) # pylint: disable=no-value-for-parameter ) ] if not has_weights: # Take (metric, weight-mask) and return the weighted mean. return cb.Serial(metric_and_mask, WeightedMean()) # pylint: disable=no-value-for-parameter return cb.Serial( metric_and_mask, cb.Parallel( [], cb.Multiply() # Multiply given weights by mask_id weights ), WeightedMean() # pylint: disable=no-value-for-parameter )
def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=activation_fns.Sigmoid, candidate_nonlinearity=activation_fns.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations for update gate, reset gate, candidate memory, and new state: .. math:: u_t &= \sigma(U' \times s_{t-1} + B') \\ r_t &= \sigma(U'' \times s_{t-1} + B'') \\ c_t &= \tanh(U \times (r_t \odot s_{t-1}) + B) \\ s_t &= u_t \odot s_{t-1} + (1 - u_t) \odot c_t See `combinators.Gate` for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation; allows trying alternatives to `Sigmoid`, such as `HardSigmoid`. candidate_nonlinearity: Nonlinearity to apply after candidate branch; allows trying alternatives to traditional `Tanh`, such as `HardTanh`. dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), _AddSigmoidBias(sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), _AddSigmoidBias(sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Dup(), reset_block, cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial( cb.Branch(memory_transform, gate_block, candidate_block), cb.Gate(), )