def _encode_bow(self, bow: jnp.ndarray) -> jnp.ndarray:
        """Encode the bag-of-words into tensors that can be used by the transormer.

    Args:
      bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector.

    Returns:
      embeddings: [batch_size, bow_n_tokens, bow_embedding_dim] tensor.
    """
        batch_size = bow.shape[0]
        bow = bow.astype(jnp.float32)

        # [B, D * n]
        embeddings = hk.Linear(self._bow_embedding_dim *
                               self._bow_n_tokens)(bow)
        embeddings = transformer_block.layer_norm(jax.nn.gelu(embeddings))
        return jnp.reshape(
            embeddings,
            [batch_size, self._bow_n_tokens, self._bow_embedding_dim])
Example #2
0
def model(design_matrix: jnp.ndarray, outcome: jnp.ndarray = None) -> None:
    """
    Model definition: Log odds of making a purchase is a linear combination
    of covariates. Specify a Normal prior over regression coefficients.
    :param design_matrix: Covariates. All categorical variables have been one-hot
        encoded.
    :param outcome: Binary response variable. In this case, whether or not the
        customer made a purchase.
    """

    beta = numpyro.sample(
        'coefficients',
        dist.MultivariateNormal(loc=0.,
                                covariance_matrix=jnp.eye(
                                    design_matrix.shape[1])))
    logits = design_matrix.dot(beta)

    with numpyro.plate('data', design_matrix.shape[0]):
        numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=outcome)
Example #3
0
def if_points_inside_any_polygon(points: np.ndarray, scene: JaxScene) -> np.ndarray:
    """Checks if points is inside any polygon

    Args:
        points (np.ndarray): point or batch of points to check
        scene (JaxScene): scene instance

    Returns:
        np.ndarray: bool result
    """

    if points.ndim == 1:
        points = points.reshape(1, -1)
    signs = batch_segment_normal_projection_sign(points, scene.segments)
    signs = signs < 0
    signs = signs.reshape(len(points), len(scene.polygons), 3)  # all polygons have 3 segments
    is_inside = np.all(signs, axis=-1)
    result = np.any(is_inside, axis=-1)
    return result
Example #4
0
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)

        # q(z|x) = N(mean(x), covariance(x))
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        variational_distrib = distrax.MultivariateNormalDiag(loc=mean,
                                                             scale_diag=stddev)
        z = variational_distrib.sample(seed=hk.next_rng_key())

        # p(x|z) = \Prod Bernoulli(logits(z))
        logits = Decoder(self._hidden_size, self._output_shape)(z)
        likelihood_distrib = distrax.Independent(
            distrax.Bernoulli(logits=logits),
            reinterpreted_batch_ndims=len(
                self._output_shape))  # 3 non-batch dims

        # Generate images from the likelihood
        image = likelihood_distrib.sample(seed=hk.next_rng_key())

        return VAEOutput(variational_distrib, likelihood_distrib, image)
Example #5
0
def symmetrize(
    data: jnp.ndarray,
    row: jnp.ndarray,
    col: jnp.ndarray,
    ncols: Optional[int] = None,
):
    """
    Get data of `(A + A.T) / 2` assuming `A` has symmetric sparsity.

    Args:
        data: values of `A`
        row: row indices of `A`
        col: col indices of `A`
        ncols: number of columns of `A`

    Returns:
        `sym_data`, same shape and dtype as `data`.
    """
    perm = reorder_perm(row=col, col=row, ncols=ncols)
    return (data + data.take(perm)) / 2
Example #6
0
def logpdf(x: jnp.ndarray, kappa: float, mu: jnp.ndarray) -> jnp.ndarray:
    """Log-density function of the power spherical distribution.

    Args:
        x: The set of points at which to evaluate the power spherical density.
        kappa: Concentration parameter.
        mu: Mean direction on the sphere. The dimensionality of the sphere is
            determined from this paramter.

    Returns:
        out: The log-density of the power spherical distribution with the
            specified concentration and mean parameter at the desired points.

    """
    d = mu.size
    alpha = (d - 1.) / 2. + kappa
    beta = (d - 1.) / 2.
    lognormalizer = ((alpha + beta) * jnp.log(2.) + beta * jnp.log(jnp.pi) +
                     jspsp.gammaln(alpha) - jspsp.gammaln(alpha + beta))
    unlogprob = kappa * jnp.log(1. + x.dot(mu))
    return unlogprob - lognormalizer
Example #7
0
def learning_schedule(
    global_step: jnp.ndarray,
    base_learning_rate: float,
    total_steps: int,
    warmup_steps: int,
    use_schedule: bool,
) -> float:
    """Cosine learning rate scheduler."""
    # Compute LR & Scaled LR
    if not use_schedule:
        return base_learning_rate
    warmup_learning_rate = (global_step.astype(jnp.float32) /
                            int(warmup_steps) * base_learning_rate
                            if warmup_steps > 0 else base_learning_rate)

    # Cosine schedule after warmup.
    decay_learning_rate = _cosine_decay(global_step - warmup_steps,
                                        total_steps - warmup_steps,
                                        base_learning_rate)
    return jnp.where(global_step < warmup_steps, warmup_learning_rate,
                     decay_learning_rate)
Example #8
0
def squeeze(x: np.ndarray, axis: Union[None, int, Tuple[int,
                                                        ...]]) -> np.ndarray:
    """`np.squeeze` analog working with 0-sized axes."""
    if isinstance(axis, int):
        axis = (axis, )

    non_zero_axes = tuple()
    shift = 0

    for a in sorted(axis):
        if x.shape[a - shift] == 0:
            new_shape = x.shape[:a] + x.shape[a + 1:]
            if size_at(new_shape) == 0:
                x = x.reshape(new_shape)
            else:
                x = np.zeros(new_shape, x.dtype)

            shift += 1
        else:
            non_zero_axes += (a - shift, )

    return np.squeeze(x, non_zero_axes)
Example #9
0
    def apply_lse_kernel(self,
                         f: jnp.ndarray,
                         g: jnp.ndarray,
                         eps: float,
                         vec: Optional[jnp.ndarray] = None,
                         axis: int = 0):
        """Applies grid kernel in log space. See notes in parent class for use case.

    Reshapes vector inputs below as grids, applies kernels onto each slice, and
    then expands the outputs as vectors.

    More implementation details in https://arxiv.org/pdf/1708.01955.pdf

    Args:
      f: jnp.ndarray, a vector of potentials
      g: jnp.ndarray, a vector of potentials
      eps: float, regularization strength
      vec: jnp.ndarray, if needed, a vector onto which apply the kernel weighted
        by f and g.
      axis: axis (0 or 1) along which summation should be carried out.

    Returns:
      a vector, the result of kernel applied in lse space onto vec.
    """
        f, g = jnp.reshape(f, self.grid_size), jnp.reshape(g, self.grid_size)

        if vec is not None:
            vec = jnp.reshape(vec, self.grid_size)

        if axis == 0:
            f, g = g, f

        for dimension in range(self.grid_dimension):
            g, vec = self._apply_lse_kernel_one_dimension(
                dimension, f, g, eps, vec)
            g -= jnp.where(jnp.isfinite(f), f, 0)
        if vec is None:
            vec = jnp.array(1.0)
        return g.ravel(), vec.ravel()
def mean_squared_logarithmic_error(y_true: jnp.ndarray,
                                   y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean squared logarithmic error between labels and predictions.

    ```python
    loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)
    ```

    Usage:

    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_squared_logarithmic_error(y_true, y_pred)

    assert loss.shape == (2,)

    first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0)
    second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0)
    assert jnp.array_equal(loss, jnp.mean(jnp.square(first_log - second_log), axis=-1))
    ```

    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

    Returns:
        Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`.
    """

    y_true = y_true.astype(y_pred.dtype)
    first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0)
    second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0)

    return jnp.mean(jnp.square(first_log - second_log), axis=-1)
Example #11
0
def mean_squared_error(y_true: jnp.ndarray,
                       y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean squared error between labels and predictions.
    
    After computing the squared distance between the inputs, the mean value over
    the last dimension is returned.
    
    ```python
    loss = mean(square(y_true - y_pred), axis=-1)
    ```

    Usage:
    
    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_squared_error(y_true, y_pred)

    assert loss.shape == (2,)

    assert jnp.array_equal(loss, jnp.mean(jnp.square(y_true - y_pred), axis=-1))
    ```
    
    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
    
    Returns:
        Mean squared error values. shape = `[batch_size, d0, .. dN-1]`.
    """

    y_true = y_true.astype(y_pred.dtype)

    return jnp.mean(jnp.square(y_pred - y_true), axis=-1)
Example #12
0
def bgrl_loss(
    first_online_predictions: jnp.ndarray,
    second_target_projections: jnp.ndarray,
    second_online_predictions: jnp.ndarray,
    first_target_projections: jnp.ndarray,
    symmetrize: bool,
    valid_mask: jnp.ndarray,
) -> Tuple[jnp.ndarray, LogsDict]:
    """Implements BGRL loss."""
    first_side_node_loss = jnp.sum(jnp.square(
        _l2_normalize(first_online_predictions, axis=-1) -
        _l2_normalize(second_target_projections, axis=-1)),
                                   axis=-1)
    if symmetrize:
        second_side_node_loss = jnp.sum(jnp.square(
            _l2_normalize(second_online_predictions, axis=-1) -
            _l2_normalize(first_target_projections, axis=-1)),
                                        axis=-1)
        node_loss = first_side_node_loss + second_side_node_loss
    else:
        node_loss = first_side_node_loss
    loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6)
    return loss, dict(bgrl_loss=loss)
def plot_final_bounds(x: np.ndarray,
                      y: np.ndarray,
                      xstar: np.ndarray,
                      bounds: np.ndarray,
                      data_xstar: np.ndarray,
                      data_ystar: np.ndarray,
                      coeff_2sls: np.ndarray = None,
                      x_kiv: np.ndarray = None,
                      y_kiv: np.ndarray = None) -> plt.Figure:
  fig = plt.figure()
  plt.scatter(x, y, **data_kwargs)
  plt.plot(xstar, bounds[:, 0], 'g--x', label="lower", lw=2, markersize=10)
  plt.plot(xstar, bounds[:, 1], 'r--x', label="upper", lw=2, markersize=10)
  if data_xstar is not None and data_ystar is not None:
    if data_ystar.ndim > 1:
      data_ystar = data_ystar.mean(0)
    plt.plot(data_xstar, data_ystar, label=f"$E[Y | do(x^*)]$", lw=2)
  if coeff_2sls is not None:
    tt = np.linspace(np.min(x), np.max(x), 10)
    y_2sls = coeff_2sls[0] + coeff_2sls[1] * tt
    plt.plot(tt, y_2sls, ls='dotted', c="tab:purple", lw=2, label="2sls")
  if x_kiv is not None and y_kiv is not None:
    plt.plot(x_kiv, y_kiv, ls='dashdot', c="tab:olive", lw=2, label="KIV")

  def get_limits(vals):
    lo = np.min(vals)
    hi = np.max(vals)
    extend = (hi - lo) / 15.
    return lo - extend, hi + extend

  plt.xlim(get_limits(x))
  plt.ylim(get_limits(y))
  plt.xlabel('x')
  plt.ylabel('y')
  plt.title("Lower and upper bound on actual effect")
  plt.legend()
  return fig
Example #14
0
def recall(
    y_true: jnp.ndarray,
    y_pred: jnp.ndarray,
    threshold: jnp.ndarray,
    class_id: jnp.ndarray,
    sample_weight: jnp.ndarray,
    true_positives: ReduceConfusionMatrix,
    false_negatives: ReduceConfusionMatrix,
) -> jnp.ndarray:

    # TODO: class_id behavior
    y_pred = (y_pred > threshold).astype(jnp.float32)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    true_positives = true_positives(
        y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
    )
    false_negatives = false_negatives(
        y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
    )

    return jnp.nan_to_num(jnp.divide(true_positives, true_positives + false_negatives))
Example #15
0
def get_masked_array(x: np.ndarray,
                     mask_constant: float = None) -> MaskedArray:
  """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask.

  The mask returned is a boolean `np.ndarray` with masked indices having `True`.

  Args:
    x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as
      `(masked_x, mask)` and pass it through.
    mask_constant: an optional `float`, the value in inputs to be considered as
      masked (e.g. padding in a batch of sentences). `None` means no masking.
      Can also be `np.nan`, `np.inf` etc.

  Returns:
    A `MaskedArray` of `(masked_x, boolean_mask)`.
  """

  if x is None:
    mask_mat = None

  elif isinstance(x, MaskedArray):
    x, mask_mat, _, _ = x.astuple()

  elif isinstance(x, np.ndarray):
    if mask_constant is None:
      mask_mat = None
    else:
      mask_mat = lax.cond(np.isnan(mask_constant),
                          np.isnan,
                          lambda x: x == mask_constant,
                          x)
  else:
    raise TypeError(x, type(x))

  x = mask(x, mask_mat)
  return MaskedArray(x, mask_mat)  # pytype: disable=wrong-arg-count
Example #16
0
def generate_positions(rng: rjax.PRNGKey,
                       genpcls: dict,
                       pos0: np.ndarray,
                       pcl: GenParticle = None):
    """ Generates position according to the momentum direction and lifetime
        Traverses decay tree recursively """
    if pcl is None:
        genpcls['root']['pos'] = Position.from_ndarray(pos0)
        pcl = genpcls['root']['gpcl']
    else:
        genpcls[pcl.name]['pos'] = Position.from_ndarray(pos0)

    for ch in pcl.children:
        particle = genpcls[ch.name]['pcl']
        if particle.lifetime > 0.0001 and particle.lifetime < 1:
            mom = genpcls[ch.name]['mom']
            nevt = mom.size
            rng, key = rjax.split(rng)
            time = particle.lifetime * rjax.exponential(key, (nevt, 1))
            # TODO: add gamma factor multiplier here (relativistic correction)
            chpos = pos0 + mom.velocity(particle.mass) * time
        else:
            chpos = pos0.copy()
        generate_positions(rng, genpcls, chpos, ch)
Example #17
0
def _mean_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
    num_valid_rows = mask.sum(0)
    return sum_with_mask(array, mask) / num_valid_rows
Example #18
0
def int_dequantize_jit(x: jnp.ndarray, scale: jnp.ndarray, offset: jnp.ndarray, max_int: int, to_type: str):
    return x.astype(to_type) * scale.astype(to_type) / max_int + offset.astype(to_type)
Example #19
0
  def apply(self,
            x: jnp.ndarray,
            num_outputs: int,
            pyramid_alpha: int = 200,
            pyramid_depth: int = 272,
            train: bool = True,
            true_gradient: bool = False) -> jnp.ndarray:
    """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      pyramid_alpha: See paper.
      pyramid_depth: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the PyramidNet model, a tensor of shape
        [batch_size, num_classes].
    """
    assert (pyramid_depth - 2) % 9 == 0

    # Shake-drop hyper-params
    mask_prob = 0.5
    alpha_min, alpha_max = (-1.0, 1.0)
    beta_min, beta_max = (0.0, 1.0)

    # Bottleneck network size
    blocks_per_group = (pyramid_depth - 2) // 9
    # See Eqn 2 in https://arxiv.org/abs/1610.02915
    num_channels = 16
    # N in https://arxiv.org/abs/1610.02915
    total_blocks = blocks_per_group * 3
    delta_channels = pyramid_alpha / total_blocks

    x = nn.Conv(
        x,
        16, (3, 3),
        padding='SAME',
        name='init_conv',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn)
    x = utils.activation(x, apply_relu=False, train=train, name='init_bn')

    layer_num = 1

    for block_i in range(blocks_per_group):
      num_channels += delta_channels
      layer_mask_prob = _calc_shakedrop_mask_prob(layer_num, total_blocks,
                                                  mask_prob)
      x = BottleneckShakeDrop(
          x,
          int(round(num_channels)), (1, 1),
          layer_mask_prob,
          alpha_min,
          alpha_max,
          beta_min,
          beta_max,
          train=train,
          true_gradient=true_gradient)
      layer_num += 1

    for block_i in range(blocks_per_group):
      num_channels += delta_channels
      layer_mask_prob = _calc_shakedrop_mask_prob(
          layer_num, total_blocks, mask_prob)
      x = BottleneckShakeDrop(x, int(round(num_channels)),
                              ((2, 2) if block_i == 0 else (1, 1)),
                              layer_mask_prob,
                              alpha_min, alpha_max, beta_min, beta_max,
                              train=train,
                              true_gradient=true_gradient)
      layer_num += 1

    for block_i in range(blocks_per_group):
      num_channels += delta_channels
      layer_mask_prob = _calc_shakedrop_mask_prob(
          layer_num, total_blocks, mask_prob)
      x = BottleneckShakeDrop(x, int(round(num_channels)),
                              ((2, 2) if block_i == 0 else (1, 1)),
                              layer_mask_prob,
                              alpha_min, alpha_max, beta_min, beta_max,
                              train=train,
                              true_gradient=true_gradient)
      layer_num += 1

    assert layer_num - 1 == total_blocks
    x = utils.activation(x, train=train, name='final_bn')
    x = nn.avg_pool(x, (8, 8))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
    return x
Example #20
0
def atom37_to_frames(
        aatype: jnp.ndarray,  # (...)
        all_atom_positions: jnp.ndarray,  # (..., 37, 3)
        all_atom_mask: jnp.ndarray,  # (..., 37)
) -> Dict[str, jnp.ndarray]:
    """Computes the frames for the up to 8 rigid groups for each residue.

  The rigid groups are defined by the possible torsions in a given amino acid.
  We group the atoms according to their dependence on the torsion angles into
  "rigid groups".  E.g., the position of atoms in the chi2-group depend on
  chi1 and chi2, but do not depend on chi3 or chi4.
  Jumper et al. (2021) Suppl. Table 2 and corresponding text.

  Args:
    aatype: Amino acid type, given as array with integers.
    all_atom_positions: atom37 representation of all atom coordinates.
    all_atom_mask: atom37 representation of mask on all atom coordinates.
  Returns:
    Dictionary containing:
      * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions'
           represented as flat 12 dimensional array.
      * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for
          the given frame are available in the ground truth, e.g. if they were
          resolved in the experiment.
      * 'rigidgroups_group_exists': Mask denoting whether given group is in
          principle present for given amino acid type.
      * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is
          affected by naming ambiguity.
      * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming
          corresponding to 'all_atom_positions' represented as flat
          12 dimensional array.
  """
    # 0: 'backbone group',
    # 1: 'pre-omega-group', (empty)
    # 2: 'phi-group', (currently empty, because it defines only hydrogens)
    # 3: 'psi-group',
    # 4,5,6,7: 'chi1,2,3,4-group'
    aatype_in_shape = aatype.shape

    # If there is a batch axis, just flatten it away, and reshape everything
    # back at the end of the function.
    aatype = jnp.reshape(aatype, [-1])
    all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3])
    all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])

    # Create an array with the atom names.
    # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
    restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)

    # 0: backbone frame
    restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']

    # 3: 'psi-group'
    restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']

    # 4,5,6,7: 'chi1,2,3,4-group'
    for restype, restype_letter in enumerate(residue_constants.restypes):
        resname = residue_constants.restype_1to3[restype_letter]
        for chi_idx in range(4):
            if residue_constants.chi_angles_mask[restype][chi_idx]:
                atom_names = residue_constants.chi_angles_atoms[resname][
                    chi_idx]
                restype_rigidgroup_base_atom_names[restype, chi_idx +
                                                   4, :] = atom_names[1:]

    # Create mask for existing rigid groups.
    restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32)
    restype_rigidgroup_mask[:, 0] = 1
    restype_rigidgroup_mask[:, 3] = 1
    restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask

    # Translate atom names into atom37 indices.
    lookuptable = residue_constants.atom_order.copy()
    lookuptable[''] = 0
    restype_rigidgroup_base_atom37_idx = np.vectorize(
        lambda x: lookuptable[x])(restype_rigidgroup_base_atom_names)

    # Compute the gather indices for all residues in the chain.
    # shape (N, 8, 3)
    residx_rigidgroup_base_atom37_idx = utils.batched_gather(
        restype_rigidgroup_base_atom37_idx, aatype)

    # Gather the base atom positions for each rigid group.
    base_atom_pos = utils.batched_gather(all_atom_positions,
                                         residx_rigidgroup_base_atom37_idx,
                                         batch_dims=1)

    # Compute the Rigids.
    gt_frames = r3.rigids_from_3_points(
        point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]),
        origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]),
        point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :]))

    # Compute a mask whether the group exists.
    # (N, 8)
    group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype)

    # Compute a mask whether ground truth exists for the group
    gt_atoms_exist = utils.batched_gather(  # shape (N, 8, 3)
        all_atom_mask.astype(jnp.float32),
        residx_rigidgroup_base_atom37_idx,
        batch_dims=1)
    gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists  # (N, 8)

    # Adapt backbone frame to old convention (mirror x-axis and z-axis).
    rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
    rots[0, 0, 0] = -1
    rots[0, 2, 2] = -1
    gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots))

    # The frames for ambiguous rigid groups are just rotated by 180 degree around
    # the x-axis. The ambiguous group is always the last chi-group.
    restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
    restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32),
                                      [21, 8, 1, 1])

    for resname, _ in residue_constants.residue_atom_renaming_swaps.items():
        restype = residue_constants.restype_order[
            residue_constants.restype_3to1[resname]]
        chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1)
        restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
        restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
        restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1

    # Gather the ambiguity information for each residue.
    residx_rigidgroup_is_ambiguous = utils.batched_gather(
        restype_rigidgroup_is_ambiguous, aatype)
    residx_rigidgroup_ambiguity_rot = utils.batched_gather(
        restype_rigidgroup_rots, aatype)

    # Create the alternative ground truth frames.
    alt_gt_frames = r3.rigids_mul_rots(
        gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot))

    gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames)
    alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames)

    # reshape back to original residue layout
    gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
    gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8, ))
    group_exists = jnp.reshape(group_exists, aatype_in_shape + (8, ))
    gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12))
    residx_rigidgroup_is_ambiguous = jnp.reshape(
        residx_rigidgroup_is_ambiguous, aatype_in_shape + (8, ))
    alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12,
                                       aatype_in_shape + (
                                           8,
                                           12,
                                       ))

    return {
        'rigidgroups_gt_frames': gt_frames_flat12,  # (..., 8, 12)
        'rigidgroups_gt_exists': gt_exists,  # (..., 8)
        'rigidgroups_group_exists': group_exists,  # (..., 8)
        'rigidgroups_group_is_ambiguous':
        residx_rigidgroup_is_ambiguous,  # (..., 8)
        'rigidgroups_alt_gt_frames': alt_gt_frames_flat12,  # (..., 8, 12)
    }
Example #21
0
def loss(params: optax.Params, batch: jnp.ndarray,
         labels: jnp.ndarray) -> jnp.ndarray:
    logits = net.apply(params, batch)
    loss_val = optax.sigmoid_binary_cross_entropy(logits, labels).sum(axis=-1)
    return loss_val.mean(), (logits.argmax(axis=-1) == labels.argmax(
        axis=-1)).sum(axis=-1) / batch_size
Example #22
0
def _vmap_2d(fn: Callable[[float, float, float], float],
             cov12: np.ndarray,
             var1: np.ndarray,
             var2: Optional[np.ndarray],
             diagonal_batch: bool,
             diagonal_spatial: bool) -> np.ndarray:
  """Effectively a "2D vmap" of `fn(cov12, var1, var2)`.

  Applicable for all possible kernel layouts.

  Args:
    fn:
      scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply.

    cov12:
      covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape
      `(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)`
      depending on `diagonal_batch`, `diagonal_spatial`, and the number of
      spatial dimensions.

    var1:
      variance tensor (`q11`), has shape `(N1[, X, Y, ...])`.

    var2:
      variance tensor (`q22`), has shape `(N1[, X, Y, ...])`.

    diagonal_batch:
      `True` if `cov12` has only one batch dimension.

    diagonal_spatial:
      `True` if `cov12` has spatial dimensions appearing once (vs twice).

  Returns:
    Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same
    shape as `cov12`.
  """
  batch_ndim = 1 if diagonal_batch else 2
  start = 2 - batch_ndim
  cov_end = batch_ndim if diagonal_spatial else cov12.ndim
  _cov12 = utils.make_2d(cov12, start, cov_end)

  var_end = 1 if diagonal_spatial else var1.ndim
  var1 = var1.reshape(var1.shape[:start] + (-1,) + var1.shape[var_end:])
  var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1,) +
                                                var2.shape[var_end:])

  fn = vmap(
      vmap(
          np.vectorize(fn),
          in_axes=(start, None, start),
          out_axes=start
      ),
      in_axes=(start, start, None),
      out_axes=start
  )
  out = fn(_cov12, var1, var2)  # type: np.ndarray
  out_shape = (cov12.shape[:start] +
               cov12.shape[start:cov_end:2] +
               cov12.shape[start + 1:cov_end:2] +
               cov12.shape[cov_end:])
  out = out.reshape(out_shape)
  out = utils.zip_axes(out, start, cov_end)
  return out
Example #23
0
 def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray:
     m = m.astype(g.dtype)
     return g * (1. - m) + t * m
Example #24
0
def poly_from_roots(roots: jnp.ndarray, x: jnp.ndarray):
    """Evaluate polynomial with given `roots` elementwise on `x`."""
    assert len(roots.shape) == 1

    roots = roots.reshape((-1,) + tuple(1 for _ in x.shape))
    return jnp.prod(x - roots, axis=0)
Example #25
0
 def __call__(self, x: np.ndarray) -> np.ndarray:
     x = x.reshape(x.shape[0], -1)
     x = nn.Dense(features=self.num_features)(x)
     x = nn.relu(x)
     return nn.Dense(features=self.num_outputs)(x)
Example #26
0
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
    return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2),
                                                 data).reshape(4)
Example #27
0
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
    return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2),
                                               num).reshape(num, 4)
Example #28
0
    def __call__(
        self,
        x: jnp.ndarray,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Returns normalized inputs.

    Args:
      x: An n-D tensor of the ``data_format`` specified in the constructor
        on which the transformation is performed.
      scale: A tensor up to n-D. The shape of this tensor must be broadcastable
        to the shape of ``x``. This is the scale applied to the normalized
        x. This cannot be passed in if the module was constructed with
        ``create_scale=True``.
      offset: A tensor up to n-D. The shape of this tensor must be broadcastable
        to the shape of ``x``. This is the offset applied to the normalized
        ``x``. This cannot be passed in if the module was constructed with
        ``create_offset=True``.

    Returns:
      An n-d tensor of the same shape as x that has been normalized.
    """
        if self.rank is not None and x.ndim != self.rank:
            raise ValueError(
                "The rank of the inputs cannot change between calls, the"
                f" original call was rank={self.rank} but this call was "
                f"rank={x.ndim}.")

        if self.create_scale and scale is not None:
            raise ValueError(
                "Cannot pass `scale` at call time if `create_scale=True`.")

        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`.")

        channels = x.shape[self.channel_index]
        if channels % self.groups != 0:
            raise ValueError(
                "The number of channels must be divisible by the number of groups, "
                f"was channels={channels}, groups={self.groups}")

        if self.rank is None:
            self._initialize(x, channels)

        dtype = x.dtype
        if self.channel_index == -1:
            params_shape = (x.shape[-1], )
        else:
            assert self.channel_index == 1
            params_shape = (x.shape[1], ) + (1, ) * (self.rank - 2)

        if self.create_scale:
            scale = hk.get_parameter("scale", params_shape, dtype,
                                     self.scale_init)

        if self.create_offset:
            offset = hk.get_parameter("offset", params_shape, dtype,
                                      self.offset_init)

        x = x.reshape(self.group_shape)
        mean = jnp.mean(x, self.axis, keepdims=True)
        # TODO(tycai): Consider faster but less precise variance formulation.
        var = jnp.var(x, self.axis, keepdims=True)
        x = (x - mean) * jax.lax.rsqrt(var + self.eps)
        x = x.reshape(self.first_input_shape)

        if scale is not None:
            scale = jax.lax.broadcast_to_rank(scale, x.ndim)
            x = x * scale

        if offset is not None:
            offset = jax.lax.broadcast_to_rank(offset, x.ndim)
            x = x + offset

        return x
Example #29
0
def walsh_hadamard_transform(
        x: jnp.ndarray,
        small_n: int = 2**7,
        precision: Union[jax.lax.Precision, str] = 'highest') -> jnp.ndarray:
    """Efficient Walsh-Hadamard transform in JAX.

  An accelerator friendly O(n log n) Walsh-Hadamard transform.

  Args:
    x: A vector. len(x) must be a power of 2.
    small_n: Size to break x into. The default value is tuned on TPUv3. Must be
      a power of 2 and > 1.
    precision: Precision for general dot products.

  Returns:
    Transformed vector.
  """
    if small_n <= 1:
        raise ValueError(f'small_n must be > 1, got {small_n}')

    # Let
    # -   A ⊗ B be the Kronecker product of A and B;
    # -   flat(X) be the vector obtained by flattening the rows of X of shape
    #     [M, N].
    #
    # We can show the following:
    #
    #     (A ⊗ B^T) flat(X) = flat(A X B)
    #
    # Note that the Hadamard matrix H_{2^M 2^N} = H_{2^M} ⊗ H_{2^N}, and
    # Hadamard matrices are symmetrical. Therefore, for a [2^M, 2^N] matrix X,
    #
    #     H_{2^M 2^N} flat(X) = flat(H_{2^M} X H_{2^N})
    #
    # The idea can be generalized by breaking a Hadamard matrix into the Kronecker
    # product of many small Hadamard matrices, and reshaping the vector input into
    # a many-dimensional array, and running einsum on each dimension.
    #
    # Let the input vector be of length D, because our "small" Hadamard matrices
    # are of size at most small_n x small_n, a constant, each einsum is O(D). We
    # need to run log D einsums, thus the overall time complexity is O(D log D),
    # same as the classical divide and conquer algorithm.
    #
    # However, thanks to efficient software & hardware implementations of einsum,
    # we can often achieve far better speed than the classical algorithm on
    # accelerators, at the same time producing a far simpler XLA HLO graph.

    n = len(x)

    # Find out the shape to reshape x into.
    shape = []
    while n > 1:
        shape.append(min(n, small_n))
        n //= small_n
    shape.reverse()
    num_dims = len(shape)
    if num_dims + 1 >= 10:
        # We will run out of dimension names in einsums.
        raise ValueError(f'small_n={small_n} is too small for input size {n}')
    y = x.reshape(shape)

    # Hadamard matrices we will need.
    hadamards = dict((d, hadamard_matrix(d, x.dtype)) for d in set(shape))

    # einsum on each dimension.
    for i, d in enumerate(shape):
        y_dims = ''.join(str(j) for j in range(num_dims))
        h_dims = f'{i}{num_dims + 1}'
        out_dims = y_dims.replace(str(i), str(num_dims + 1), 1)
        operands = f'{y_dims},{h_dims}->{out_dims}'
        y = jnp.einsum(operands, y, hadamards[d], precision=precision)
    return y.flatten()
Example #30
0
 def _split(
     self,
     x: jnp.ndarray,
 ) -> jnp.ndarray:
     return x.reshape(
         (*x.shape[:2], self.num_heads, self.model_size // self.num_heads))