示例#1
0
def test_dense_mass(kernel_cls, rho):
    num_warmup, num_samples = 20000, 10000

    true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

    def model():
        numpyro.sample(
            "x",
            dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov))

    if kernel_cls is HMC or kernel_cls is NUTS:
        kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model, dense_mass=True)

    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0))

    mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt
    if kernel_cls is HMC or kernel_cls is NUTS:
        mass_matrix_sqrt = mass_matrix_sqrt[("x", )]
    mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt))
    estimated_cov = jnp.linalg.inv(mass_matrix)
    assert_allclose(estimated_cov, true_cov, rtol=0.10)

    samples = mcmc.get_samples()["x"]
    assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50)
    assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05)
    assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]),
                    jnp.array(rho),
                    atol=0.20)
    assert_allclose(jnp.var(samples, axis=0),
                    jnp.array([10.0, 0.1]),
                    rtol=0.20)
示例#2
0
文件: project.py 项目: choltz95/imax
def meshgrid(height, width, is_homogeneous=True):
    """Construct a 2D meshgrid.
    Args:
        height: height of the grid
        width: width of the grid
        is_homogeneous: whether to return in homogeneous coordinates
    Returns:
        x,y grid coordinates [batch, 2 (3 if homogeneous), height, width]
    """
    x_t = jnp.matmul(
        jnp.ones(shape=[height, 1]),
        jnp.transpose(jnp.expand_dims(jnp.linspace(-1.0, 1.0, width), 1),
                      [1, 0]))
    y_t = jnp.matmul(jnp.expand_dims(jnp.linspace(-1.0, 1.0, height), 1),
                     jnp.ones(shape=[1, width]))
    x_t = (x_t + 1.0) * 0.5 * jnp.array(width - 1, dtype='float32')
    y_t = (y_t + 1.0) * 0.5 * jnp.array(height - 1, dtype='float32')
    if is_homogeneous:
        ones = jnp.ones_like(x_t)
        coords = jnp.stack([x_t, y_t, ones], axis=0)
    else:
        coords = jnp.stack([x_t, y_t], axis=0)
    # coords = jnp.tile(jnp.expand_dims(coords, 0), [batch, 1, 1, 1])
    return coords
示例#3
0
def train_step(optimizer, state, batch, prev_metrics, learning_rate_fn):
    """Single training step."""
    images, labels = batch['image'], batch['label']
    if FLAGS.transpose_images:
        images = jnp.transpose(images, [3, 0, 1, 2])
    images = normalize_images(images)
    if images.shape[1:] != (224, 224, 3):
        raise ValueError('images has shape {}'.format(images.shape))

    def loss_fn(model):
        with nn.stateful(state) as new_state:
            logits = model(images)
        loss = cross_entropy_loss(logits, labels, FLAGS.label_smoothing)
        return loss / logits.shape[0], (new_state, logits)

    lr = learning_rate_fn(optimizer.state[0].step)
    new_optimizer, _, (new_state,
                       logits) = optimizer.optimize(loss_fn, learning_rate=lr)
    if FLAGS.train_metrics:
        metrics = compute_metrics(logits, labels)
        metrics = jax.tree_multimap(jnp.add, prev_metrics, metrics)
    else:
        metrics = {}
    return new_optimizer, new_state, metrics
示例#4
0
    def net_fn(inputs):
        """Function representing Rainbow Q-network."""
        inputs = jnp.transpose(inputs, [0, 3, 1, 2])
        inputs = dqn_torso_delta()(inputs.reshape(-1, 84, 84, 1))
        inputs = inputs.reshape(-1, 4, 1568)
        current = inputs[:, 1:, :]
        prev = jax.lax.stop_gradient(inputs[:, :-1, :])
        inputs = jax.numpy.concatenate([current, current - prev], axis=1)
        inputs = hk.Flatten()(inputs)
        inputs = linear(256)(inputs)
        inputs = layer_norm(inputs)

        # Advantage head.
        advantage = noisy_linear(512, noisy_weight_init,
                                 with_bias=True)(inputs)
        advantage = jax.nn.relu(advantage)
        advantage = noisy_linear(num_actions * num_atoms,
                                 noisy_weight_init,
                                 with_bias=False)(advantage)
        advantage = jnp.reshape(advantage, (-1, num_actions, num_atoms))

        # Value head.
        value = noisy_linear(512, noisy_weight_init, with_bias=True)(inputs)
        value = jax.nn.relu(value)
        value = noisy_linear(num_atoms, noisy_weight_init,
                             with_bias=False)(value)
        value = jnp.reshape(value, (-1, 1, num_atoms))

        # Q-distribution and values.
        q_logits = value + advantage - jnp.mean(
            advantage, axis=-2, keepdims=True)
        assert q_logits.shape[1:] == (num_actions, num_atoms)
        q_dist = jax.nn.softmax(q_logits)
        q_values = jnp.sum(q_dist * support, axis=2)
        q_values = jax.lax.stop_gradient(q_values)
        return C51NetworkOutputs(q_logits=q_logits, q_values=q_values)
示例#5
0
    def ntk_fun(x1, x2, params):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.
    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        j1 = jac_fn(params, x1)

        if x2 is None:
            j2 = j1
        else:
            j2 = jac_fn(params, x2)

        ntk = sum_and_contract(j1, j2)
        # TODO(schsam): If we care, this will not work if the output is not of
        # shape [n, output_dim].
        return np.transpose(ntk, (0, 2, 1, 3))
示例#6
0
    def _get_2d_latent_grid(self, paths_x):

        num_points_grid = 30

        def scan_fn(carry, paths):
            x, drift, diffusion, index = carry
            time = index * self.config["delta_t"] - 0.5
            max_x = np.amax(paths_x[:, :, 0])
            min_x = np.amin(paths_x[:, :, 0])
            max_y = np.amax(paths_x[:, :, 1])
            min_y = np.amin(paths_x[:, :, 1])
            xx, yy = np.meshgrid(np.linspace(min_x, max_x, num_points_grid),
                                 np.linspace(min_y, max_y, num_points_grid))
            temp = np.transpose(np.vstack([xx.reshape(-1), yy.reshape(-1)]))

            gp_matrices, temp_drift_function, temp_diffusion_function = self.model.build(
                self.model.model_vars())
            temp_drift = temp_drift_function(temp, time)
            temp_diffusion = np.linalg.det(temp_diffusion_function(temp, time))

            x = ops.index_add(x, ops.index[index], temp)
            drift = ops.index_add(drift, ops.index[index], temp_drift)
            diffusion = ops.index_add(diffusion, ops.index[index],
                                      temp_diffusion)
            index += 1

            return (x, drift, diffusion, index), np.array([0.])

        x_grid = np.zeros((paths_x.shape[1], num_points_grid**2, 2))
        drift_grid = np.zeros((paths_x.shape[1], num_points_grid**2, 2))
        diffusion_grid = np.zeros((paths_x.shape[1], num_points_grid**2))
        (x_grid, drift_grid, diffusion_grid,
         index), _ = lax.scan(scan_fn, (x_grid, drift_grid, diffusion_grid, 0),
                              np.transpose(paths_x, (1, 0, 2)))

        return x_grid, drift_grid, diffusion_grid
示例#7
0
def _flatten_batch_dimensions(
        k: np.ndarray,
        is_parallel: bool,
        discard_axis: Optional[int] = None) -> np.ndarray:
    """Takes a kernel that has been evaluated in batches and flattens."""

    if discard_axis is not None:
        if not is_parallel:
            k = np.take(k, 0, axis=discard_axis)
            return np.reshape(k, (-1, ) + k.shape[2:])

        if discard_axis == 1:
            return np.reshape(k, (k.shape[0] * k.shape[1], ) + k.shape[2:])

        return k[0]

    else:
        if is_parallel:
            return np.reshape(k, (k.shape[0] * k.shape[1], ) + k.shape[2:])

        k = np.transpose(k, (0, 2, 1, 3) + tuple(range(4, k.ndim)))
        return np.reshape(k,
                          (k.shape[0] * k.shape[1], k.shape[2] * k.shape[3]) +
                          k.shape[4:])
示例#8
0
    def test_periodic_general_wrapped_vs_unwrapped(self, spatial_dimension,
                                                   dtype):
        key = random.PRNGKey(0)

        eye = np.eye(spatial_dimension, dtype=dtype)

        tol = 1e-13
        if dtype is f32:
            tol = 2e-5

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_R, split_T = random.split(key, 3)

            dT = random.normal(split_T, (spatial_dimension, spatial_dimension),
                               dtype=dtype)
            T = eye + dT + np.transpose(dT)

            R = random.uniform(split_R, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            R0 = R
            unwrapped_R = R

            displacement, shift = space.periodic_general(T)
            _, unwrapped_shift = space.periodic_general(T, wrapped=False)

            displacement = space.map_product(displacement)

            for _ in range(SHIFT_STEPS):
                key, split = random.split(key)
                dR = random.normal(split, (PARTICLE_COUNT, spatial_dimension),
                                   dtype=dtype)
                R = shift(R, dR)
                unwrapped_R = unwrapped_shift(unwrapped_R, dR)
                self.assertAllClose(displacement(R, R0),
                                    displacement(unwrapped_R, R0))
            assert not (np.all(unwrapped_R > 0) and np.all(unwrapped_R < 1))
    def __call__(self, inputs, Q_inputs=None):
        dimensionality = self.dimensionality
        # are queries generated by a different set of inputs? (for output stack)
        if Q_inputs is None:
            Q_inputs = inputs

        Q = hk.Linear(dimensionality)(Q_inputs)
        K = hk.Linear(dimensionality)(inputs)
        V = hk.Linear(dimensionality)(inputs)

        attention_scale = jnp.sqrt(dimensionality)

        attention_weights = jnp.matmul(Q, jnp.transpose(
            K, axes=[0, 2, 1])) / attention_scale
        if self.causal_mask:  # can't attend to later times
            mask = np.tril(np.ones(attention_weights.shape[-2:]))
            mask = np.expand_dims(mask, axis=0)
            attention_weights = attention_weights * mask - 1e10 * (1 - mask)

        attention_weights = jax.nn.softmax(attention_weights, axis=-1)

        outputs = jnp.matmul(attention_weights, V)

        return outputs
示例#10
0
 def setup(self):
     """Initialize P, L, U, s"""
     # W = PL(U + s)
     # Based on https://github.com/openai/glow/blob/master/model.py#L485
     dim = self.input_dim
     # Sample random rotation matrix
     q, _ = np.linalg.qr(jax.random.normal(self.rng, (dim, dim)),
                         mode="complete")
     p, l, u = jax.scipy.linalg.lu(q)
     # Fixed Permutation (non-trainable)
     self.P = p
     self.P_inv = jax.scipy.linalg.inv(p)
     # Init value from LU decomposition
     L_init = l
     U_init = np.triu(u, k=1)
     s = np.diag(u)
     self.sign_s = np.sign(s)
     S_log_init = np.log(np.abs(s))
     self.l_mask = np.tril(np.ones((dim, dim)), k=-1)
     self.u_mask = np.transpose(self.l_mask)
     # Define trainable variables
     self.L = self.param("L", lambda k, sh: L_init, (dim, dim))
     self.U = self.param("U", lambda k, sh: U_init, (dim, dim))
     self.log_s = self.param("log_s", lambda k, sh: S_log_init, (dim, ))
示例#11
0
    def ntk_fn(x1, x2, params, keys=None, **apply_fn_kwargs):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.
      keys: None or a PRNG key or a tuple of PRNG keys or a (2, 2) array and
        dtype uint32. If `key == None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then x1
        and x2 share the same PRNG key; else x1 and x2 use two different PRNG
        keys.
      **apply_fn_kwargs: keyword arguments passed to `apply_fn`.

    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        key1, key2 = _read_keys(keys)

        f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
        jac_fn1 = jacobian(f1)
        j1 = jac_fn1(params)
        if x2 is None:
            j2 = j1
        else:
            f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs)
            jac_fn2 = jacobian(f2)
            j2 = jac_fn2(params)

        ntk = sum_and_contract(j1, j2)
        # TODO: If we care, this will not work if the output is not of
        # shape [n, output_dim].
        return np.transpose(ntk, (0, 2, 1, 3))
示例#12
0
    def sample(self,
               batchSize,
               key,
               L,
               hiddenSize=10,
               depth=1,
               inputDim=2,
               actFun=nn.elu,
               initScale=1.0,
               logProbFactor=0.5):
        """sampler
        """

        rnnCell = RNNCellStack.shared(hiddenSize=hiddenSize,
                                      outDim=inputDim,
                                      actFun=actFun,
                                      initScale=initScale,
                                      name="myCell")

        outputs = jnp.asarray(np.zeros((batchSize, L, L)))

        state = jnp.zeros((batchSize, depth, hiddenSize))

        def rnn_cell(carry, x):
            newCarry, logits = jax.vmap(rnnCell)(carry[0], carry[1])
            sampleOut = jax.random.categorical(x, logits)
            sample = jax.nn.one_hot(sampleOut, inputDim)
            logProb = jnp.sum(nn.log_softmax(logits) * sample, axis=1)
            return (newCarry, sample), (jnp.nan_to_num(logProb,
                                                       nan=-35), sampleOut)

        keys = jax.random.split(key, L)
        _, res = jax.lax.scan(rnn_cell, (state, jnp.zeros(
            (batchSize, inputDim))), keys)

        return jnp.transpose(res[1])  #, 0.5 * jnp.sum(res[0], axis=0)
示例#13
0
    def _apply_lse_kernel_one_dimension(self, dimension, f, g, eps, vec=None):
        """Helper function to permute axis & apply the kernel on a single slice."""
        indices = np.arange(self.grid_dimension)
        indices[dimension], indices[0] = 0, dimension
        f, g = jnp.transpose(f, indices), jnp.transpose(g, indices)
        centered_cost = (
            f[:, jnp.newaxis, ...] + g[jnp.newaxis, ...] - jnp.expand_dims(
                self.cost_matrices[dimension],
                axis=tuple(range(2, 1 + self.grid_dimension)))) / eps

        if vec is not None:
            vec = jnp.transpose(vec, indices)
            softmax_res, softmax_sgn = jax.scipy.special.logsumexp(
                centered_cost, b=vec, axis=1, return_sign=True)
            return eps * jnp.transpose(softmax_res, indices), jnp.transpose(
                softmax_sgn, indices)
        else:
            softmax_res = jax.scipy.special.logsumexp(centered_cost, axis=1)
            return eps * jnp.transpose(softmax_res, indices), None
        # x = params[0][0].reshape(5, 5, 10)
        # # print(np.transpose(x))
        # print(np.transpose(x).swapaxes(1, 2))
        # break

        train_loss = loss(params, (train_images, train_labels))
        test_loss = loss(params, (test_images, test_labels))
        print("Training set loss {}".format(train_loss))
        print("Test set loss {}".format(test_loss))

        train_loss_arr.append(train_loss)
        test_loss_arr.append(test_loss)
        test_class_loss_arr.append(1-test_acc)
        print()
        
        kernels = np.transpose(params[0][0].reshape(5, 5, 10)).swapaxes(1, 2)

        # params[0][0].shape = (5, 5, 1, 10)    -> Input edge weights
        # params[4][1].shape = (10, )           -> Output layer
        eigv0 = onp.asarray(get_eigenvalues(kernels[0]))
        eigv1 = onp.asarray(get_eigenvalues(kernels[1]))
        eigv2 = onp.asarray(get_eigenvalues(kernels[2]))
        eigv3 = onp.asarray(get_eigenvalues(kernels[3]))
        eigv4 = onp.asarray(get_eigenvalues(kernels[4]))
        eigv5 = onp.asarray(get_eigenvalues(kernels[5]))
        eigv6 = onp.asarray(get_eigenvalues(kernels[6]))
        eigv7 = onp.asarray(get_eigenvalues(kernels[7]))
        eigv8 = onp.asarray(get_eigenvalues(kernels[8]))
        eigv9 = onp.asarray(get_eigenvalues(kernels[9]))

        eigenvalues0.append(np.sort(eigv0))
示例#15
0
 def fun(x):
   return np.transpose(x, (2, 0, 1))
示例#16
0
 def conv_var(x):
     x = _conv_var_3d(x, filter_shape_nngp, strides_nngp, padding)
     if x is not None:
         x = np.transpose(x, (0, 2, 1))
     x = _affine(x, W_std, b_std)
     return x
示例#17
0
 def f(x):
   # x: [b1, b2, d1, d2] (a batch of matrices)
   # res: [b1, b2, d1, d1]
   return jnp.matmul(x, jnp.transpose(x, axes=(0, 1, 3, 2)))
示例#18
0
 def contract(x, y):
     param_count = int(np.prod(x.shape[2:]))
     x = np.reshape(x, x.shape[:2] + (param_count, ))
     y = np.reshape(y, y.shape[:2] + (param_count, ))
     return np.dot(x, np.transpose(y, (0, 2, 1)))
示例#19
0
def _inputs_to_kernel(x1, x2, use_pooling, compute_ntk):
    """Transforms (batches of) inputs to a `Kernel`.

  This is a private method. Docstring and example are for internal reference.

   The kernel contains the empirical covariances between different inputs and
     their entries (pixels) necessary to compute the covariance of the Gaussian
     Process corresponding to an infinite Bayesian or gradient-flow-trained
     neural network.

   The smallest necessary number of covariance entries is tracked. For example,
     all networks are assumed to have i.i.d. weights along the channel / feature
     / logits dimensions, hence covariance between different entries along these
     dimensions is known to be 0 and is not tracked.

  Args:
    x1: a 2D `np.ndarray` of shape `[batch_size_1, n_features]` (dense
      network) or 4D of shape `[batch_size_1, height, width, channels]`
      (conv-nets).
    x2: an optional `np.ndarray` with the same shape as `x1` apart
      from possibly different leading batch size. `None` means
      `x2 == x1`.
    use_pooling: a boolean, indicating whether pooling will be used somewhere in
      the model. If so, more covariance entries need to be tracked. Is set
      automatically based on the network topology. Specifically, is set to
      `False` if a `serial` or `parallel` networks contain a `Flatten` layer
      and no pooling layers (`AvgPool` or `GlobalAvgPool`). Has no effect for
      non-convolutional models.
    compute_ntk: a boolean, `True` to compute both NTK and NNGP kernels,
        `False` to only compute NNGP.

    Example:
      ```python
          >>> x = np.ones((10, 32, 16, 3))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 32, 16, 16)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 16)
          >>> x1 = np.ones((10, 128))
          >>> x2 = np.ones((20, 128))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).ntk
          None
      ```

  Returns:
    a `Kernel` object.
  """
    x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64))
    var1 = _get_variance(x1)

    if x2 is None:
        x2 = x1
        var2 = None
    else:
        if x1.shape[1:] != x2.shape[1:]:
            raise ValueError(
                '`x1` and `x2` are expected to be batches of'
                ' inputs with the same shape (apart from the batch size),'
                ' got %s and %s.' % (str(x1.shape), str(x2.shape)))

        x2 = x2.astype(xla_bridge.canonicalize_dtype(np.float64))
        var2 = _get_variance(x2)

    if use_pooling and x1.ndim == 4:
        x2 = np.expand_dims(x2, -1)
        nngp = np.dot(x1, x2) / x1.shape[-1]
        nngp = np.transpose(np.squeeze(nngp, -1), (0, 3, 1, 4, 2, 5))

    elif x1.ndim == 4 or x1.ndim == 2:
        nngp = _batch_uncentered_covariance(x1, x2)

    else:
        raise ValueError('Inputs must be 2D or 4D `np.ndarray`s of shape '
                         '`[batch_size, n_features]` or '
                         '`[batch_size, height, width, channels]`, '
                         'got %s.' % str(x1.shape))

    ntk = 0. if compute_ntk else None
    is_gaussian = False
    is_height_width = True
    return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
示例#20
0
    r = p + np.array([
        fbm(p + 4 * q + np.array([1.7, 9.2]), seed),
        fbm(p + 4 * q + np.array([8.3, 2.8]), seed)
    ])
    s = np.array([
        fbm(p + 4 * q + r, seed),
        fbm(p + 4 * q + r + np.array([1.3, 5.2]), seed)
    ])
    return s


# additional params

shift = np.array([3, 2])
img = np.array(cv2.imread('images/lena.jpg'))
img = np.transpose(img, (1, 0, 2))
N, M = img.shape[0], img.shape[1]
grid = np.stack(np.meshgrid(np.arange(N), np.arange(M)), -1)


@jit
def render_one(seed):
    xys = np.stack(np.meshgrid(np.linspace(0, 1, N), np.linspace(0, 1, M)), -1)
    flat = np.reshape(xys, (-1, 2))
    seed = np.ones_like(flat) * seed
    perturb = vmap(compose2dout)(flat + shift, seed)
    perturb = np.reshape(perturb, (N, M, 2))
    xys = xys + 0.25 * perturb
    xys = np.clip(xys, 0., 1. - 1e-9)
    grid = (xys * N).astype(int)
    pimg = img[grid[..., 0], grid[..., 1]]
示例#21
0
def oei_arrays(geom, basis, charges):
    """
    Build one electron integral arrays (overlap, kinetic, and potential integrals)
    """
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    nprim = coeffs.shape[0]
    max_am = jnp.max(ams)
    A_vals = jnp.zeros(2 * max_am + 1)

    # Save various AM distributions for indexing
    # Obtain all possible primitive quartet index combinations
    primitive_duets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim))

    with loops.Scope() as s:
        s.oei = jnp.zeros((3, nbf, nbf))
        s.a = 0  # center A angular momentum iterator
        s.b = 0  # center B angular momentum iterator

        for prim_duet in s.range(primitive_duets.shape[0]):
            p1, p2 = primitive_duets[prim_duet]
            coef = coeffs[p1] * coeffs[p2]
            aa, bb = exps[p1], exps[p2]
            atom1, atom2 = atoms[p1], atoms[p2]
            am1, am2 = ams[p1], ams[p2]
            A, B = geom[atom1], geom[atom2]
            ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]

            gamma = aa + bb
            prefactor = jnp.exp(-aa * bb * jnp.dot(A - B, A - B) / gamma)
            P = (aa * A + bb * B) / gamma
            # Maximum angular momentum: hard coded
            #max_am = 3 # f function support
            # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi.
            # We need 2+max_am since kinetic requires incrementing angluar momentum by +2
            PA_pow = jnp.power(
                jnp.broadcast_to(P - A, (max_am + 3, 3)).T,
                jnp.arange(max_am + 3))
            PB_pow = jnp.power(
                jnp.broadcast_to(P - B, (max_am + 3, 3)).T,
                jnp.arange(max_am + 3))

            # For potential integrals, we need the difference between
            # the gaussian product center P and ALL atoms in the molecule,
            # and then take all possible powers up to 2*max_am.
            # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
            # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
            P_minus_geom = jnp.broadcast_to(P, geom.shape) - geom
            Pgeom_pow = jnp.power(
                jnp.transpose(
                    jnp.broadcast_to(
                        P_minus_geom,
                        (2 * max_am + 1, geom.shape[0], geom.shape[1])),
                    (1, 2, 0)), jnp.arange(2 * max_am + 1))
            # All possible jnp.dot(P-atom,P-atom)
            rcp2 = jnp.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
            # All needed (and unneeded, for am < max_am) boys function evaluations
            boys_arg = jnp.broadcast_to(rcp2 * gamma,
                                        (2 * max_am + 1, geom.shape[0]))
            boys_nu = jnp.tile(jnp.arange(2 * max_am + 1),
                               (geom.shape[0], 1)).T
            boys_eval = boys(boys_nu, boys_arg)

            s.a = 0
            for _ in s.while_range(lambda: s.a < dims[p1]):
                s.b = 0
                for _ in s.while_range(lambda: s.b < dims[p2]):
                    # Gather angular momentum and index
                    la, ma, na = angular_momentum_combinations[s.a + ld1]
                    lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                    # To only create unique indices, need to have separate indices arrays for i and j.
                    i = indices[p1] + s.a
                    j = indices[p2] + s.b
                    # Compute one electron integrals and add to appropriate index
                    overlap_int = overlap(la, ma, na, lb, mb, nb, aa, bb,
                                          PA_pow, PB_pow, prefactor) * coef
                    kinetic_int = kinetic(la, ma, na, lb, mb, nb, aa, bb,
                                          PA_pow, PB_pow, prefactor) * coef
                    potential_int = potential(la, ma, na, lb, mb, nb, aa, bb,
                                              PA_pow, PB_pow, Pgeom_pow,
                                              boys_eval, prefactor, charges,
                                              A_vals) * coef
                    s.oei = jax.ops.index_add(
                        s.oei, ([0, 1, 2], [i, i, i], [j, j, j]),
                        (overlap_int, kinetic_int, potential_int))

                    s.b += 1
                s.a += 1
    S, T, V = s.oei[0], s.oei[1], s.oei[2]
    return S, T, V
示例#22
0
    def apply(
        self,
        x,
        num_classes=1000,
        train=False,
        resnet=None,
        patches=None,
        hidden_size=None,
        transformer=None,
        representation_size=None,
        classifier="gap",
    ):

        n, h, w, c = x.shape

        # Embed the grid or patches of the grid.
        fh, fw = patches.size
        gh, gw = h // fh, w // fw
        if hidden_size:  # We can merge s2d+emb into a single conv; it's the same.
            x = nn.Conv(
                x,
                hidden_size,
                (fh, fw),
                strides=(fh, fw),
                padding="VALID",
                name="embedding",
            )
        else:
            # This path often results in excessive padding.
            x = jnp.reshape(x, [n, gh, fh, gw, fw, c])
            x = jnp.transpose(x, [0, 1, 3, 2, 4, 5])
            x = jnp.reshape(x, [n, gh, gw, -1])

        # Here, x is a grid of embeddings.

        # (Possibly partial) Transformer.
        if transformer is not None:
            n, h, w, c = x.shape
            x = jnp.reshape(x, [n, h * w, c])

            # If we want to add a class token, add it here.
            if classifier == "token":
                cls = self.param("cls", (1, 1, c), nn.initializers.zeros)
                cls = jnp.tile(cls, [n, 1, 1])
                x = jnp.concatenate([cls, x], axis=1)

            x = Encoder(x, train=train, name="Transformer", **transformer)

        if classifier == "token":
            x = x[:, 0]
        elif classifier == "gap":
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

        if representation_size is not None:
            x = nn.Dense(x, representation_size, name="pre_logits")
            x = nn.tanh(x)
        else:
            x = IdentityLayer(x, name="pre_logits")

        x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros)
        return x
示例#23
0
    def __init__(self, S, X, ctgInd, amaOpts, consOpts):
        # NOTE S should already reshaped [Nlvl,Ni,nPix]

        self.bJax = amaOpts.bJax
        if self.bJax:
            # NOTE OVERWRITTING np
            self.rng = jxrandom.PRNGKey(123)  #
            np.asarray(S)

        else:
            self.rng = 123

        self.consOpts = consOpts

        #cdef const int    self.nPix, self.xPix, self.yPix self.nStim, self.Nf, self.nX, self.Ni, self.nLvl
        self.S = S
        self.Nf = int(amaOpts.Nf)  # also known as Nq
        self.Nlvl = int(np.max(ctgInd) + 1)
        self.nStim = int(ctgInd.shape[0])
        self.Ni = int(self.nStim / self.Nlvl)  # num stim per lvl

        #
        self.nPix = int(S.shape[2])
        self.xPix = int(np.sqrt(self.nPix))  #TODO
        self.yPix = int(np.sqrt(self.nPix))  #TODO

        #cdef const np.ndarray[cnp.double_t, ndim=3] self.S
        # S  [Nlvl, Ni, nPix]

        #cdef const np.ndarray[cnp.double_t, ndim=2] self.X # TODO dim?
        #cdef const np.ndarray[cnp.int_t,    ndim=2] self.labels
        #cdef const np.ndarray[cnp.int_t,    ndim=1] self.ctgInd

        self.X = np.transpose(X)
        self.ctgInd = np.array([ctgInd])
        self.labels = np.reshape(X[ctgInd], (self.Nlvl, self.Ni))

        self.get_Ac()
        self.get_filter_ind()

        #cdef const double self.alpha, self.s0, self.Nf, rmax
        #cdef const int errType, normType, bRectify
        self.alpha = amaOpts.alpha
        self.s0 = amaOpts.s0
        self.rmax = amaOpts.rmax
        self.errType = amaOpts.errType
        self.normType = amaOpts.normType
        self.bRectify = amaOpts.bRectify
        self.bMean = amaOpts.bMean
        self.bPrint = amaOpts.bPrint
        self.bNormF = consOpts.bNormF
        self.bPrintCons = consOpts.bPrint

        self.eps = np.sqrt(np.finfo(float).eps)

        ## in/out variable
        #cdef np.ndarray[cnp.double_t, ndim=2] self.f
        #cdef double self.mCost

        ## variable init
        #cdef np.ndarray[cnp.double_t, ndim=3] self.R
        #cdef np.ndarray[cnp.double_t, ndim=3] self.r
        #cdef np.ndarray[cnp.double_t, ndim=3] self.var
        #cdef np.ndarray[cnp.double_t, ndim=3] self.PPstm
        #cdef np.ndarray[cnp.double_t, ndim=3] self.Y

        #np.empty((self.Nlvl,self,Ni,self.Nf),  dtype=double) self.R
        #np.empty((self.Nlvl,self,Ni,self.Nf),  dtype=double) self.r
        #np.empty((self.Nlvl,self,Ni,self.Nf),  dtype=double) self.var
        #np.empty((self.Nlvl,self,Ni,self.Nlvl),  dtype=double) self.PPstm
        #np.empty((self.Nlvl,self,Ni,self.Nlvl),  dtype=double) self.Y

        self.x = np.array([0, 0])

        self.R = np.empty([self.Nlvl, self.Ni, self.Nf])
        self.r = np.empty([self.Nlvl, self.Ni,
                           self.Nf])  # mean rsp foreach stm
        self.var = np.empty([self.Nlvl, self.Ni,
                             self.Nf])  # var  rsp foreach stm
        self.PPstm = np.empty([self.Nlvl, self.Ni,
                               self.Nlvl])  # Posterior foreach stm @ each lvl
        self.Y = np.empty([self.Nlvl, self.Ni,
                           self.Nlvl])  # Posterior foreach stm @ each lvl

        #if self.bJax:
        if self.bJax:
            self.jac = value_and_grad(self.objective_fun_core)
示例#24
0
 def swapaxes(x):
   transposed_axes = list(range(len(x.shape)))
   transposed_axes[axis] = 0
   transposed_axes[0] = axis
   return jnp.transpose(x, axes=transposed_axes)
    def __call__(
        self,
        pixel_values: jnp.ndarray,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Examples:

        ```python
        >>> from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # load output tokenizer
        >>> tokenizer_output = GPT2Tokenizer.from_pretrained("gpt2")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "gpt2"
        ... )

        >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values

        >>> # use GPT2's eos_token as the pad as well as eos token
        >>> model.config.eos_token_id = model.config.decoder.eos_token_id
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> # generation
        >>> sequences = model.generate(pixel_values, num_beams=4, max_length=12).sequences

        >>> captions = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # prepare encoder inputs

        # `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format.
        # Currently, we assume this holds for all Flax vision models, and perform a transpose here.
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # prepare decoder inputs
        if decoder_input_ids is None:
            raise ValueError("`decoder_input_ids` can't be `None`.")
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        return self.module.apply(
            {"params": params or self.params},
            pixel_values=jnp.array(pixel_values, dtype=self.dtype),
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask,
                                             dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
        )
示例#26
0
def make_canonical_transform(
        n_xyz: jnp.ndarray, ca_xyz: jnp.ndarray,
        c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Returns translation and rotation matrices to canonicalize residue atoms.

  Note that this method does not take care of symmetries. If you provide the
  atom positions in the non-standard way, the N atom will end up not at
  [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
  need to take care of such cases in your code.

  Args:
    n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
    ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
    c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.

  Returns:
    A tuple (translation, rotation) where:
      translation is an array of shape [batch, 3] defining the translation.
      rotation is an array of shape [batch, 3, 3] defining the rotation.
    After applying the translation and rotation to all atoms in a residue:
      * All atoms will be shifted so that CA is at the origin,
      * All atoms will be rotated so that C is at the x-axis,
      * All atoms will be shifted so that N is in the xy plane.
  """
    assert len(n_xyz.shape) == 2, n_xyz.shape
    assert n_xyz.shape[-1] == 3, n_xyz.shape
    assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (n_xyz.shape,
                                                        ca_xyz.shape,
                                                        c_xyz.shape)

    # Place CA at the origin.
    translation = -ca_xyz
    n_xyz = n_xyz + translation
    c_xyz = c_xyz + translation

    # Place C on the x-axis.
    c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)]
    # Rotate by angle c1 in the x-y plane (around the z-axis).
    sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
    cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
    zeros = jnp.zeros_like(sin_c1)
    ones = jnp.ones_like(sin_c1)
    # pylint: disable=bad-whitespace
    c1_rot_matrix = jnp.stack([
        jnp.array([cos_c1, -sin_c1, zeros]),
        jnp.array([sin_c1, cos_c1, zeros]),
        jnp.array([zeros, zeros, ones])
    ])

    # Rotate by angle c2 in the x-z plane (around the y-axis).
    sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2)
    cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt(1e-20 + c_x**2 + c_y**2 +
                                                  c_z**2)
    c2_rot_matrix = jnp.stack([
        jnp.array([cos_c2, zeros, sin_c2]),
        jnp.array([zeros, ones, zeros]),
        jnp.array([-sin_c2, zeros, cos_c2])
    ])

    c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
    n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T

    # Place N in the x-y plane.
    _, n_y, n_z = [n_xyz[:, i] for i in range(3)]
    # Rotate by angle alpha in the y-z plane (around the x-axis).
    sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
    cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
    n_rot_matrix = jnp.stack([
        jnp.array([ones, zeros, zeros]),
        jnp.array([zeros, cos_n, -sin_n]),
        jnp.array([zeros, sin_n, cos_n])
    ])
    # pylint: enable=bad-whitespace

    return (translation,
            jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]))
    def encode(
        self,
        pixel_values: jnp.ndarray,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "gpt2"
        ... )

        >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
        >>> encoder_outputs = model.encode(pixel_values)
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format.
        # Currently, we assume this holds for all Flax vision models, and perform a transpose here.
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _encoder_forward(module, pixel_values, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(pixel_values, **kwargs)

        outputs = self.module.apply(
            {"params": params or self.params},
            pixel_values=jnp.array(pixel_values, dtype=self.dtype),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )

        if return_dict:
            outputs = FlaxBaseModelOutput(
                last_hidden_state=outputs.last_hidden_state,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

        return outputs
示例#28
0
 def trans(x):
     perm = tuple(range(x.ndim))
     perm = (ax, ) + tuple(np.delete(perm, ax))
     return jnp.transpose(x, perm)
示例#29
0
def point_to_coordinate(pt, num_fragments=6):
    """ Takes points from dihedral_to_point and sequentially converts them into the coordinates of a 3D structure.

        Reconstruction is done in parallel, by independently reconstructing num_fragments fragments and then 
        reconstituting the chain at the end in reverse order. The core reconstruction algorithm is NeRF, based on 
        DOI: 10.1002/jcc.20237 by Parsons et al. 2005. The parallelized version is described in 
        https://www.biorxiv.org/content/early/2018/08/06/385450.

    Args:
        pt: [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS]

    Opts:
        num_fragments: Number of fragments to reconstruct in parallel. If None, the number is chosen adaptively

    Returns:
            [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS] 
    """

    # compute optimal number of fragments if needed
    s = pt.shape[0]  # NUM_STEPS x NUM_DIHEDRALS
    if num_fragments is None:
        num_fragments = np.cast(np.sqrt(np.cast(s, dtype='float32')),
                                dtype='int32')

    # initial three coordinates (specifically chosen to eliminate need for extraneous matmul)
    Triplet = collections.namedtuple('Triplet', 'a, b, c')
    batch_size = pt.shape[1]  # BATCH_SIZE
    init_mat = np.array(
        [[-np.sqrt(1.0 / 2.0), np.sqrt(3.0 / 2.0), 0], [-np.sqrt(2.0), 0, 0],
         [0, 0, 0]],
        dtype='float32')
    init_coords = Triplet(*[
        np.reshape(
            np.tile(row[np.newaxis], np.stack([num_fragments *
                                               batch_size, 1])),
            [num_fragments, batch_size, NUM_DIMENSIONS]) for row in init_mat
    ])
    # NUM_DIHEDRALS x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]

    # pad points to yield equal-sized fragments
    r = ((num_fragments - (s % num_fragments)) % num_fragments
         )  # (NUM_FRAGS x FRAG_SIZE) - (NUM_STEPS x NUM_DIHEDRALS)
    pt = np.pad(pt, [[0, r], [0, 0], [0, 0]
                     ])  # [NUM_FRAGS x FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
    pt = np.reshape(pt,
                    [num_fragments, -1, batch_size, NUM_DIMENSIONS
                     ])  # [NUM_FRAGS, FRAG_SIZE,  BATCH_SIZE, NUM_DIMENSIONS]
    pt = np.transpose(
        pt,
        [1, 0, 2, 3])  # [FRAG_SIZE, NUM_FRAGS,  BATCH_SIZE, NUM_DIMENSIONS]

    # extension function used for single atom reconstruction and whole fragment alignment
    def extend(tri, pt, multi_m):
        """
        Args:
            tri: NUM_DIHEDRALS x [NUM_FRAGS/0,         BATCH_SIZE, NUM_DIMENSIONS]
            pt:                  [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
            multi_m: bool indicating whether m (and tri) is higher rank. pt is always higher rank; what changes is what the first rank is.
        """

        bc = normalize(tri.c - tri.b,
                       axis=-1)  # [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMS]
        n = normalize(np.cross(tri.b - tri.a, bc),
                      axis=-1)  # [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMS]
        if multi_m:  # multiple fragments, one atom at a time.
            m = np.transpose(
                np.stack([bc, np.cross(n, bc), n]),
                [1, 2, 3, 0])  # [NUM_FRAGS,   BATCH_SIZE, NUM_DIMS, 3 TRANS]
        else:  # single fragment, reconstructed entirely at once.
            s = onp.pad(
                pt.shape, [[0, 1]],
                constant_values=3)  # FRAG_SIZE, BATCH_SIZE, NUM_DIMS, 3 TRANS
            m = np.transpose(np.stack([bc, np.cross(n, bc), n]),
                             [1, 2, 0])  # [BATCH_SIZE, NUM_DIMS, 3 TRANS]
            m = np.reshape(np.tile(m, [s[0], 1, 1]),
                           s)  # [FRAG_SIZE, BATCH_SIZE, NUM_DIMS, 3 TRANS]
        coord = np.squeeze(
            np.matmul(m, np.expand_dims(pt, 3)),
            axis=3) + tri.c  # [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMS]
        return coord

    # loop over FRAG_SIZE in NUM_FRAGS parallel fragments, sequentially generating the coordinates for each fragment across all batches
    coords = np.zeros_like(
        pt)  # FRAG_SIZE x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]

    def loop_extend(i,
                    dt):  # FRAG_SIZE x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]
        tri, coords = dt
        coord = extend(tri, pt[i], True)
        return (Triplet(tri.b, tri.c, coord), index_update(coords, i, coord))

    tris, coords_pretrans = fori_loop(0, pt.shape[0], loop_extend,
                                      (init_coords, coords))
    # NUM_DIHEDRALS x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS],
    # FRAG_SIZE x [NUM_FRAGS, BATCH_SIZE, NUM_DIMENSIONS]
    # loop over NUM_FRAGS in reverse order, bringing all the downstream fragments in alignment with current fragment
    coords_pretrans = np.transpose(
        coords_pretrans,
        [1, 0, 2, 3])  # [NUM_FRAGS, FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
    n = coords_pretrans.shape[0]  # NUM_FRAGS
    fs = coords_pretrans.shape[1]  # FRAG_SIZE

    res_array = np.zeros(
        (n * coords_pretrans.shape[1], *coords_pretrans.shape[2:]))

    def loop_trans(j, coords):
        i = (n - j) - 1
        transformed_coords = extend(Triplet(*[di[i] for di in tris]), coords,
                                    False)
        return dynamic_update_slice(transformed_coords, coords_pretrans[i],
                                    [fs * i] + [0] *
                                    (transformed_coords.ndim - 1))

    res_array = index_update(res_array, index[fs * (n - 1):fs * n],
                             coords_pretrans[-1])
    coords_trans = fori_loop(
        0, n, loop_trans, res_array
    )  # coords_pretrans[-1]) # [NUM_FRAGS x FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]

    # lose last atom and pad from the front to gain an atom ([0,0,0], consistent with init_mat), to maintain correct atom ordering
    coords = np.pad(
        coords_trans[:s - 1],
        [[1, 0], [0, 0], [0, 0]
         ])  # [NUM_STEPS x NUM_DIHEDRALS, BATCH_SIZE, NUM_DIMENSIONS]

    return coords
示例#30
0
    def fit_model(self, particle_weights, particles):
        """Fits a binary model using weighted particles.

    The model will be a sparse lower triangular logistic regression as in
    Procedure 5 from
    https://arxiv.org/pdf/1101.6037.pdf

    Args:
      particle_weights: a np.array<float> of simplicial weights
      particles: np.array<bool>[groups, n_patients]

    Returns:
     A np.array<float>[n_patients, n_patients] model.
    """
        n_groups, n_patients = particles.shape
        model = np.zeros((n_patients, n_patients))
        eps = 1e-5
        # keep track of basic stats
        xbar = (1 - eps) * np.sum(particle_weights[:, np.newaxis] * particles,
                                  axis=0) + eps * 0.5
        xcov = np.matmul(np.transpose(particles),
                         particle_weights[:, np.newaxis] * particles)
        xb1mxb = xbar * (1.0 - xbar)
        cov_matrix = (xcov -
                      xbar[:, np.newaxis] * xbar[np.newaxis, :]) / np.sqrt(
                          xb1mxb[:, np.newaxis] * xb1mxb[np.newaxis, :])

        # TODO(oliviert): turn this into parameters.
        eps = 0.01
        delta = 0.05
        indices_model = np.logical_and(xbar > eps, xbar < 1 - eps)
        indices_single = np.logical_or(xbar <= eps, xbar >= 1 - eps)
        # no regression for first variable
        indices_single = jax.ops.index_update(indices_single, 0, True)
        indices_model = jax.ops.index_update(indices_model, 0, False)

        # look for sparse blocks of variables to regress on others
        if self.sparse_model_lr:
            regressed, regressor = np.where(np.abs(cov_matrix) > delta)
            dic_regressors = collections.defaultdict(list)
            for i, j in zip(regressed, regressor):
                if j < i:
                    dic_regressors[i].append(j)

        # Where there exists cross-correlation we estimate a model
        # TODO(cuturi) : switch to predefined number of regressors (i.e. top k
        # corellated variables. From kth patient we can then jit this regression.
        for i in np.where(indices_model)[0]:
            if self.sparse_model_lr:
                indices_i = dic_regressors[i]
            else:
                indices_i = list(range(i))

            regressors = np.concatenate(
                (particles[:, indices_i], np.ones((n_groups, 1))), axis=-1)
            y = particles[:, i]

            # initialize loop
            # TODO(oliviert): turn those hard coded constants into parameters
            b = np.zeros((regressors.shape[1], ))
            diff = 1e10
            iterations = 0
            reg = .05

            while diff > 1e-2 and iterations < 30:
                iterations += 1
                regressorsb = np.dot(regressors, b)
                p = jax.scipy.special.expit(regressorsb)
                q = p * (1 - p)
                cov = np.matmul(
                    particle_weights[np.newaxis, :] * q[np.newaxis, :] *
                    np.transpose(regressors), regressors)
                cov = cov + reg * np.eye(len(indices_i) + 1)
                c = np.dot(
                    np.transpose(regressors) * particle_weights[np.newaxis, :],
                    q * regressorsb + y - p)
                bnew = np.linalg.solve(cov, c)
                diff = np.sum((bnew - b)**2)
                b = bnew
            # add constant, to list of indices, to be stored in [i,i]
            indices_i.append(i)
            # update line i of model
            model = jax.ops.index_update(
                model, jax.ops.index[i, np.asarray(indices_i)], bnew)

        # Where there are no cross-correlations, or posterior is very peaked,
        # we flip randomly and indvidually
        v = np.zeros((n_patients, ))
        v = jax.ops.index_update(v, jax.ops.index[indices_single],
                                 jax.scipy.special.logit(xbar[indices_single]))
        model = model + np.diag(v)
        self.model = model