Ejemplo n.º 1
0
def matmul_maybe_select(A, B):
    """Perform Matrix multiplication C = A * B but A could be an integer id vector.

    If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.
    In this case, the expensive dense matrix multiply can be replaced by a much
    cheaper index lookup.

    For example,
    ::

        A = [2, 0, 1],
        B = [[0.1, 0.2],
             [0.3, 0.4],
             [0.5, 0.6]]

    then matmul_maybe_select(A, B) is equivalent to
    ::

        [[0, 0, 1],     [[0.1, 0.2],
         [1, 0, 0],  *   [0.3, 0.4],
         [0, 1, 0]]      [0.5, 0.6]]

    In all other cases, perform a normal matmul.

    Parameters
    ----------
    A : torch.Tensor
        lhs tensor
    B : torch.Tensor
        rhs tensor

    Returns
    -------
    C : torch.Tensor
        result tensor
    """
    if A.dtype == jnp.int64 and len(A.shape) == 1:
        return jnp.take(B, A, 0)
    else:
        return jnp.matmul(A, B)
Ejemplo n.º 2
0
    def extend(prev_three_coords, point, multi_m):
        """
        Aligns an atom or an entire fragment depending on value of `multi_m`
        with the preceding three atoms.
        :param prev_three_coords: Named tuple storing the last three atom
        coordinates ("a", "b", "c") where "c" is the current end of the
        structure (i.e. closest to the atom/ fragment that will be added now).
        Shape NUM_DIHEDRALS x [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMENSIONS].
        First rank depends on value of `multi_m`.
        :param point: Point describing the atom that is added to the structure.
        Shape [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
        First rank depends on value of `multi_m`.
        :param multi_m: If True, a single atom is added to the chain for
        multiple fragments in parallel. If False, an single fragment is added.
        Note the different parameter dimensions.
        :return: Coordinates of the atom/ fragment.
        """
        # Normalize rows: https://necromuralist.github.io/neural_networks/posts/normalizing-with-numpy/
        Xbc = (prev_three_coords.c - prev_three_coords.b)
        bc = Xbc / onp.linalg.norm(Xbc, axis=-1, keepdims=True)

        Xn = onp.cross(prev_three_coords.b - prev_three_coords.a,
                       bc,
                       axisa=-1,
                       axisb=-1,
                       axisc=-1)
        n = Xn / onp.linalg.norm(Xn, axis=-1, keepdims=True)

        if multi_m:  # multiple fragments, one atom at a time
            m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]),
                              (1, 2, 3, 0))
        else:  # single fragment, reconstructed entirely at once.
            s = point.shape + (3, )  # +
            m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]), (1, 2, 0))
            m = onp.tile(m, (s[0], 1, 1)).reshape(s)

        coord = onp.squeeze(onp.matmul(m, onp.expand_dims(point, axis=3)),
                            axis=3) + prev_three_coords.c

        return coord
Ejemplo n.º 3
0
def group_utility(particle_weights,
                  particles,
                  groups,
                  group_sensitivities,
                  group_specificities,
                  utility_fun):
  """Compute the utility of a set of groups.

  This function computes the utility of a set of groups, given a distribution
  over the population status encoded as a weighted sum of Dirac measures on
  particles, the specificities and sensitivities of tests, and a utility
  function.

  Args:
   particle_weights: weights of particles
   particles: particles summarizing belief about infection status
   groups: set of groups to be tested
   group_sensitivities: sensitivies of test for each group
   group_specificities: specificities of test for each group
   utility_fun: a utility function that takes as input (particle_weights,
      particles) and output the utility of the distribution

  Returns:
   The expected utility (over the test results) of the posterior
  """
  num_groups = groups.shape[0]
  proba_y_is_one_given_x = (np.matmul(particles, np.transpose(groups))
                            * (group_sensitivities + group_specificities - 1)
                            + 1.0 - group_specificities)
  proba_y_is_one_given_x = np.expand_dims(proba_y_is_one_given_x, axis=2)
  test_res = np.array(list(itertools.product([0, 1], repeat=num_groups)))
  test_res = np.expand_dims(np.transpose(test_res), axis=0)
  proba_y_given_x = np.product(test_res * proba_y_is_one_given_x + (1-test_res)
                               * (1-proba_y_is_one_given_x), axis=1)
  proba_y_and_x = proba_y_given_x * np.expand_dims(particle_weights, 1)
  proba_y = np.sum(proba_y_and_x, axis=0)
  proba_x_given_y = proba_y_and_x / np.expand_dims(proba_y, 0)
  vutility_fun = jax.vmap(utility_fun, [1, None])
  utility_x_given_y = vutility_fun(proba_x_given_y, particles)
  return np.dot(proba_y, utility_x_given_y)
Ejemplo n.º 4
0
def ghq_2d_separable(f, locs=None, Ls=None, pts_and_weights=None, degree=5):
  """Computes an estimate of E[f(x)] using Gauss-Hermite quadrature.

  Args:
    f: The function to estimate E[f(x)] for. Must accept a [data_dim] input 
      point and return a scalar or vector. 
    loc: A vector of shape [data_dim], the means of the Normal distributions to 
      integrate against.
    cov: A PSD matrix of shape [data_dim, data_dim], the covariance matrix  of 
      the normal distributions to integrate against.

  Returns:
    The estimate of E[f(x)], a scalar.
  """
  if pts_and_weights is None:
    n = locs.shape[0]
    # [degree]
    xs_1d, ws_1d = np.polynomial.hermite.hermgauss(degree)
    # [degree^2, 2]
    xs_2d = np.array(list(itertools.product(xs_1d, xs_1d)))
    # [degree^2, n, 2]
    xs_nd = np.stack([xs_2d]*n, axis=1)
    # [degree^2]
    ws_2d = np.prod(np.array(list(itertools.product(ws_1d, ws_1d))), axis=1)
    # [degree^2]
    ws_nd = ws_2d * n
    pts = np.matmul(xs_nd[:,:,np.newaxis,:], np.transpose(Ls[np.newaxis,:,:,:], axes=(0,1,3,2)))
    pts = np.squeeze(pts, axis=2)
    # [degree^2, n, 2]
    pts = pts*np.sqrt(2.) + locs[np.newaxis,:,:]
  else:
    pts, ws_nd = pts_and_weights
    n = pts.shape[1]

  # [degree^2, n, ...]
  gs = np.array([f(pts[i]) for i in range(pts.shape[0])])
  # [degree^2, ...]
  gs = np.sum(gs, axis=1)
  ws_shape = [gs.shape[0]] + [1]*(gs.ndim-1)
  return np.sum(gs * ws_nd.reshape(ws_shape), axis=0)/(n*np.pi), (pts, ws_nd)
 def model(self, batch):
     XL, XH = batch['XL'], batch['XH']
     y = batch['y']
     NL, NH = XL.shape[0], XH.shape[0]
     D = XH.shape[1]
     # set uninformative log-normal priors for low-fidelity kernel
     var_L = sample('kernel_var_L',
                    dist.LogNormal(0.0, 1.0),
                    sample_shape=(1, ))
     length_L = sample('kernel_length_L',
                       dist.LogNormal(0.0, 1.0),
                       sample_shape=(D, ))
     theta_L = np.concatenate([var_L, length_L])
     # set uninformative log-normal priors for high-fidelity kernel
     var_H = sample('kernel_var_H',
                    dist.LogNormal(0.0, 1.0),
                    sample_shape=(1, ))
     length_H = sample('kernel_length_H',
                       dist.LogNormal(0.0, 1.0),
                       sample_shape=(D, ))
     theta_H = np.concatenate([var_H, length_H])
     # prior for rho
     rho = sample('rho', dist.Normal(0.0, 10.0), sample_shape=(1, ))
     # Compute kernels
     K_LL = self.kernel(XL, XL, theta_L) + np.eye(NL) * 1e-8
     K_LH = rho * self.kernel(XL, XH, theta_L)
     K_HH = rho**2 * self.kernel(XH, XH, theta_L) + \
                     self.kernel(XH, XH, theta_H) + np.eye(NH)*1e-8
     K = np.vstack((np.hstack((K_LL, K_LH)), np.hstack((K_LH.T, K_HH))))
     L = cholesky(K, lower=True)
     # Generate latent function
     beta_L = sample('beta_L', dist.Normal(0.0, 1.0))
     beta_H = sample('beta_H', dist.Normal(0.0, 1.0))
     eta_L = sample('eta_L', dist.Normal(0.0, 1.0), sample_shape=(NL, ))
     eta_H = sample('eta_H', dist.Normal(0.0, 1.0), sample_shape=(NH, ))
     beta = np.concatenate([beta_L * np.ones(NL), beta_H * np.ones(NH)])
     eta = np.concatenate([eta_L, eta_H])
     f = np.matmul(L, eta) + beta
     # Bernoulli likelihood
     sample('y', dist.Bernoulli(logits=f), obs=y)
Ejemplo n.º 6
0
    def __getitem__(self, slice_spec):
        """Basic indexing, returns a TTMatrix containing the specified element / slice."""
        d = self.ndim
        if len(slice_spec) != 2 * d:
            raise ValueError('Expected %d indices, got %d' %
                             (2 * d, len(slice_spec)))
        for i in range(d):
            if isinstance(slice_spec[i], slice) != isinstance(
                    slice_spec[d + i], slice):
                raise ValueError(
                    'Elements i_%d and j_%d should be the same type, '
                    'instead: %s and %s.' %
                    (i, i, slice_spec[i], slice_spec[d + i]))
        new_tt_cores = []
        remainder = None
        for i in range(self.ndim):
            curr_core = self.tt_cores[i]
            sliced_core = curr_core[..., :, slice_spec[i],
                                    slice_spec[d + i], :]
            if len(curr_core.shape) != len(sliced_core.shape):
                # These indices are specified exactly and we want to collapse this axis.
                if remainder is None:
                    remainder = sliced_core
                else:
                    remainder = jnp.matmul(remainder, sliced_core)
            else:
                if remainder is not None:
                    # Add reminder from the previous collapsed cores to the current core.
                    sliced_core = jnp.einsum('...ab,...bijd->...aijd',
                                             remainder, sliced_core)
                    remainder = None
                new_tt_cores.append(sliced_core)

        if remainder is not None:
            # The reminder obtained from collapsing the last cores.
            new_tt_cores[-1] = jnp.einsum('...aijb,...bd->...aijd',
                                          new_tt_cores[-1], remainder)
            remainder = None

        return TTMatrix(new_tt_cores)
Ejemplo n.º 7
0
def driven_mu_cov(b, f, theta, dt):
    """
    compute the twisted mean and covariance matrix of twisting parameters

    arguments
        b : jnp.array(N)
            twisting vector
        f : jnp.array(N)
            push vector
        theta : jnp.array(N,N)
            twist matrix
        dt : float
            time increment
    returns
        mu : jnp.array(N)
            twisted mean
        cov : jnp.array(N,N)
            twisted covariance matrix
    """
    mu = jnp.matmul(theta, f - dt*b)
    cov = dt*theta
    return mu, cov
Ejemplo n.º 8
0
def pool_forward(X, W, size=2, stride=2):
    n, d, h, w = X.shape

    h_out = (h - size) // stride + 1
    w_out = (w - size) // stride + 1

    h_out, w_out = int(h_out), int(w_out)

    X_reshaped = X.reshape(n * d, 1, h, w)

    X_col = im2col_indices(X_reshaped, size, size, padding=0, stride=stride)

    n_filter, v, h_filter, w_filter = W.shape
    W_col = W.reshape(n_filter, -1)

    out = jnp.matmul(W_col, X_col)
    out = jnp.mean(out,axis=0)
    out = out.reshape(h_out, w_out, n, d)
    out = jnp.transpose(out, (2, 3, 0, 1))
    out = jnp.array(out, dtype='float32')

    return out
Ejemplo n.º 9
0
  def test_shape_error(self):
    """Some of the examples from the README."""
    with self.assertRaisesRegex(
        TypeError,
        re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")):
      self.CheckShapePolymorphism(
          lambda x, y: x + y,
          input_signature=[tf.TensorSpec([None]),
                           tf.TensorSpec([4])],
          polymorphic_shapes=["(v,)", "(4,)"],
          expected_output_signature=tf.TensorSpec([None]))

    four_ones = np.ones((4,))
    # We get the error even if we use correct actual arguments
    with self.assertRaisesRegex(
        TypeError,
        re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")):
      jax2tf.convert(
          lambda x, y: x + y, polymorphic_shapes=["(v,)", "(4,)"])(four_ones,
                                                                   four_ones)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Shape variable comparison v == 4 is inconclusive")):
      jax2tf.convert(lambda x: jnp.matmul(x, x),
                     polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))

    with self.assertRaisesRegex(TypeError,
                                re.escape("unsupported operand type(s) for *: 'DimVar' and 'int'")):
      jax2tf.convert(lambda x: jnp.reshape(x, np.prod(x.shape)),
                     polymorphic_shapes=["(b, ...)"])(np.ones((3, 4, 5)))

    jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0], np.prod(x.shape[1:]))),
                   polymorphic_shapes=["(b, _, _)"])(np.ones((3, 4, 5)))

    with self.assertRaisesRegex(
        TypeError,
        re.escape("unsupported operand type(s) for /: 'TensorFlowTracer' and 'DimVar'")):
      jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
                     polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
Ejemplo n.º 10
0
  def test_full_range_integer_weights_with_float_scale_should_give_close_output(
      self, weight_prec):
    # If weights are ints (already quantized) with
    # max(abs(weights[:, ch])) == 2**(prec-1)-1 in each channel
    # and if these integer weights are multiplied by a float scale,
    # the resulting error should still be very small (just float rounding).

    num_features = 256
    input_dim = 1024
    inputs = random.uniform(self.rng_key, shape=(2, input_dim))
    model, state = self.init_model_with_1_layer(
        inputs, num_features, weight_prec=weight_prec)
    minval = -2**(weight_prec - 1) + 1
    maxval = 2**(weight_prec - 1) - 1

    full_range_integer_weights = random.randint(self.rng_key,
                                                (input_dim, num_features),
                                                minval, maxval + 1)

    # manually set one value in each output dim of weights to be exactly maxval
    full_range_integer_weights = jax.ops.index_update(
        full_range_integer_weights, jax.ops.index[0, :], maxval)

    float_scale = jax.random.uniform(self.rng_key, (1, num_features))
    state = state.unfreeze()
    state['params']['kernel'] = full_range_integer_weights * float_scale
    state = flax.core.freeze(state)
    outputs = model.apply(state, inputs, padding_mask=None)
    exp_outputs = jnp.matmul(inputs, state['params']['kernel'])
    # TODO(wanglisa): Determine how much noise is expected for following test.
    # We know that the noise should be proportional to the square root of
    # input_dim and inversely proportional to 2**weight_prec.
    # The following tol_const was obtained experimentally and should be derived
    # more systematically.
    tol_const = 8e-04
    onp.testing.assert_allclose(
        outputs,
        exp_outputs,
        rtol=jnp.sqrt(input_dim) * 2**(-weight_prec) * tol_const)
Ejemplo n.º 11
0
    def build(self, sc: OrderedDict):
        gp_matrices = OrderedDict()
        gp_matrices["l_u_beta"] = aux_math.matrix_diag_transform(np.tril(sc["S_u_beta"]), nn.softplus)
        gp_matrices["l_u_gamma"] = aux_math.matrix_diag_transform(np.tril(sc["S_u_gamma"]), nn.softplus)

        # Kernel
        gp_matrices["k_beta_beta"] = self.kernel.matrix(sc["X_u_beta"], sc["X_u_beta"])
        gp_matrices["l_beta_beta"] = scipy.linalg.cholesky(gp_matrices["k_beta_beta"] +
                                                           self.jitter * np.eye(self.n_inducing_beta),
                                                           lower=True)

        k_gamma_gamma = \
            self.kernel.matrix(sc["X_u_gamma"], sc["X_u_gamma"]) + self.jitter * np.eye(self.n_inducing_gamma)

        k_beta_gamma = self.kernel.matrix(sc["X_u_beta"], sc["X_u_gamma"])

        l_beta_inv_k_beta_gamma = scipy.linalg.solve_triangular(gp_matrices["l_beta_beta"], k_beta_gamma,
                                                                lower=True)
        gp_matrices["l_beta_inv_k_beta_gamma"] = l_beta_inv_k_beta_gamma

        c_gamma_gamma = k_gamma_gamma - np.matmul(
            np.transpose(gp_matrices["l_beta_inv_k_beta_gamma"], (0, 2, 1)), gp_matrices["l_beta_inv_k_beta_gamma"])

        gp_matrices["l_gamma_gamma"] = scipy.linalg.cholesky(c_gamma_gamma +
                                                             self.jitter * np.eye(self.n_inducing_gamma), lower=True)

        # U_beta_dists
        gp_matrices["q_u_beta_mean"] = sc["mu_u_beta"]
        gp_matrices["q_u_beta_tril"] = gp_matrices["l_u_beta"]
        gp_matrices["p_u_beta_mean"] = np.zeros([self.ndims_out, self.n_inducing_beta])
        gp_matrices["p_u_beta_tril"] = gp_matrices["l_beta_beta"]

        # U_gamma_dists
        gp_matrices["q_u_gamma_mean"] = sc["mu_u_gamma"]
        gp_matrices["q_u_gamma_tril"] = gp_matrices["l_u_gamma"]
        gp_matrices["p_u_gamma_mean"] = np.zeros([self.ndims_out, self.n_inducing_gamma])
        gp_matrices["p_u_gamma_tril"] = gp_matrices["l_gamma_gamma"]

        return gp_matrices
Ejemplo n.º 12
0
    def __call__(self, x, features, bias=True, kernel_init=None):
        def mu_init(key, shape):
            # Initialization of mean noise parameters (Section 3.2)
            low = -1 / jnp.power(x.shape[0], 0.5)
            high = 1 / jnp.power(x.shape[0], 0.5)
            return jax.random.uniform(key,
                                      minval=low,
                                      maxval=high,
                                      shape=shape)

        def sigma_init(key, shape, dtype=jnp.float32):  # pylint: disable=unused-argument
            # Initialization of sigma noise parameters (Section 3.2)
            return jnp.ones(shape, dtype) * (0.1 / onp.sqrt(x.shape[0]))

        if self.eval_mode:
            # Turn off noise during evaluation
            w_epsilon = onp.zeros(shape=(x.shape[0], features),
                                  dtype=onp.float32)
            b_epsilon = onp.zeros(shape=(features, ), dtype=onp.float32)
        else:
            # Factored gaussian noise in (10) and (11) in Fortunato et al. (2018).
            p = NoisyNetwork.sample_noise(self.rng_key, [x.shape[0], 1])
            q = NoisyNetwork.sample_noise(self.rng_key, [1, features])
            f_p = NoisyNetwork.f(p)
            f_q = NoisyNetwork.f(q)
            w_epsilon = f_p * f_q
            b_epsilon = jnp.squeeze(f_q)

        # See (8) and (9) in Fortunato et al. (2018) for output computation.
        w_mu = self.param('kernel_mu', mu_init, (x.shape[0], features))
        w_sigma = self.param('kernel_sigma', sigma_init,
                             (x.shape[0], features))
        w = w_mu + jnp.multiply(w_sigma, w_epsilon)
        ret = jnp.matmul(x, w)

        b_mu = self.param('bias_mu', mu_init, (features, ))
        b_sigma = self.param('bias_sigma', sigma_init, (features, ))
        b = b_mu + jnp.multiply(b_sigma, b_epsilon)
        return jnp.where(bias, ret + b, ret)
Ejemplo n.º 13
0
def mean(
    gp: NonConjugatePosterior,
    param: dict,
    test_inputs: Array,
    train_inputs: Array,
    train_outputs: Array,
):
    ell, alpha, nu = param["lengthscale"], param["variance"], param["latent"]
    Kff = gram(gp.prior.kernel, train_inputs, param)
    Kfx = cross_covariance(gp.prior.kernel, train_inputs, test_inputs, param)
    Kxx = gram(gp.prior.kernel, test_inputs, param)
    L = jnp.linalg.cholesky(Kff + jnp.eye(train_inputs.shape[0]) * 1e-6)

    A = solve_triangular(L, Kfx.T, lower=True)
    latent_var = Kxx - jnp.sum(jnp.square(A), -2)
    latent_mean = jnp.matmul(A.T, nu)

    lvar = jnp.diag(latent_var)

    moment_fn = predictive_moments(gp.likelihood)
    pred_rv = moment_fn(latent_mean.ravel(), lvar)
    return pred_rv.mean()
Ejemplo n.º 14
0
def _linear_correlate_color(t):
    """Multiply input by sqrt of empirical (ImageNet) color correlation matrix.

  If you interpret t's innermost dimension as describing colors in a
  decorrelated version of the color space (which is a very natural way to
  describe colors -- see discussion in Feature Visualization article) the way
  to map back to normal colors is multiply the square root of your color
  correlations.

  Args:
    t: input whitened color array, with trailing dimension 3.

  Returns:
    t_correlated: RGB color array.
  """
    assert t.shape[-1] == 3
    t_flat = np.reshape(t, [-1, 3])
    color_correlation_normalized = (color_correlation_svd_sqrt /
                                    max_norm_svd_sqrt)
    t_flat = np.matmul(t_flat, color_correlation_normalized.T)
    t_correlated = np.reshape(t_flat, t.shape)
    return t_correlated
Ejemplo n.º 15
0
    def test_shape_error(self):
        """Some of the examples from the README."""
        raise unittest.SkipTest("Failing after fixing Poly unsoundness #4878")
        with self.assertRaisesRegex(
                TypeError,
                re.escape(
                    "add got incompatible shapes for broadcasting: (v,), (4,)")
        ):
            self.CheckShapePolymorphism(
                lambda x, y: x + y,
                input_signature=[tf.TensorSpec([None]),
                                 tf.TensorSpec([4])],
                in_shapes=["(v,)", "(4,)"],
                expected_output_signature=tf.TensorSpec([None]))

        four_ones = np.ones((4, ))
        # We get the error even if we use correct actual arguments
        with self.assertRaisesRegex(
                TypeError,
                re.escape(
                    "add got incompatible shapes for broadcasting: (v,), (4,)")
        ):
            jax2tf.convert(lambda x, y: x + y,
                           in_shapes=["(v,)", "(4,)"])(four_ones, four_ones)

        with self.assertRaisesRegex(
                TypeError,
                re.escape(
                    "dot_general requires contracting dimensions to have the same shape, got [4] and [v]."
                )):
            jax2tf.convert(lambda x: jnp.matmul(x, x),
                           in_shapes=["(v, 4)"])(np.ones((4, 4)))

        # TODO: this is an opportunity to improve the translation, should not error
        with self.assertRaisesRegex(
                TypeError,
                "Only integers, .* tensors are valid indices, got 0"):
            jax2tf.convert(lambda x: jnp.split(x, 2),
                           in_shapes=["(2*v,)"])(four_ones)
Ejemplo n.º 16
0
def parametric(subposteriors, diagonal=False):
    """
    Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

    **References:**

    1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
       Willie Neiswanger, Chong Wang, Eric Xing

    :param list subposteriors: a list in which each element is a collection of samples.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :return: the estimated mean and variance/covariance parameters of the joined posterior
    """
    joined_subposteriors = tree_multimap(lambda *args: np.stack(args),
                                         *subposteriors)
    joined_subposteriors = vmap(
        vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)

    submeans = np.mean(joined_subposteriors, axis=1)
    if diagonal:
        weights = vmap(lambda x: 1 / np.var(x, ddof=1, axis=0))(
            joined_subposteriors)
        var = 1 / np.sum(weights, axis=0)
        normalized_weights = var * weights

        # comparing to consensus implementation, we compute weighted mean here
        mean = np.einsum('ij,ij->j', normalized_weights, submeans)
        return mean, var
    else:
        weights = vmap(lambda x: np.linalg.inv(np.cov(x.T)))(
            joined_subposteriors)
        cov = np.linalg.inv(np.sum(weights, axis=0))
        normalized_weights = np.matmul(cov, weights)

        # comparing to consensus implementation, we compute weighted mean here
        mean = np.einsum('ijk,ik->j', normalized_weights, submeans)
        return mean, cov
Ejemplo n.º 17
0
def model_wo_c(T, T_forecast, x, obs=None):
    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    W = numpyro.sample(name="W",
                       fn=dist.Normal(loc=jnp.zeros((2, 4)),
                                      scale=jnp.ones((2, 4))))
    beta = numpyro.sample(name="beta",
                          fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=0.1))
    z_prev = numpyro.sample(name="z_1",
                            fn=dist.Normal(loc=jnp.zeros(2),
                                           scale=jnp.ones(2)))
    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.0))
    Sigma_lower = jnp.matmul(
        jnp.diag(jnp.sqrt(tau)),
        L_Omega)  # lower cholesky factor of the covariance matrix
    noises = numpyro.sample(
        "noises",
        fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),
        sample_shape=(T + T_forecast, ),
    )
    # Propagate the dynamics forward using jax.lax.scan
    carry = (W, beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, (x, noises), T + T_forecast)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    obs_mean = z_collection[:T, 1]
    pred_mean = z_collection[T:, 1]

    # Sample the observed y (y_obs)
    numpyro.sample(name="y_obs",
                   fn=dist.Normal(loc=obs_mean, scale=sigma),
                   obs=obs)
    numpyro.sample(name="y_pred",
                   fn=dist.Normal(loc=pred_mean, scale=sigma),
                   obs=None)
Ejemplo n.º 18
0
    def _combine_gradients(self, px_grads_list, px_loss):
        """ Combines the per-example gradients into the batch gradient and
            applies the batch gradient transformation given as
            `batch_grad_manipulation_fn`.

        This is the third step of a full update iteration.

        :param px_grads_list: List of transformed per-example gradients as returned
            by `_apply_per_example_gradient_transformations`
        :param px_loss: Array of per-example loss values as output by
            `_compute_per_example_gradients`.
        :returns: tuple consisting of the updated svi state, the loss value for
            the batch and a jax tree of batch gradients per parameter site.
        """

        # get total loss and loss combiner vjp func
        loss_val, loss_combine_vjp = jax.vjp(self.loss.combiner_fn, px_loss)

        # loss_combine_vjp gives us the backward differentiation function
        #   from combined loss to per-example losses. we use it to get the
        #   (1xbatch_size) Jacobian and construct a function that takes
        #   per-example gradients and left-multiplies them with that jacobian
        #   to get the final combined gradient
        loss_jacobian = jnp.reshape(loss_combine_vjp(jnp.array(1.))[0], (1, -1))
        # loss_vjp = lambda px_grads: jnp.sum(jnp.multiply(loss_jacobian, px_grads))
        loss_vjp = lambda px_grads: jnp.matmul(loss_jacobian, px_grads)

        # we map the loss combination vjp func over all secondary dimensions
        #   of gradient sites. This is necessary since some gradient
        #   sites might be matrices in itself (e.g., for NN layers), so a stack
        #   of those would be 3-dimensional and not admittable to jnp.matmul
        loss_vjp = map_over_secondary_dims(loss_vjp)

        # combine gradients for all parameters in the gradient jax tree
        #   according to the loss combination vjp func
        grads_list = tuple(map(loss_vjp, px_grads_list))

        return loss_val, grads_list
Ejemplo n.º 19
0
    def __getitem__(self, slice_spec):
        """Basic indexing, returns a TT containing the specified element / slice.
    
    Examples:
      >>> a = ttax.random.tensor(rng, [2, 3, 4])
      >>> a[1, :, :]
      is a 2D TensorTrain 3 x 4.
      >>> a[1:2, :, :]
      is a 3D TensorTrain 1 x 3 x 4
    """
        if len(slice_spec) != self.ndim:
            raise ValueError('Expected %d indices, got %d' %
                             (self.ndim, len(slice_spec)))
        new_tt_cores = []
        remainder = None
        for i in range(self.ndim):
            curr_core = self.tt_cores[i]
            sliced_core = curr_core[..., :, slice_spec[i], :]
            if len(curr_core.shape) != len(sliced_core.shape):
                # This index is specified exactly and we want to collapse this axis.
                if remainder is None:
                    remainder = sliced_core
                else:
                    remainder = jnp.matmul(remainder, sliced_core)
            else:
                if remainder is not None:
                    # Add reminder from the previous collapsed cores to the current core.
                    sliced_core = jnp.einsum('...ab,...bid->...aid', remainder,
                                             sliced_core)
                    remainder = None
                new_tt_cores.append(sliced_core)

        if remainder is not None:
            # The reminder obtained from collapsing the last cores.
            new_tt_cores[-1] = jnp.einsum('...aib,...bd->...aid',
                                          new_tt_cores[-1], remainder)
            remainder = None
        return TT(new_tt_cores)
Ejemplo n.º 20
0
def logmarglike_lineargaussianmodel_onetransfer_jit(M_T, y, yinvvar,
                                                    logyinvvar):
    """
    Fit linear model to one Gaussian data set, with no (=uniform) prior on the linear components.

    Parameters
    ----------
    y, yinvvar, logyinvvar : ndarray (n_pix_y)
        data and data inverse variances.
        Zeros will be ignored.
    M_T : ndarray (n_components, n_pix_y)
        design matrix of linear model

    Returns
    -------
    logfml : ndarray scalar
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components, n_components)
        Parameter covariance

    """
    log2pi = np.log(2.0 * np.pi)
    nt = np.shape(M_T)[-2]
    ny = np.count_nonzero(yinvvar)
    M = np.transpose(M_T)  # (n_pix_y, n_components)
    Myinv = M * yinvvar[:, None]  # (n_pix_y, n_components)
    Hbar = np.matmul(M_T, Myinv)  #  (n_components, n_components)
    etabar = np.sum(Myinv * y[:, None], axis=0)  # (n_components)
    theta_map = np.linalg.solve(Hbar, etabar)  # (n_components)
    theta_cov = np.linalg.inv(Hbar)  # (n_components, n_components)
    logdetH = np.sum(logyinvvar)  # scalar
    xi1 = -0.5 * (ny * log2pi - logdetH + np.sum(y * y * yinvvar))  # scalar
    sign, logdetHbar = np.linalg.slogdet(Hbar)
    xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map))
    logfml = xi1 - xi2
    return logfml, theta_map, theta_cov
Ejemplo n.º 21
0
    def _test_lanczos_dynamic_vs_static_once(self, seed=0):
        def _safe_div(x1, x2):
            return jnp.where(jnp.logical_and(x1 == 0, x2 == 0), x1, x1 / x2)

        dim = 5
        key = jax.random.PRNGKey(seed)
        h_tmp = jax.random.normal(key, shape=(dim, dim))
        h = h_tmp + jnp.transpose(h_tmp)
        hv = lambda v: jnp.matmul(h, v)
        tr1, vecs1 = eigenvector_utils.lanczos_alg(hv,
                                                   dim,
                                                   dim,
                                                   key,
                                                   dynamic_unroll=True)
        tr2, vecs2 = eigenvector_utils.lanczos_alg(hv,
                                                   dim,
                                                   dim,
                                                   key,
                                                   dynamic_unroll=False)
        assert jnp.max(jnp.abs(_safe_div(tr1 - tr2, tr2))) < 1e-4, (
            f'Seed {seed}: large relative error in Lanczos tridiag')
        assert jnp.max(jnp.abs(_safe_div(vecs1 - vecs2, vecs2))) < 1e-4, (
            f'Seed {seed}: large relative error in Lanczos vecs')
Ejemplo n.º 22
0
    def mll(
        params: dict,
        training: Dataset,
        priors: dict = {"latent": tfd.Normal(loc=0.0, scale=1.0)},
        static_params: dict = None,
    ):
        x, y = training.X, training.y
        n = training.n
        params = transform(params)
        if static_params:
            params = concat_dictionaries(params, transform(static_params))
        link = link_function(gp.likelihood)
        gram_matrix = gram(gp.prior.kernel, x, params)
        gram_matrix += I(n) * jitter
        L = jnp.linalg.cholesky(gram_matrix)
        F = jnp.matmul(L, params["latent"])
        rv = link(F)
        ll = jnp.sum(rv.log_prob(y))

        priors = prior_checks(gp, priors)
        log_prior_density = evaluate_prior(params, priors)
        constant = jnp.array(-1.0) if negative else jnp.array(1.0)
        return constant * (ll + log_prior_density)
Ejemplo n.º 23
0
 def evaluate(self, params):
     """
     Evaluates the circuit and returns its unitary operator as a matrix. You must
     assemble the circuit before evaluating it. Uses a JIT compiler; for optimal
     performance, do not modify the circuit nor its constituent gates between calls
     to this method.
     
     Parameters
     ----------
     params: 1D ndarray
         A vector of parameters for the gates.
     
     Returns
     -------
     2D ndarray. A unitary matrix representing the circuit.
     """
     assert len(params) == self.n_params
     mat = jnp.eye(self.regInfo.dim)
     for layer in self.gates:
         g = layer.gate(params)
         # reverse since operators act right-to-left
         mat = jnp.matmul(g, mat)
     return mat
Ejemplo n.º 24
0
 def model(n, y, mu, tlist, AV, t0left, t0right):
     t0 = numpyro.sample('t0',
                         numpyro.distributions.Uniform(t0left, t0right))
     if Tau == 0:
         light_curve = numpyro.distributions.Normal(t0, scale=Sigma)
         pl = numpyro.primitives.deterministic(
             'pl',
             jnp.exp(light_curve.log_prob(tlist)) / n * mu)
     else:
         pl = numpyro.primitives.deterministic(
             'pl',
             Co * (1. - jax.scipy.special.erf(
                 (Alpha * Sigma**2 -
                  (tlist - t0)) / (math.sqrt(2.) * Sigma))) *
             jnp.exp(-Alpha * (tlist - t0)) / n * mu)
     A = numpyro.sample('A', mNormal(pl, mix0sigma, 1., gsigma / gmu))
     with numpyro.plate('observations', len(y)):
         obs = numpyro.sample('obs',
                              numpyro.distributions.Normal(jnp.matmul(
                                  AV, A),
                                                           scale=std),
                              obs=y)
     return obs
Ejemplo n.º 25
0
def get_dYdt(
    P: Union[float, np.ndarray],
    R: float,
    gas_info: GasInfo,
    nasa_poly: NASAPolynomials,
    kinetics_coeffs: KineticsCoeffs,
    kinetics_data: KineticsData,
    t: Union[float, np.ndarray],
    state_vec: np.ndarray,
) -> np.ndarray:
    """
    get dYdt where dYdt[0] is dTdt and rest is dYdt
    """
    T = state_vec[0]
    Y = state_vec[1:]
    T = np.clip(T, a_min=200.0, a_max=1e5)
    Y = np.clip(Y, a_min=0.0, a_max=1.0)
    mean_molecular_weight = get_mean_molecular_weight(
        Y, gas_info.molecular_weights)
    density_mass = P / R / T * mean_molecular_weight
    X = Y2X(Y, gas_info.molecular_weights, mean_molecular_weight)
    C = Y2C(Y, gas_info.molecular_weights, density_mass)
    cp_data = calculate_cp(T, X, R, nasa_poly)
    enthalpy_data = calculate_enthalpy(T, X, R, nasa_poly)
    entropy_data = calculate_entropy(T, P, X, R, nasa_poly)
    cp_mass = cp_data.cp_mole / mean_molecular_weight
    Kc = get_equilibirum_constants(T, P, R, entropy_data.sdivR,
                                   enthalpy_data.hdivRT, gas_info)
    kf = get_forward_rate_constants(T, R, C, kinetics_coeffs, kinetics_data)
    kr = get_reverse_rate_constants(kf, Kc, kinetics_data.is_reversible)
    production_rates = get_production_rates(kf, kr, C, gas_info)
    Ydot = (production_rates.wdot * gas_info.molecular_weights) / density_mass
    Tdot = -(np.matmul(enthalpy_data.partial_molar_enthalpies,
                       production_rates.wdot)) / (density_mass * cp_mass)
    dYdt = np.hstack((Tdot, Ydot))
    # dYdt = delete_small_numbers(dYdt)
    return dYdt
    def update_curvature_matrix_estimate(self, info: _BlockInfo,
                                         batch_size: int,
                                         ema_old: Union[float, jnp.ndarray],
                                         ema_new: Union[float, jnp.ndarray],
                                         pmap_axis_name: str) -> None:
        del pmap_axis_name
        (x, ), (dy, ) = info["inputs"], info["outputs_tangent"]
        utils.check_first_dim_is_batch_size(batch_size, x, dy)

        grads = list()
        if self._has_scale:
            # Scale gradients
            scale_shape = info["params"][0].shape
            full_scale_shape = (1, ) * (len(x.shape) -
                                        len(scale_shape)) + scale_shape
            axis = [
                i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0
            ]
            d_scale = jnp.sum(x * dy, axis=axis)
            d_scale = d_scale.reshape([batch_size, -1])
            grads.append(d_scale)

        if self._has_shift:
            # Shift gradients
            shift_shape = info["params"][1].shape
            full_shift_shape = (1, ) * (len(x.shape) -
                                        len(shift_shape)) + shift_shape
            axis = [
                i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0
            ]
            d_shift = jnp.sum(dy, axis=axis)
            d_shift = d_shift.reshape([batch_size, -1])
            grads.append(d_shift)

        grads = jnp.concatenate(grads, axis=1)
        factor_update = jnp.matmul(grads.T, grads) / batch_size
        self.factor.update(factor_update, ema_old, ema_new)
Ejemplo n.º 27
0
def test_dense_mass(kernel_cls, rho):
    num_warmup, num_samples = 20000, 10000

    true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

    def model():
        numpyro.sample(
            "x",
            dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov))

    if kernel_cls is HMC or kernel_cls is NUTS:
        kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model, dense_mass=True)

    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0))

    mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt
    if kernel_cls is HMC or kernel_cls is NUTS:
        mass_matrix_sqrt = mass_matrix_sqrt[("x", )]
    mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt))
    estimated_cov = jnp.linalg.inv(mass_matrix)
    assert_allclose(estimated_cov, true_cov, rtol=0.10)

    samples = mcmc.get_samples()["x"]
    assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50)
    assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05)
    assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]),
                    jnp.array(rho),
                    atol=0.20)
    assert_allclose(jnp.var(samples, axis=0),
                    jnp.array([10.0, 0.1]),
                    rtol=0.20)
Ejemplo n.º 28
0
def projective_inverse_warp(img,
                            transform,
                            mask_value,
                            intrinsics,
                            depth,
                            bilinear=True):
    """Inverse warp a source image to the target image plane based on projection.
    Args:
        img: the source image [batch, height_s, width_s, 3]
        transform: 4x4 transformation matrix
        mask_value: uint8 maks value of rgb/a mask value
        depth: depth map of the target image [batch, height_t, width_t]
        intrinsics: camera intrinsics [batch, 3, 3]
        bilinear: bool use bilinear or nearest sampling.
    Returns:
        Source image inverse warped to the target image plane [batch, height_t,
        width_t, 3]
    """
    height, width, _ = img.shape
    # Construct pixel grid coordinates
    pixel_coords = meshgrid(height, width)
    # Convert pixel coordinates to the camera frame
    cam_coords = pixel2cam(depth, pixel_coords, intrinsics)
    # Construct a 4x4 intrinsic matrix
    filler = jnp.array([[0.0, 0.0, 0.0, 1.0]])
    # filler = jnp.tile(filler, [batch, 1, 1])
    intrinsics = jnp.concatenate([intrinsics, jnp.zeros([3, 1])], axis=1)
    intrinsics = jnp.concatenate([intrinsics, filler], axis=0)
    # Get a 4x4 transformation matrix from 'target' camera frame to 'source'
    # pixel frame.
    proj_tgt_cam_to_src_pixel = jnp.matmul(intrinsics, transform)
    src_pixel_coords = cam2pixel(cam_coords, proj_tgt_cam_to_src_pixel)

    output_img = jnp.where(bilinear,
                           bilinear_sampler(img, src_pixel_coords, mask_value),
                           nearest_sampler(img, src_pixel_coords, mask_value))
    return output_img.astype('uint8')
Ejemplo n.º 29
0
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - np.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, np.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return np.matmul(cholesky, np.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = np.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Ejemplo n.º 30
0
    def apply(self, x, features, bias=True, kernel_init=None):
        #print("NoisyNetwork")
        def sample_noise(shape):
            #tf.random_normal
            noise = jax.random.normal(random.PRNGKey(0), shape)
            ##noise = jax.random.normal(shape)
            return noise

        def f(x):
            return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5))

        # Initializer of \mu and \sigma

        def mu_init(key, shape):
            low = -1 * 1 / jnp.power(x.shape[1], 0.5)
            high = 1 * 1 / jnp.power(x.shape[1], 0.5)
            return onp.random.uniform(low, high, shape)

        def sigma_init(key, shape, dtype=jnp.float32):
            return jnp.ones(shape, dtype) * (0.1 / onp.sqrt(x.shape[1]))

        # Sample noise from gaussian
        p = sample_noise([x.shape[1], 1])
        q = sample_noise([1, features])
        f_p = f(p)
        f_q = f(q)
        w_epsilon = f_p * f_q
        b_epsilon = jnp.squeeze(f_q)
        w_mu = self.param('kernel', (x.shape[1], features), mu_init)
        w_sigma = self.param('kernell', (x.shape[1], features), sigma_init)
        w = w_mu + jnp.multiply(w_sigma, w_epsilon)
        ret = jnp.matmul(x, w)

        b_mu = self.param('bias', (features, ), mu_init)
        b_sigma = self.param('biass', (features, ), sigma_init)
        b = b_mu + jnp.multiply(b_sigma, b_epsilon)
        return jnp.where(bias, ret + b, ret)