예제 #1
0
  def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
    if not inputs.shape:
      raise ValueError("Input must not be scalar.")

    input_size = self.input_size = inputs.shape[-1]
    output_size = self.output_size
    dtype = inputs.dtype

    w_mu_init = self.w_mu_init
    
    if w_mu_init is None:
      stddev = 1. / np.sqrt(self.input_size)
      w_mu_init = hk.initializers.TruncatedNormal(stddev=stddev)
    w_mu = hk.get_parameter("w_mu", [input_size, output_size], dtype, init=w_mu_init)
    
    w_sigma_init = self.w_sigma_init
    if w_sigma_init is None:
      stddev = 1. / np.sqrt(self.input_size)
      w_sigma_init = hk.initializers.TruncatedNormal(stddev=stddev)
    w_sigma = hk.get_parameter("w_sigma", [input_size, output_size], dtype, init=w_sigma_init)

    if self.factorized:
        e_noise_input = jax.random.normal(next(self.rng), (w_sigma.shape[0], 1))
        e_noise_output = jax.random.normal(next(self.rng), (1, w_sigma.shape[1]))
        e_noise_input = jnp.multiply(jnp.sign(e_noise_input), jnp.sqrt(jnp.abs(e_noise_input)))
        e_noise_output = jnp.multiply(jnp.sign(e_noise_output), jnp.sqrt(jnp.abs(e_noise_output)))
        w_noise = jnp.matmul(e_noise_input, e_noise_output)
    else:
        w_noise = jax.random.normal(next(self.rng), w_sigma.shape)
    
    out_noisy = jnp.dot(inputs, jnp.add(w_mu, jnp.multiply(w_sigma, w_noise)))

    if self.with_bias:
      b_mu = hk.get_parameter("b_mu", [self.output_size], dtype, init=self.b_mu_init)
      b_sigma = hk.get_parameter("b_sigma", [self.output_size], dtype, init=self.b_sigma_init)
      b_mu = jnp.broadcast_to(b_mu, out_noisy.shape)
      b_sigma = jnp.broadcast_to(b_sigma, out_noisy.shape)
      b_noise = e_noise_output if self.factorized else jax.random.normal(next(self.rng), b_sigma.shape)
      out_noisy = out_noisy + jnp.add(b_mu, jnp.multiply(b_sigma, b_noise))
      
    return out_noisy
예제 #2
0
def step_directed_jit(kappa, pre_filtered, post_filtered, weights):
    error = pre_filtered - post_filtered
    w = jnp.minimum(weights, 1.0)
    nweights = weights/jnp.max(weights)
    r = jnp.where(error<0.,
                  jnp.where(w>0.0, error/(0.25*nweights), 0.0),
                  jnp.where(w>0.0, error*(1.01 - nweights), 0.0))
    h = jnp.where(w>0.0, (1.1 - nweights)*post_filtered, 0.0)
    #n = jnp.where(w>0.0, weights*pre_filtered, 0.0)
    d = kappa * (r - h)
    delta_sum = jnp.add(d, weights)
    return jnp.where(delta_sum < 0, 0., d)
예제 #3
0
def func(arg):
    divider = 0  # denominator
    numerator = 0
    for i in range(NUM_ARG):
        temp = np.dot(arg[i], var)
        temp1 = np.sin(temp)
        temp2 = np.cos(temp)

        divid = np.add(temp1, temp2)
        divid = np.power(divid, 2)
        divid = np.sum(divid)

        numer = np.add(temp1, temp2)
        numer = np.sum(numer)
        numer = np.power(numer, 2)
        numerator = np.add(numer, numerator)

        divider = np.add(divider, divid)
    divider = np.power(divider, 1 / 2)

    return np.log(np.divide(numerator, divider))
예제 #4
0
def conv2d(inputs, weights, strides, padding, dimension_numbers, bias=None):
    output = lax.conv_general_dilated(
        lhs=inputs,
        rhs=weights,
        window_strides=strides,
        padding=padding,
        dimension_numbers=dimension_numbers,
    )

    if bias:
        output = jnp.add(output, bias)
    return output
 def scale_tril(self):
     # The following identity is used to increase the numerically computation stability
     # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
     #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
     # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
     # hence it is well-conditioned and safe to take Cholesky decomposition.
     cov_diag_sqrt_unsqueeze = np.expand_dims(np.sqrt(self.cov_diag), axis=-1)
     Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
     K = np.matmul(Dinvsqrt_W, np.swapaxes(Dinvsqrt_W, -1, -2))
     K = np.add(K, np.identity(K.shape[-1]))
     scale_tril = cov_diag_sqrt_unsqueeze * np.linalg.cholesky(K)
     return scale_tril
예제 #6
0
def step_undirected_jit(kappa, pre_filtered, post_filtered, weights):
    error = pre_filtered - post_filtered
    abs_error = jnp.abs(error)
    max_abs_error = jnp.max(abs_error)
    w = jnp.minimum(weights, 1.0)
    nweights = weights/jnp.max(weights)
    r = (max_abs_error - abs_error)*(1.0 - nweights)
    h = jnp.where(w>0.0, weights*post_filtered, 0.0)
    #n = jnp.where(w>0.0, weights*pre_filtered, 0.0)
    d = kappa * (r - h)
    delta_sum = jnp.add(d, weights)
    return jnp.where(delta_sum < 0, 0., d)
예제 #7
0
def step_heun(state, t, params, diffusivity, stimuli, dt, dx):
    def euler(y, dy, h):
        return jax.tree_multimap(lambda v, dv: jnp.add(v, dv * h), y, dy)

    d_state = step(state, t, params, diffusivity, stimuli, dx)
    new_state = euler(state, d_state, dt)
    d_new_state = step(new_state, t, params, diffusivity, stimuli, dx)
    new_state = euler(
        state,
        jax.tree_multimap(lambda x, y: jnp.add(x, y), d_state, d_new_state),
        dt * 0.5,
    )
    return new_state
예제 #8
0
파일: loops_test.py 프로젝트: xiaoral2/jax
    def test_add_vectors(self):
        def add_vec(x, y):
            with loops.Scope() as s:
                n = x.shape[0]
                assert n == y.shape[0]
                s.out = jnp.zeros(shape=[n], dtype=jnp.float32)
                for i in s.range(n):
                    s.out = ops.index_add(s.out, i, x[i] + y[i])
                return s.out

        x = jnp.array([1., 2., 3.], dtype=jnp.float32)
        y = jnp.array([4., 5., 6.], dtype=jnp.float32)
        self.assertAllClose(jnp.add(x, y), add_vec(x, y))
예제 #9
0
def onnx_add(a, b, axis=None, broadcast=False):
    if broadcast:
        b_shape = []
        b_shape.extend(a.shape[:axis])
        b_shape.append(a.shape[axis])
        b_shape.extend([1] * len(a.shape[axis + 1:]))
        b = jnp.reshape(b, b_shape)
    elif len(a.shape) != len(b.shape):
        b_shape = [1] * len(a.shape)
        b_shape[1] = -1
        b = jnp.reshape(b, b_shape)

    return jnp.add(a, b)
예제 #10
0
    def convolution_op(self, params: Tuple, inputs: DeviceArray):
        output = lax.conv_general_dilated(
            lhs=inputs,
            rhs=params[0],
            window_strides=self.strides,
            padding=self.padding,
            dimension_numbers=self.dn,
        )
        if self.use_bias:
            output = jnp.add(output, params[1])

        if self.activation:
            output = self.activation(output)
        return output
예제 #11
0
 def model(self, batch):
     X = batch['X']
     y = batch['y']
     N, D = X.shape
     H = X
     # Forward pass
     num_layers = len(self.layers)
     for l in range(0,num_layers-2):
         D_X, D_H = self.layers[l], self.layers[l+1]
         W = sample('w%d' % (l+1), dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))
         b = sample('b%d' % (l+1), dist.Normal(np.zeros(D_H), np.ones(D_H)))
         H = np.tanh(np.add(np.matmul(H, W), b))
     D_X, D_H = self.layers[-2], self.layers[-1]
     # Output mean
     W = sample('w%d_mu' % (num_layers-1), dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))
     b = sample('b%d_mu' % (num_layers-1), dist.Normal(np.zeros(D_H), np.ones(D_H)))
     mu = np.add(np.matmul(H, W), b)
     # Output std
     W = sample('w%d_std' % (num_layers-1), dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))
     b = sample('b%d_std' % (num_layers-1), dist.Normal(np.zeros(D_H), np.ones(D_H)))
     sigma = np.exp(np.add(np.matmul(H, W), b))
     mu, sigma = mu.flatten(), sigma.flatten()
     # Likelihood
     sample("y", dist.Normal(mu, sigma), obs=y)
예제 #12
0
    def first_step(self, rho: float, grads: List[JaxArray]):
        assert len(grads) == len(
            self.train_vars
        ), 'Expecting as many gradients as trainable variables'
        # Create empty state dict
        self.state = defaultdict(dict)
        # norm grads
        grad_norm = self._grad_norm(grads)
        # create a scale factor
        scale = rho / (grad_norm + 1e-12)

        # loop through grads and params
        for g, p in zip(grads, self.train_vars):
            e_w = g * scale
            p.value = jn.add(p.value, e_w)
            self.state[str(p.ref)]["e_w"] = e_w
예제 #13
0
  def test_dim_vars_symbolic_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.symbolic_equal_dim(a, a))
    self.assertFalse(core.symbolic_equal_dim(a, 1))
    self.assertFalse(core.symbolic_equal_dim(a, b))

    self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, []))

    self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2]))
    self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(3, []))

    self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1)))  # A DeviceArray
    with self.assertRaisesRegex(TypeError,
                                re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")):
      self.assertTrue(core.symbolic_equal_dim(1, "a"))
예제 #14
0
    def apply(self, x, noise_rng, n_actions):
        dense_layer_1 = flax.nn.Dense(x, 64)
        activation_layer_1 = flax.nn.relu(dense_layer_1)
        noisy_layer = NoisyDense(activation_layer_1, noise_rng, 64)
        activation_layer_2 = flax.nn.relu(noisy_layer)

        noisy_value = NoisyDense(activation_layer_2, noise_rng, 64)
        value = flax.nn.relu(noisy_value)
        value = NoisyDense(value, noise_rng, 1)

        noisy_advantage = NoisyDense(activation_layer_2, noise_rng, 64)
        advantage = flax.nn.relu(noisy_advantage)
        advantage = NoisyDense(advantage, noise_rng, n_actions)

        advantage_average = jnp.mean(advantage, keepdims=True)

        q_values_layer = jnp.subtract(jnp.add(advantage, value),
                                      advantage_average)
        return q_values_layer
예제 #15
0
    def apply(self, x, n_actions):
        dense_layer_1 = flax.nn.Dense(x, 64)
        activation_layer_1 = flax.nn.relu(dense_layer_1)
        dense_layer_2 = flax.nn.Dense(activation_layer_1, 64)
        activation_layer_2 = flax.nn.relu(dense_layer_2)

        value_dense = flax.nn.Dense(activation_layer_2, 64)
        value = flax.nn.relu(value_dense)
        value = flax.nn.Dense(value, 1)

        advantage_dense = flax.nn.Dense(activation_layer_2, 64)
        advantage = flax.nn.relu(advantage_dense)
        advantage = flax.nn.Dense(advantage, n_actions)

        advantage_average = jnp.mean(advantage, keepdims=True)

        q_values_layer = jnp.subtract(jnp.add(advantage, value),
                                      advantage_average)
        return q_values_layer
def stream_assemble(nu, beta_fric, n, dx):
    A = jnp.zeros((n, n))
    nu_plus1 = jnp.roll(nu, -1)
    beta_fric_plus1 = jnp.roll(beta_fric, -1)
    A1 = jnp.array(4 * nu[0:n - 1] / dx + dx / 3. * beta_fric[0:n - 1]**2 +
                   4 * nu_plus1[0:n - 1] / dx +
                   dx / 3. * beta_fric_plus1[0:n - 1]**2)
    A1 = jnp.array(
        jnp.add(
            jnp.add(
                4 * nu[0:n - 1] / dx +
                dx / 3. * jnp.square(beta_fric[0:n - 1]),
                4 * nu_plus1[0:n - 1] / dx),
            dx / 3. * jnp.square(beta_fric_plus1[0:n - 1])))
    A1 = jnp.append(A1, 4 * nu[n - 1] / dx + dx / 3. * beta_fric[n - 1]**2)
    A = jnp.diag(A1)
    AL = jnp.diag(
        jnp.add(-4 * nu[1:n] / dx, dx / 6. * jnp.square(beta_fric[1:n])), -1)
    AU = jnp.diag(
        jnp.add(-4 * nu_plus1[0:n - 1] / dx,
                dx / 6. * jnp.square(beta_fric_plus1[0:n - 1])), 1)
    A = jnp.add(jnp.add(A, AL), AU)
    return A
예제 #17
0
파일: sum.py 프로젝트: gglin001/onnx-jax
def onnx_sum_4(x0, x1, x2, x3):
    return jnp.add(onnx_sum_3(x0, x1, x2), x3)
예제 #18
0
파일: jax2tf_test.py 프로젝트: alonfnt/jax
 def f(xs, y):
     return [jnp.add(x, y) for x in xs]
예제 #19
0
 def add(x):
   return (lambda x: jnp.add(x, x))(x)
예제 #20
0
 def add(x, y):
   return jnp.add(x, y)
예제 #21
0
 def add(x):
   return jnp.add(x, x)
def stream_vel_visc(h, u, n, dx):
    ux = jnp.diff(u) / dx
    tmp = jnp.add(jnp.square(ux), jnp.square(ep_glen))
    nu = .5 * h * Aglen**(-1. / nglen) * tmp**((1 - nglen) / 2. / nglen)
    return nu
예제 #23
0
 def f_jax(x, y):
     return jnp.add(x, y)
예제 #24
0
 def add(x, y):
     return np.add(x, y)
예제 #25
0
 def g(x, y):
   return lnp.add(x, y)
예제 #26
0
def step_jit(kappa, pre_filtered, post_filtered, weights):
    on = jnp.where(pre_filtered > 0, 0., 1.)
    d = kappa * post_filtered * on
    delta_sum = jnp.add(d, weights)
    return jnp.where(delta_sum <= 0, weights, d)
예제 #27
0
 def f():
     return np.add(3., 4.)
예제 #28
0
def _add_n(args):
    start = args[0]
    for arg in args:
        start = jnp.add(start, arg)
    return start
예제 #29
0
 def np_fn(a, b):
     return np.add(a, b)
예제 #30
0
파일: sum.py 프로젝트: gglin001/onnx-jax
def onnx_sum_5(x0, x1, x2, x3, x4):
    return jnp.add(onnx_sum_4(x0, x1, x2, x3), x4)