def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), core.AddConstant(constant=sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), core.AddConstant( constant=sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Branch([], reset_block), cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial([ cb.Branch(memory_transform, gate_block, candidate_block), cb.Gate(), ])
def GeneralGRUCell(candidate_transform, memory_transform=combinators.Identity, gate_nonlinearity=core.Sigmoid, candidate_nonlinearity=core.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.GateBranches for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ return combinators.Serial( combinators.Branch(num_branches=3), combinators.Parallel( # s_{t-1} branch - optionally transform # Typically is an identity. memory_transform(), # u_t (Update gate) branch combinators.Serial( candidate_transform(), # Want bias to start out positive before sigmoids. core.AddConstant(constant=sigmoid_bias), gate_nonlinearity()), # c_t (Candidate) branch combinators.Serial( combinators.Branch(num_branches=2), combinators.Parallel( combinators.Identity(), # r_t (Reset) Branch combinators.Serial( candidate_transform(), # Want bias to start out positive before sigmoids. core.AddConstant(constant=sigmoid_bias), gate_nonlinearity())), ## Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) combinators.MultiplyBranches(), # Final projection + tanh to get Ct candidate_transform(), candidate_nonlinearity()), # Candidate gate # Only apply dropout on the C gate. # Paper reports that 0.1 is a good default. core.Dropout(rate=dropout_rate_c)), # Gate memory and candidate combinators.GateBranches())