Exemplo n.º 1
0
    def forward(self, inputs, weights):
        del weights
        x1, x2 = inputs

        x1_split = np.split(x1, self._n_sections, self._axis)
        x2_split = np.split(x2, self._n_sections, self._axis)

        res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
        return tuple(res)
Exemplo n.º 2
0
    def forward(self, inputs, weights):
        x, lstm_state = inputs

        # LSTM state consists of c and h.
        c, h = jnp.split(lstm_state, 2, axis=-1)

        # Dense layer on the concatenation of x and h.
        w, b = weights
        y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = jnp.split(y, 4, axis=-1)

        new_c = c * math.sigmoid(f) + math.sigmoid(i) * jnp.tanh(j)
        new_h = jnp.tanh(new_c) * math.sigmoid(o)
        return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Exemplo n.º 3
0
    def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
        del weights, kwargs

        x1_split = []
        x2_split = []
        for y in output:
            y1, y2 = np.split(y, 2, -1)
            x1_split.append(y1)
            x2_split.append(y2)

        x1 = np.concatenate(x1_split, self._axis)
        x2 = np.concatenate(x2_split, self._axis)

        return (x1, x2)
Exemplo n.º 4
0
    def forward(self, inputs, weights):
        x, gru_state = inputs

        # Dense layer on the concatenation of x and h.
        w1, b1, w2, b2 = weights
        y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

        # Update and reset gates.
        u, r = jnp.split(math.sigmoid(y), 2, axis=-1)

        # Candidate.
        c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

        new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
        return new_gru_state, new_gru_state
Exemplo n.º 5
0
  def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
    del weights, kwargs
    if not isinstance(output, (list, tuple)):
      output = [output]

    x1_split = []
    x2_split = []
    for y in output:
      y1, y2 = np.split(y, 2, -1)
      x1_split.append(y1)
      x2_split.append(y2)

    x1 = np.concatenate(x1_split, self._axis)
    x2 = np.concatenate(x2_split, self._axis)

    return (x1, x2)
Exemplo n.º 6
0
 def forward(self, inputs, weights):
   del weights
   return tuple(np.split(inputs, self._n_items, self._axis))
Exemplo n.º 7
0
 def forward(self, inputs):
     return tuple(jnp.split(inputs, self._n_items, self._axis))