示例#1
0
  def forward(self, inputs):
    x, lstm_state = inputs

    # LSTM state consists of c and h.
    c, h = jnp.split(lstm_state, 2, axis=-1)

    # Dense layer on the concatenation of x and h.
    w, b = self.weights
    y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = jnp.split(y, 4, axis=-1)

    new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j)
    new_h = jnp.tanh(new_c) * fastmath.sigmoid(o)
    return new_h, jnp.concatenate([new_c, new_h], axis=-1)
示例#2
0
def Tanh():
  r"""Returns a layer that computes the hyperbolic tangent function.

  .. math::
      f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
  """
  return Fn('Tanh', lambda x: jnp.tanh(x))
示例#3
0
文件: rl.py 项目: stephenjfox/trax
 def ActionInjector(mode):
     if inject_actions:
         if is_discrete:
             action_encoder = tl.Embedding(vocab_size, inject_actions_dim)
         else:
             action_encoder = tl.Dense(inject_actions_dim)
         encoders = tl.Parallel(
             tl.Dense(inject_actions_dim),
             action_encoder,
         )
         if multiplicative_action_injection:
             action_injector = tl.Serial(
                 tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)),
                 tl.LayerNorm()  # compensate for reduced variance
             )
         else:
             action_injector = tl.Add()
         return tl.Serial(
             # Input: (body output, actions).
             encoders,
             action_injector,
             models.MLP(
                 layer_widths=(inject_actions_dim, ) *
                 inject_actions_n_layers,
                 out_activation=True,
                 flatten=False,
                 mode=mode,
             ))
     else:
         return []
示例#4
0
  def forward(self, inputs):
    x, gru_state = inputs

    # Dense layer on the concatenation of x and h.
    w1, b1, w2, b2 = self.weights
    y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

    # Update and reset gates.
    u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1)

    # Candidate.
    c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

    new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
    return new_gru_state, new_gru_state
示例#5
0
 def f(x):  # pylint: disable=invalid-name
   return 0.5 * x * (1 + jnp.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))