def forward(self, inputs): x, lstm_state = inputs # LSTM state consists of c and h. c, h = jnp.split(lstm_state, 2, axis=-1) # Dense layer on the concatenation of x and h. w, b = self.weights y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = jnp.split(y, 4, axis=-1) new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) return new_h, jnp.concatenate([new_c, new_h], axis=-1)
def Tanh(): r"""Returns a layer that computes the hyperbolic tangent function. .. math:: f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} """ return Fn('Tanh', lambda x: jnp.tanh(x))
def ActionInjector(mode): if inject_actions: if is_discrete: action_encoder = tl.Embedding(vocab_size, inject_actions_dim) else: action_encoder = tl.Dense(inject_actions_dim) encoders = tl.Parallel( tl.Dense(inject_actions_dim), action_encoder, ) if multiplicative_action_injection: action_injector = tl.Serial( tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)), tl.LayerNorm() # compensate for reduced variance ) else: action_injector = tl.Add() return tl.Serial( # Input: (body output, actions). encoders, action_injector, models.MLP( layer_widths=(inject_actions_dim, ) * inject_actions_n_layers, out_activation=True, flatten=False, mode=mode, )) else: return []
def forward(self, inputs): x, gru_state = inputs # Dense layer on the concatenation of x and h. w1, b1, w2, b2 = self.weights y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 # Update and reset gates. u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) # Candidate. c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) return new_gru_state, new_gru_state
def f(x): # pylint: disable=invalid-name return 0.5 * x * (1 + jnp.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))