Exemple #1
0
 def postnet(self, mel: ndarray) -> ndarray:
     x = mel
     for conv, bn in zip(self.postnet_convs, self.postnet_bns):
         x = conv(x)
         if bn is not None:
             x = bn(x, is_training=self.is_training)
             x = jnp.tanh(x)
         x = hk.dropout(hk.next_rng_key(), 0.5,
                        x) if self.is_training else x
     return x
Exemple #2
0
def predict(params, inputs):
    activations = inputs
    for w, b in params[:-1]:
        outputs = np.dot(activations, w) + b
        # print("w has a shape{}".format(w.shape))
        activations = np.tanh(outputs)

    final_w, final_b = params[-1]
    logits = np.dot(activations, final_w) + final_b
    return logits - logsumexp(logits, axis=1, keepdims=True)
Exemple #3
0
def sigmoid(x, version="tanh"):
    """Sigmoid activation function.

    Two versions are provided: "tanh" and "exp".
    "exp" is used in the re-implementation.
    """
    sigmoids = {
        "tanh": lambda x: 0.5 * np.tanh(x) + 0.5,
        "exp": lambda x: safe_sigmoid_exp(x),
    }
    return sigmoids[version](x)
def tanh(x):
    """
        Return the activation after a hyperbolic tangent function.

        Args:
            x (numpy.dtype): The input sum for the activation function.

        Returns:
            The activation value.
    """
    return jnp.tanh(x)
Exemple #5
0
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    #outputs = np.tanh(activations)
    #activations = np.dot(inputs, w) + b 
    outputs = np.dot(activations, w) + b 
    activations = np.tanh(outputs)

  final_w, final_b = params[-1]
  logits = np.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
Exemple #6
0
    def apply(self, params, inputs, state):
        update_gate = params['update_gate']
        reset_gate = params['reset_gate']
        cell_state = params['cell_state']

        update = jax.nn.sigmoid(
            update_gate.apply(inputs, state) + self.gate_bias)
        reset = jax.nn.sigmoid(reset_gate.apply(inputs, state))
        cell = jnp.tanh(cell_state.apply(inputs, reset * state))

        return update * state + (1 - update) * cell
Exemple #7
0
def planar_flow(params, z):
    """
    Transforms z using planar transformations.

    :param z: Samples or transformed samples.
    :param params: A dictionary of tunable parameters.
    """
    a = (np.dot(params["w"].T, z.T) + params["b"]
         )  # (w: (dim, 1), z: (n, dim), b: scalar, a: (1, n))
    return (z + (params["u"] * np.tanh(a)).T
            )  # (u: (dim, 1), a: (1, n), z: (n, dim))
Exemple #8
0
    def test_jacobian(self):
        R = onp.random.RandomState(0).randn
        A = R(4, 3)
        x = R(3)

        f = lambda x: np.dot(A, x)
        assert onp.allclose(jacfwd(f)(x), A)
        assert onp.allclose(jacrev(f)(x), A)

        f = lambda x: np.tanh(np.dot(A, x))
        assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
Exemple #9
0
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(np.dot(both, param('update_kernel')))
        reset = sigmoid(np.dot(both, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out
def forward(params, t, X):
    input = jnp.concatenate((t, X), 0)  # M x D+1
    activations = input

    for w, b in params[:-1]:
        outputs = jnp.dot(activations, w) + b
        activations = jnp.tanh(outputs)  # relu(outputs)

    final_w, final_b = params[-1]
    u = jnp.dot(activations, final_w) + final_b
    return jnp.reshape(u, ())  # need scalar for grad
Exemple #11
0
def logistic_preprocess(nn_out):
    *batch, h, w, _ = nn_out.shape
    assert nn_out.shape[-1] % 10 == 0
    k = nn_out.shape[-1] // 10
    logit_weights, nn_out = jnp.split(nn_out, [k], -1)
    m, s, t = jnp.moveaxis(jnp.reshape(nn_out,
                                       tuple(batch) + (h, w, 3, k, 3)),
                           (-2, -1), (-4, 0))
    assert m.shape == tuple(batch) + (k, h, w, 3)
    inv_scales = jnp.maximum(nn.softplus(s), 1e-7)
    return m, jnp.tanh(t), inv_scales, jnp.moveaxis(logit_weights, -1, -3)
Exemple #12
0
    def apply_fun_scan(params, hidden_cell, inp):
        """ Perform single timestep update of the network. """
        _, (forget_W, forget_U, forget_b), (in_W, in_U, in_b), (
            out_W, out_U, out_b), (change_W, change_U, change_b) = params

        hidden, cell = hidden_cell
        input_gate = sigmoid(np.dot(inp, in_W) + np.dot(hidden, in_U) + in_b)
        change_gate = np.tanh(np.dot(inp, change_W) + np.dot(hidden, change_U)
                              + change_b)
        forget_gate = sigmoid(np.dot(inp, forget_W) + np.dot(hidden, forget_U)
                              + forget_b)

        cell = np.multiply(change_gate, input_gate) + np.multiply(cell,
                                                                  forget_gate)

        output_gate = sigmoid(np.dot(inp, out_W)
                              + np.dot(hidden, out_U) + out_b)
        output = np.multiply(output_gate, np.tanh(cell))
        hidden_cell = (hidden, cell)
        return hidden_cell, hidden_cell
Exemple #13
0
def decode_one_step(params, hps, keep_rate, use_mean, state, inputs):
    """Run the LFADS network from latent variables to log rates one time step.

  Args:
    params: a dictionary of LFADS parameters
    hps: a dictionary of LFADS hyperparameters
    keep_rate: dropout keep rate
    use_mean: Use the mean of the posteror dist, not a sample.
    state: dict of state variables for the decoder RNN
      (controller, inferred input sample, generator, factors)
    inputs: dict of inputs to decoder RNN (keys, inferred bias and data
      encoding).

  Returns:
    A dict of decode values at time t,
      (controller hidden state, inferred input (ii) sample,
       ii mean, ii log var, log rates, factors,  generator hidden state,
       factors, log rates)
  """
    key = inputs['keys_t']
    ccb = inputs['ccb_t']
    ib = inputs['ib_t']
    xenc = inputs['xenc_t']
    c = state['c']
    f = state['f']
    g = state['g']
    # A bit weird but the 'inferred input' is actually state because
    # samples are generated during the decoding pass.
    ii = state['ii']

    keys = random.split(key, 2)
    cin = np.concatenate([xenc, f, ii], axis=0)
    c = gru(params['con'], c, cin)
    cout = affine(params['con_out'], c)
    ii_mean, ii_logvar = np.split(cout, 2, axis=0)  # inferred input params
    ii_sample = dists.diag_gaussian_sample(keys[0], ii_mean, ii_logvar,
                                           hps['var_min'])
    ii = np.where(use_mean, ii_mean, ii_sample)
    ii = np.where(hps['do_tanh_latents'], np.tanh(ii), ii)
    g = gru(params['gen'], g, np.concatenate([ii, ib, ccb], axis=0))
    g = dropout(g, keys[1], keep_rate)
    f = normed_linear(params['factors'], g)
    lograte = affine(params['logrates'], f)
    return {
        'c': c,
        'ccb': ccb,
        'g': g,
        'f': f,
        'ib': ib,
        'ii': ii,
        'ii_mean': ii_mean,
        'ii_logvar': ii_logvar,
        'lograte': lograte
    }
    def __call__(self, x, **kwargs):
        init = hk.initializers.RandomNormal()
        weight_logits = hk.get_parameter("weight_logits",
                                         (self.n_components, ),
                                         init=jnp.zeros)
        means = hk.get_parameter("means", (self.n_components, ), init=init)
        log_scales = hk.get_parameter("log_scales", (self.n_components, ),
                                      init=jnp.zeros)

        log_scales = 1.5 * jnp.tanh(log_scales)
        z = logistic_cdf_mixture_logit(weight_logits, means, log_scales, x)
        return z
Exemple #15
0
def predict(t, params_a, params_b, inputs):
    params = [((1 - t) * wa + t * wb, (1 - t) * ba + t * bb)
              for (wa, ba), (wb, bb) in zip(params_a, params_b)]

    activations = inputs
    for w, b in params[:-1]:
        outputs = jnp.dot(activations, w) + b
        activations = jnp.tanh(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(activations, final_w) + final_b
    return logits - logsumexp(logits, axis=1, keepdims=True)
    def _full_forward(self, theta):

        x = jnp.dot(theta[0], self.Phi_trans)
        for i in range(self.depth):
            if self.biases:
                x = x + theta[i + self.depth + 1]
            if self.activation == 'ReLU':
                x = jnp.maximum(x, 0.0)
            elif self.activation == 'tanh':
                x = jnp.tanh(x)
            x = jnp.dot(theta[i+1],x)
        return x[0,:]
Exemple #17
0
def OVM(veh, lead, p, leadlen, relax, *args):

    # regime drenotes what the model is in.

    # regime = 0 is reserved for something that has no dependence on parameters: this could be the shifted end, or it could be a model regime that has no dependnce on parameters (e.g. acceleration is bounded)

    regime = 1

    out = jnp.zeros((1, 2))

    # find and replace all tanh, then brackets to paranthesis, then rename all the variables

    out[0, 0] = veh[1]

    out[0, 1] = p[3] * (p[0] *
                        (jnp.tanh(p[1] * (lead[0] - leadlen - veh[0] + relax) -
                                  p[2] - p[4]) - jnp.tanh(-p[2])) - veh[1])

    # could be a good idea to make another helper function which adds this to the current value so we can constrain velocity?

    return out, regime
Exemple #18
0
def embed_oar(features: Array, action: Array, reward: Array,
              num_actions: int) -> Array:
  """Embed each of the (observation, action, reward) inputs & concatenate."""
  chex.assert_rank([features, action, reward], [2, 1, 1])
  action = jax.nn.one_hot(action, num_classes=num_actions)  # [B, A]

  reward = jnp.tanh(reward)
  while reward.ndim < action.ndim:
    reward = jnp.expand_dims(reward, axis=-1)

  embedding = jnp.concatenate([features, action, reward], axis=-1)  # [B, D+A+1]
  return embedding
Exemple #19
0
    def __call__(self, x):
        # transform to (-1, 1) interval
        t = jnp.tanh(x)

        # apply stick-breaking transform
        remainder = jnp.cumprod(1 - jnp.abs(t[..., :-1]), axis=-1)
        pad_width = [(0, 0)] * (t.ndim - 1) + [(1, 0)]
        remainder = jnp.pad(remainder,
                            pad_width,
                            mode="constant",
                            constant_values=1.0)
        return t * remainder
Exemple #20
0
  def __call__(self, inputs, state):
    prev_h, prev_c = state

    gates = conv.ConvND(
        num_spatial_dims=self._num_spatial_dims,
        output_channels=4*self.output_channels,
        kernel_shape=self.kernel_shape,
        name="input_to_hidden")(
            inputs)
    gates += conv.ConvND(
        num_spatial_dims=self._num_spatial_dims,
        output_channels=4*self.output_channels,
        kernel_shape=self.kernel_shape,
        name="hidden_to_hidden")(
            prev_h)
    i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)

    f = jax.nn.sigmoid(f + 1)
    c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g)
    h = jax.nn.sigmoid(o) * jnp.tanh(c)
    return h, (h, c)
Exemple #21
0
def decode(params, hps, key, keep_rate, encodes, class_id=-1, use_mean=False):
    """Run the LFADS network from latent variables to log rates.

    Since the factors (and inferred input) feed back to the controller,
      factors_{t-1} -> controller_t -> ii_t -> generator_t -> factors_t
      is really one big loop and therefor one RNN.

  Args:
    params: a dictionary of LFADS parameters
    hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    keep_rate: dropout keep rate
    encodes: dictionary of variables from lfads encoding, including
      (ib_mean, ib_logvar, ic_mean, ic_logvar, xenc_t)
    class_id: int, indicating one-hot encoding for class conditional bias
    use_mean: Use the mean of the posteror dist, not a sample.

  Returns:
    dictionary of lfads decoding variables, including:
      (controller state, generator state, factors, inferred input,
       inferred bias, inferred input mean, inferred input log var,
       log rates)
  """
    keys = random.split(key, 2)

    ii0 = params['ii0']
    ii0 = np.where(hps['do_tanh_latents'], np.tanh(ii0), ii0)
    c0 = params['con']['h0']

    # All the randomness for all T steps at once for efficiency.
    xenc_t = encodes['xenc_t']
    T = xenc_t.shape[0]
    keys_t = random.split(keys[0], T)
    ib_t = np.tile(encodes['ib'],
                   (T, 1))  # as time-dependent input in decoding.

    class_one_hot = one_hot(hps['nclasses'], class_id)
    class_one_hot_txc = np.tile(class_one_hot, (hps['ntimesteps'], 1))

    inputs = {
        'ccb_t': class_one_hot_txc,
        'ib_t': ib_t,
        'keys_t': keys_t,
        'xenc_t': xenc_t
    }
    # A bit weird but the 'inferred input' is actually state because
    # samples are generated during the decoding pass.
    state0 = {'c': c0, 'f': params['f0'], 'g': encodes['g0'], 'ii': ii0}
    decoder = functools.partial(decode_one_step_scan,
                                *(params, hps, keep_rate, use_mean))
    _, decodes = lax.scan(decoder, state0, inputs)
    return decodes
 def forward(self, H, sample):
     num_layers = len(self.layers)
     for l in range(0, num_layers - 2):
         W = sample['w%d' % (l + 1)]
         b = sample['b%d' % (l + 1)]
         H = np.tanh(np.add(np.matmul(H, W), b))
     W = sample['w%d_mu' % (num_layers - 1)]
     b = sample['b%d_mu' % (num_layers - 1)]
     mu = np.add(np.matmul(H, W), b)
     W = sample['w%d_std' % (num_layers - 1)]
     b = sample['b%d_std' % (num_layers - 1)]
     sigma = np.exp(np.add(np.matmul(H, W), b))
     return mu, sigma
Exemple #23
0
    def __action_and_log_Pi(params, state, rng, clip):
        # action
        nn_out = SharedNetwork.apply_Pi(params, state)
        batch_size, out_num = nn_out.shape
        assert (out_num // 2 == EnAction.num)
        means = nn_out[:, :EnAction.num]
        lsigs = nn_out[:, EnAction.num:]
        lsigs = jnp.log(jax.nn.sigmoid(lsigs))
        assert (means.shape == (batch_size, EnAction.num))
        assert (lsigs.shape == (batch_size, EnAction.num))

        epss = jrandom.normal(rng, shape=(batch_size, EnAction.num))
        epss = jax.lax.cond(clip, SharedNetwork.__clip_eps, lambda x: x, epss)
        sigs = jnp.exp(lsigs)
        action = means + sigs * epss
        exploit_action = means
        assert (action.shape == (batch_size, EnAction.num))
        assert (exploit_action.shape == (batch_size, EnAction.num))

        log_pi = -(((action - means)**2) /
                   (2 *
                    (sigs**2))).sum(axis=-1) - EnAction.num * 0.5 * jnp.log(
                        2 * jnp.pi) - lsigs.sum(axis=-1)
        log_pi = log_pi.reshape((batch_size, 1))
        action = jnp.tanh(action)
        log_pi = log_pi - jnp.log((1.0 - action * action) + 1E-5).sum(
            axis=-1, keepdims=True)

        exploit_log_pi = -(
            ((exploit_action - means)**2) /
            (2 * (sigs**2))).sum(axis=-1) - EnAction.num * 0.5 * jnp.log(
                2 * jnp.pi) - lsigs.sum(axis=-1)
        exploit_log_pi = exploit_log_pi.reshape((batch_size, 1))
        exploit_action = jnp.tanh(exploit_action)
        exploit_log_pi = exploit_log_pi - jnp.log(
            (1.0 - exploit_action * exploit_action) + 1E-5).sum(axis=-1,
                                                                keepdims=True)

        return action, log_pi, exploit_action, exploit_log_pi, means, sigs
Exemple #24
0
    def __call__(
        self,
        inputs,
        state: LSTMState,
    ) -> Tuple[jnp.ndarray, LSTMState]:
        input_to_hidden = hk.ConvND(num_spatial_dims=self.num_spatial_dims,
                                    output_channels=4 * self.output_channels,
                                    kernel_shape=self.kernel_shape,
                                    name="input_to_hidden")

        hidden_to_hidden = hk.ConvND(num_spatial_dims=self.num_spatial_dims,
                                     output_channels=4 * self.output_channels,
                                     kernel_shape=self.kernel_shape,
                                     name="hidden_to_hidden")

        gates = input_to_hidden(inputs) + hidden_to_hidden(state.hidden)
        i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1)

        f = jax.nn.sigmoid(f + 1)
        c = f * state.cell + jax.nn.sigmoid(i) * jnp.tanh(g)
        h = jax.nn.sigmoid(o) * jnp.tanh(c)
        return h, LSTMState(h, c)
def main(_):
    # Define the total number of training steps
    training_iters = 200

    rng = random.PRNGKey(0)

    rng, key = random.split(rng)

    init_random_params, model_apply = stax.serial(stax.Dense(256), stax.Relu,
                                                  stax.Dense(256), stax.Relu,
                                                  stax.Dense(2))

    # init the model
    _, params = init_random_params(rng, (-1, 2))

    # Create the optimizer corresponding to the 0th hyperparameter configuration
    # with the specified amount of training steps.
    # opt = optix.adam(1e-4)
    opt = jax_optix_opt_list.optimizer_for_idx(0, training_iters)

    opt_state = opt.init(params)

    @jax.jit
    def loss_fn(params, batch):
        x, y = batch
        y_hat = model_apply(params, x)
        return jnp.mean(jnp.square(y_hat - y))

    @jax.jit
    def train_step(params, opt_state, batch):
        """Train for a single step."""
        value_and_grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = value_and_grad_fn(params, batch)

        # Note this is not the usual optix api as we additionally need parameter
        # values.
        # updates, opt_state = opt.update(grad, opt_state)
        updates, opt_state = opt.update_with_params(grad, params, opt_state)

        new_params = optix.apply_updates(params, updates)
        return new_params, opt_state, loss

    for _ in range(training_iters):
        # make a random batch of fake data
        rng, key = random.split(rng)
        inp = random.normal(key, [512, 2]) / 4.
        target = jnp.tanh(1 / (1e-6 + inp))

        # train the model a step
        params, opt_state, loss = train_step(params, opt_state, (inp, target))
        print(loss)
 def tensor_constrain(self, u, min_bounds, max_bounds):
     """Constrains a control vector tensor variable between given bounds through
     a squashing function.
     This is implemented with Theano, so as to be auto-differentiable.
     Args:
         u: Control vector tensor variable [action_size].
         min_bounds: Minimum control bounds [action_size].
         max_bounds: Maximum control bounds [action_size].
     Returns:
         Constrained control vector tensor variable [action_size].
     """
     diff = (max_bounds - min_bounds) / 2.0
     mean = (max_bounds + min_bounds) / 2.0
     return diff * np.tanh(u) + mean
Exemple #27
0
def model_loss(params, inputs, l2_reg):
  """Evaluate the standard autoencoder."""
  h = inputs.reshape([inputs.shape[0], -1])
  for i, layer_params in enumerate(params):
    h = fully_connected_layer(layer_params, h)
    # Last layer does not have a nonlinearity
    if i % 4 != 3:
      h = jnp.tanh(h)
  l2_value = 0.5 * sum(jnp.square(p).sum() for p in jax.tree_leaves(params))
  error = jax.nn.sigmoid(h) - inputs.reshape([inputs.shape[0], -1])
  mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0)
  regularized_loss = mean_squared_error + l2_reg * l2_value

  return regularized_loss, dict(mean_squared_error=mean_squared_error)
    def predict(params, inputs, with_classifier=True):
        x = inputs.reshape(
            (inputs.shape[0],
             np.prod(inputs.shape[1:])))  # flatten to f32[B, 784]
        for w, b in params[:-1]:
            x = jnp.dot(x, w) + b
            x = jnp.tanh(x)

        if not with_classifier:
            return x
        final_w, final_b = params[-1]
        logits = jnp.dot(x, final_w) + final_b
        return logits - jax.scipy.special.logsumexp(
            logits, axis=1, keepdims=True)
Exemple #29
0
        def update(state, x):
            nhid = self.nhid

            h, cell = state[0], state[1]
            h = h.squeeze()
            cell = cell.squeeze()

            x_components = self.i2h(x)
            h_components = self.h2h(h)

            preactivations = x_components + h_components

            gates_together = jax.nn.sigmoid(preactivations[:, 0:3 * nhid])
            forget_gate = gates_together[:, 0:nhid]
            input_gate = gates_together[:, nhid:2 * nhid]
            output_gate = gates_together[:, 2 * nhid:3 * nhid]
            new_cell = jnp.tanh(preactivations[:, 3 * nhid:4 * nhid])

            cell = forget_gate * cell + input_gate * new_cell
            h = output_gate * jnp.tanh(cell)

            new_state = jnp.stack([h, cell])
            return new_state, h
Exemple #30
0
  def __call__(self, inputs, rng, **kwargs):


    if self.triangular_jacobian:
      dx = jnp.ones_like(inputs)

    # Main autoregressive transform
    x = inputs
    layer_sizes = [self.dim] + self.hidden_layer_sizes + [self.dim]
    self.gen_masks(self.input_sel, layer_sizes[1:], rng)

    prev_sel = self.sels[0]
    for i, (mask, sel, input_size, output_size) in enumerate(zip(self.masks, \
                                                                 self.sels[1:], \
                                                                 layer_sizes[:-1], \
                                                                 layer_sizes[1:])):
      w, b = self.get_params(i, x, output_size)

      w_masked = w*mask
      x = jnp.dot(x, w_masked) + b

      if self.triangular_jacobian:
        nonlinearity_grad = jax.grad(self.nonlinearity)
        for i in range(x.ndim):
          nonlinearity_grad = vmap(nonlinearity_grad)

        diag_mask = prev_sel[:,None] == sel
        dx = jnp.dot(dx, (w*diag_mask))

        if i < len(self.masks) - 1:
          dx *= nonlinearity_grad(x)

      if i < len(self.masks) - 1:
        x = self.nonlinearity(x)

      prev_sel = sel

    w_mu = hk.get_parameter("w_mu", [self.dim, self.dim], x.dtype, init=w_init)
    mu = jnp.dot(x, w_mu*self.out_mask)

    if self.triangular_jacobian:
      diag_mask = prev_sel[:,None] == self.input_sel
      dmu = jnp.dot(dx, (w_mu*diag_mask))
      return mu, dmu

    w_alpha = hk.get_parameter("w_alpha", [self.dim, self.dim], x.dtype, init=w_init)
    alpha = jnp.dot(x, w_alpha*self.out_mask)
    alpha_bounded = jnp.tanh(alpha)

    return mu, alpha_bounded
Exemple #31
0
def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs