def _encode_bow(self, bow: jnp.ndarray) -> jnp.ndarray: """Encode the bag-of-words into tensors that can be used by the transormer. Args: bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector. Returns: embeddings: [batch_size, bow_n_tokens, bow_embedding_dim] tensor. """ batch_size = bow.shape[0] bow = bow.astype(jnp.float32) # [B, D * n] embeddings = hk.Linear(self._bow_embedding_dim * self._bow_n_tokens)(bow) embeddings = transformer_block.layer_norm(jax.nn.gelu(embeddings)) return jnp.reshape( embeddings, [batch_size, self._bow_n_tokens, self._bow_embedding_dim])
def model(design_matrix: jnp.ndarray, outcome: jnp.ndarray = None) -> None: """ Model definition: Log odds of making a purchase is a linear combination of covariates. Specify a Normal prior over regression coefficients. :param design_matrix: Covariates. All categorical variables have been one-hot encoded. :param outcome: Binary response variable. In this case, whether or not the customer made a purchase. """ beta = numpyro.sample( 'coefficients', dist.MultivariateNormal(loc=0., covariance_matrix=jnp.eye( design_matrix.shape[1]))) logits = design_matrix.dot(beta) with numpyro.plate('data', design_matrix.shape[0]): numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=outcome)
def if_points_inside_any_polygon(points: np.ndarray, scene: JaxScene) -> np.ndarray: """Checks if points is inside any polygon Args: points (np.ndarray): point or batch of points to check scene (JaxScene): scene instance Returns: np.ndarray: bool result """ if points.ndim == 1: points = points.reshape(1, -1) signs = batch_segment_normal_projection_sign(points, scene.segments) signs = signs < 0 signs = signs.reshape(len(points), len(scene.polygons), 3) # all polygons have 3 segments is_inside = np.all(signs, axis=-1) result = np.any(is_inside, axis=-1) return result
def __call__(self, x: jnp.ndarray) -> VAEOutput: x = x.astype(jnp.float32) # q(z|x) = N(mean(x), covariance(x)) mean, stddev = Encoder(self._hidden_size, self._latent_size)(x) variational_distrib = distrax.MultivariateNormalDiag(loc=mean, scale_diag=stddev) z = variational_distrib.sample(seed=hk.next_rng_key()) # p(x|z) = \Prod Bernoulli(logits(z)) logits = Decoder(self._hidden_size, self._output_shape)(z) likelihood_distrib = distrax.Independent( distrax.Bernoulli(logits=logits), reinterpreted_batch_ndims=len( self._output_shape)) # 3 non-batch dims # Generate images from the likelihood image = likelihood_distrib.sample(seed=hk.next_rng_key()) return VAEOutput(variational_distrib, likelihood_distrib, image)
def symmetrize( data: jnp.ndarray, row: jnp.ndarray, col: jnp.ndarray, ncols: Optional[int] = None, ): """ Get data of `(A + A.T) / 2` assuming `A` has symmetric sparsity. Args: data: values of `A` row: row indices of `A` col: col indices of `A` ncols: number of columns of `A` Returns: `sym_data`, same shape and dtype as `data`. """ perm = reorder_perm(row=col, col=row, ncols=ncols) return (data + data.take(perm)) / 2
def logpdf(x: jnp.ndarray, kappa: float, mu: jnp.ndarray) -> jnp.ndarray: """Log-density function of the power spherical distribution. Args: x: The set of points at which to evaluate the power spherical density. kappa: Concentration parameter. mu: Mean direction on the sphere. The dimensionality of the sphere is determined from this paramter. Returns: out: The log-density of the power spherical distribution with the specified concentration and mean parameter at the desired points. """ d = mu.size alpha = (d - 1.) / 2. + kappa beta = (d - 1.) / 2. lognormalizer = ((alpha + beta) * jnp.log(2.) + beta * jnp.log(jnp.pi) + jspsp.gammaln(alpha) - jspsp.gammaln(alpha + beta)) unlogprob = kappa * jnp.log(1. + x.dot(mu)) return unlogprob - lognormalizer
def learning_schedule( global_step: jnp.ndarray, base_learning_rate: float, total_steps: int, warmup_steps: int, use_schedule: bool, ) -> float: """Cosine learning rate scheduler.""" # Compute LR & Scaled LR if not use_schedule: return base_learning_rate warmup_learning_rate = (global_step.astype(jnp.float32) / int(warmup_steps) * base_learning_rate if warmup_steps > 0 else base_learning_rate) # Cosine schedule after warmup. decay_learning_rate = _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps, base_learning_rate) return jnp.where(global_step < warmup_steps, warmup_learning_rate, decay_learning_rate)
def squeeze(x: np.ndarray, axis: Union[None, int, Tuple[int, ...]]) -> np.ndarray: """`np.squeeze` analog working with 0-sized axes.""" if isinstance(axis, int): axis = (axis, ) non_zero_axes = tuple() shift = 0 for a in sorted(axis): if x.shape[a - shift] == 0: new_shape = x.shape[:a] + x.shape[a + 1:] if size_at(new_shape) == 0: x = x.reshape(new_shape) else: x = np.zeros(new_shape, x.dtype) shift += 1 else: non_zero_axes += (a - shift, ) return np.squeeze(x, non_zero_axes)
def apply_lse_kernel(self, f: jnp.ndarray, g: jnp.ndarray, eps: float, vec: Optional[jnp.ndarray] = None, axis: int = 0): """Applies grid kernel in log space. See notes in parent class for use case. Reshapes vector inputs below as grids, applies kernels onto each slice, and then expands the outputs as vectors. More implementation details in https://arxiv.org/pdf/1708.01955.pdf Args: f: jnp.ndarray, a vector of potentials g: jnp.ndarray, a vector of potentials eps: float, regularization strength vec: jnp.ndarray, if needed, a vector onto which apply the kernel weighted by f and g. axis: axis (0 or 1) along which summation should be carried out. Returns: a vector, the result of kernel applied in lse space onto vec. """ f, g = jnp.reshape(f, self.grid_size), jnp.reshape(g, self.grid_size) if vec is not None: vec = jnp.reshape(vec, self.grid_size) if axis == 0: f, g = g, f for dimension in range(self.grid_dimension): g, vec = self._apply_lse_kernel_one_dimension( dimension, f, g, eps, vec) g -= jnp.where(jnp.isfinite(f), f, 0) if vec is None: vec = jnp.array(1.0) return g.ravel(), vec.ravel()
def mean_squared_logarithmic_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: """ Computes the mean squared logarithmic error between labels and predictions. ```python loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1) ``` Usage: ```python rng = jax.random.PRNGKey(42) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_pred = jax.random.uniform(rng, shape=(2, 3)) loss = elegy.losses.mean_squared_logarithmic_error(y_true, y_pred) assert loss.shape == (2,) first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0) second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0) assert jnp.array_equal(loss, jnp.mean(jnp.square(first_log - second_log), axis=-1)) ``` Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. Returns: Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`. """ y_true = y_true.astype(y_pred.dtype) first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0) second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0) return jnp.mean(jnp.square(first_log - second_log), axis=-1)
def mean_squared_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray: """ Computes the mean squared error between labels and predictions. After computing the squared distance between the inputs, the mean value over the last dimension is returned. ```python loss = mean(square(y_true - y_pred), axis=-1) ``` Usage: ```python rng = jax.random.PRNGKey(42) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_pred = jax.random.uniform(rng, shape=(2, 3)) loss = elegy.losses.mean_squared_error(y_true, y_pred) assert loss.shape == (2,) assert jnp.array_equal(loss, jnp.mean(jnp.square(y_true - y_pred), axis=-1)) ``` Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. Returns: Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. """ y_true = y_true.astype(y_pred.dtype) return jnp.mean(jnp.square(y_pred - y_true), axis=-1)
def bgrl_loss( first_online_predictions: jnp.ndarray, second_target_projections: jnp.ndarray, second_online_predictions: jnp.ndarray, first_target_projections: jnp.ndarray, symmetrize: bool, valid_mask: jnp.ndarray, ) -> Tuple[jnp.ndarray, LogsDict]: """Implements BGRL loss.""" first_side_node_loss = jnp.sum(jnp.square( _l2_normalize(first_online_predictions, axis=-1) - _l2_normalize(second_target_projections, axis=-1)), axis=-1) if symmetrize: second_side_node_loss = jnp.sum(jnp.square( _l2_normalize(second_online_predictions, axis=-1) - _l2_normalize(first_target_projections, axis=-1)), axis=-1) node_loss = first_side_node_loss + second_side_node_loss else: node_loss = first_side_node_loss loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6) return loss, dict(bgrl_loss=loss)
def plot_final_bounds(x: np.ndarray, y: np.ndarray, xstar: np.ndarray, bounds: np.ndarray, data_xstar: np.ndarray, data_ystar: np.ndarray, coeff_2sls: np.ndarray = None, x_kiv: np.ndarray = None, y_kiv: np.ndarray = None) -> plt.Figure: fig = plt.figure() plt.scatter(x, y, **data_kwargs) plt.plot(xstar, bounds[:, 0], 'g--x', label="lower", lw=2, markersize=10) plt.plot(xstar, bounds[:, 1], 'r--x', label="upper", lw=2, markersize=10) if data_xstar is not None and data_ystar is not None: if data_ystar.ndim > 1: data_ystar = data_ystar.mean(0) plt.plot(data_xstar, data_ystar, label=f"$E[Y | do(x^*)]$", lw=2) if coeff_2sls is not None: tt = np.linspace(np.min(x), np.max(x), 10) y_2sls = coeff_2sls[0] + coeff_2sls[1] * tt plt.plot(tt, y_2sls, ls='dotted', c="tab:purple", lw=2, label="2sls") if x_kiv is not None and y_kiv is not None: plt.plot(x_kiv, y_kiv, ls='dashdot', c="tab:olive", lw=2, label="KIV") def get_limits(vals): lo = np.min(vals) hi = np.max(vals) extend = (hi - lo) / 15. return lo - extend, hi + extend plt.xlim(get_limits(x)) plt.ylim(get_limits(y)) plt.xlabel('x') plt.ylabel('y') plt.title("Lower and upper bound on actual effect") plt.legend() return fig
def recall( y_true: jnp.ndarray, y_pred: jnp.ndarray, threshold: jnp.ndarray, class_id: jnp.ndarray, sample_weight: jnp.ndarray, true_positives: ReduceConfusionMatrix, false_negatives: ReduceConfusionMatrix, ) -> jnp.ndarray: # TODO: class_id behavior y_pred = (y_pred > threshold).astype(jnp.float32) if y_true.dtype != y_pred.dtype: y_pred = y_pred.astype(y_true.dtype) true_positives = true_positives( y_true=y_true, y_pred=y_pred, sample_weight=sample_weight ) false_negatives = false_negatives( y_true=y_true, y_pred=y_pred, sample_weight=sample_weight ) return jnp.nan_to_num(jnp.divide(true_positives, true_positives + false_negatives))
def get_masked_array(x: np.ndarray, mask_constant: float = None) -> MaskedArray: """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask. The mask returned is a boolean `np.ndarray` with masked indices having `True`. Args: x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as `(masked_x, mask)` and pass it through. mask_constant: an optional `float`, the value in inputs to be considered as masked (e.g. padding in a batch of sentences). `None` means no masking. Can also be `np.nan`, `np.inf` etc. Returns: A `MaskedArray` of `(masked_x, boolean_mask)`. """ if x is None: mask_mat = None elif isinstance(x, MaskedArray): x, mask_mat, _, _ = x.astuple() elif isinstance(x, np.ndarray): if mask_constant is None: mask_mat = None else: mask_mat = lax.cond(np.isnan(mask_constant), np.isnan, lambda x: x == mask_constant, x) else: raise TypeError(x, type(x)) x = mask(x, mask_mat) return MaskedArray(x, mask_mat) # pytype: disable=wrong-arg-count
def generate_positions(rng: rjax.PRNGKey, genpcls: dict, pos0: np.ndarray, pcl: GenParticle = None): """ Generates position according to the momentum direction and lifetime Traverses decay tree recursively """ if pcl is None: genpcls['root']['pos'] = Position.from_ndarray(pos0) pcl = genpcls['root']['gpcl'] else: genpcls[pcl.name]['pos'] = Position.from_ndarray(pos0) for ch in pcl.children: particle = genpcls[ch.name]['pcl'] if particle.lifetime > 0.0001 and particle.lifetime < 1: mom = genpcls[ch.name]['mom'] nevt = mom.size rng, key = rjax.split(rng) time = particle.lifetime * rjax.exponential(key, (nevt, 1)) # TODO: add gamma factor multiplier here (relativistic correction) chpos = pos0 + mom.velocity(particle.mass) * time else: chpos = pos0.copy() generate_positions(rng, genpcls, chpos, ch)
def _mean_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: num_valid_rows = mask.sum(0) return sum_with_mask(array, mask) / num_valid_rows
def int_dequantize_jit(x: jnp.ndarray, scale: jnp.ndarray, offset: jnp.ndarray, max_int: int, to_type: str): return x.astype(to_type) * scale.astype(to_type) / max_int + offset.astype(to_type)
def apply(self, x: jnp.ndarray, num_outputs: int, pyramid_alpha: int = 200, pyramid_depth: int = 272, train: bool = True, true_gradient: bool = False) -> jnp.ndarray: """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). pyramid_alpha: See paper. pyramid_depth: See paper. train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. true_gradient: If true, the same mixing parameter will be used for the forward and backward pass (see paper for more details). Returns: The output of the PyramidNet model, a tensor of shape [batch_size, num_classes]. """ assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv( x, 16, (3, 3), padding='SAME', name='init_conv', bias=False, kernel_init=utils.conv_kernel_init_fn) x = utils.activation(x, apply_relu=False, train=train, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob(layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop( x, int(round(num_channels)), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train, true_gradient=true_gradient) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train, true_gradient=true_gradient) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train, true_gradient=true_gradient) layer_num += 1 assert layer_num - 1 == total_blocks x = utils.activation(x, train=train, name='final_bn') x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) return x
def atom37_to_frames( aatype: jnp.ndarray, # (...) all_atom_positions: jnp.ndarray, # (..., 37, 3) all_atom_mask: jnp.ndarray, # (..., 37) ) -> Dict[str, jnp.ndarray]: """Computes the frames for the up to 8 rigid groups for each residue. The rigid groups are defined by the possible torsions in a given amino acid. We group the atoms according to their dependence on the torsion angles into "rigid groups". E.g., the position of atoms in the chi2-group depend on chi1 and chi2, but do not depend on chi3 or chi4. Jumper et al. (2021) Suppl. Table 2 and corresponding text. Args: aatype: Amino acid type, given as array with integers. all_atom_positions: atom37 representation of all atom coordinates. all_atom_mask: atom37 representation of mask on all atom coordinates. Returns: Dictionary containing: * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions' represented as flat 12 dimensional array. * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for the given frame are available in the ground truth, e.g. if they were resolved in the experiment. * 'rigidgroups_group_exists': Mask denoting whether given group is in principle present for given amino acid type. * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is affected by naming ambiguity. * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming corresponding to 'all_atom_positions' represented as flat 12 dimensional array. """ # 0: 'backbone group', # 1: 'pre-omega-group', (empty) # 2: 'phi-group', (currently empty, because it defines only hydrogens) # 3: 'psi-group', # 4,5,6,7: 'chi1,2,3,4-group' aatype_in_shape = aatype.shape # If there is a batch axis, just flatten it away, and reshape everything # back at the end of the function. aatype = jnp.reshape(aatype, [-1]) all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3]) all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) # Create an array with the atom names. # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) # 0: backbone frame restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] # 3: 'psi-group' restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] # 4,5,6,7: 'chi1,2,3,4-group' for restype, restype_letter in enumerate(residue_constants.restypes): resname = residue_constants.restype_1to3[restype_letter] for chi_idx in range(4): if residue_constants.chi_angles_mask[restype][chi_idx]: atom_names = residue_constants.chi_angles_atoms[resname][ chi_idx] restype_rigidgroup_base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] # Create mask for existing rigid groups. restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_mask[:, 0] = 1 restype_rigidgroup_mask[:, 3] = 1 restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask # Translate atom names into atom37 indices. lookuptable = residue_constants.atom_order.copy() lookuptable[''] = 0 restype_rigidgroup_base_atom37_idx = np.vectorize( lambda x: lookuptable[x])(restype_rigidgroup_base_atom_names) # Compute the gather indices for all residues in the chain. # shape (N, 8, 3) residx_rigidgroup_base_atom37_idx = utils.batched_gather( restype_rigidgroup_base_atom37_idx, aatype) # Gather the base atom positions for each rigid group. base_atom_pos = utils.batched_gather(all_atom_positions, residx_rigidgroup_base_atom37_idx, batch_dims=1) # Compute the Rigids. gt_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]), origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]), point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :])) # Compute a mask whether the group exists. # (N, 8) group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype) # Compute a mask whether ground truth exists for the group gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) all_atom_mask.astype(jnp.float32), residx_rigidgroup_base_atom37_idx, batch_dims=1) gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) # Adapt backbone frame to old convention (mirror x-axis and z-axis). rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) rots[0, 0, 0] = -1 rots[0, 2, 2] = -1 gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots)) # The frames for ambiguous rigid groups are just rotated by 180 degree around # the x-axis. The ambiguous group is always the last chi-group. restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): restype = residue_constants.restype_order[ residue_constants.restype_3to1[resname]] chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 # Gather the ambiguity information for each residue. residx_rigidgroup_is_ambiguous = utils.batched_gather( restype_rigidgroup_is_ambiguous, aatype) residx_rigidgroup_ambiguity_rot = utils.batched_gather( restype_rigidgroup_rots, aatype) # Create the alternative ground truth frames. alt_gt_frames = r3.rigids_mul_rots( gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot)) gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames) alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames) # reshape back to original residue layout gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8, )) group_exists = jnp.reshape(group_exists, aatype_in_shape + (8, )) gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) residx_rigidgroup_is_ambiguous = jnp.reshape( residx_rigidgroup_is_ambiguous, aatype_in_shape + (8, )) alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12, aatype_in_shape + ( 8, 12, )) return { 'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12) 'rigidgroups_gt_exists': gt_exists, # (..., 8) 'rigidgroups_group_exists': group_exists, # (..., 8) 'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous, # (..., 8) 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12) }
def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: logits = net.apply(params, batch) loss_val = optax.sigmoid_binary_cross_entropy(logits, labels).sum(axis=-1) return loss_val.mean(), (logits.argmax(axis=-1) == labels.argmax( axis=-1)).sum(axis=-1) / batch_size
def _vmap_2d(fn: Callable[[float, float, float], float], cov12: np.ndarray, var1: np.ndarray, var2: Optional[np.ndarray], diagonal_batch: bool, diagonal_spatial: bool) -> np.ndarray: """Effectively a "2D vmap" of `fn(cov12, var1, var2)`. Applicable for all possible kernel layouts. Args: fn: scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply. cov12: covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape `(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)` depending on `diagonal_batch`, `diagonal_spatial`, and the number of spatial dimensions. var1: variance tensor (`q11`), has shape `(N1[, X, Y, ...])`. var2: variance tensor (`q22`), has shape `(N1[, X, Y, ...])`. diagonal_batch: `True` if `cov12` has only one batch dimension. diagonal_spatial: `True` if `cov12` has spatial dimensions appearing once (vs twice). Returns: Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same shape as `cov12`. """ batch_ndim = 1 if diagonal_batch else 2 start = 2 - batch_ndim cov_end = batch_ndim if diagonal_spatial else cov12.ndim _cov12 = utils.make_2d(cov12, start, cov_end) var_end = 1 if diagonal_spatial else var1.ndim var1 = var1.reshape(var1.shape[:start] + (-1,) + var1.shape[var_end:]) var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1,) + var2.shape[var_end:]) fn = vmap( vmap( np.vectorize(fn), in_axes=(start, None, start), out_axes=start ), in_axes=(start, start, None), out_axes=start ) out = fn(_cov12, var1, var2) # type: np.ndarray out_shape = (cov12.shape[:start] + cov12.shape[start:cov_end:2] + cov12.shape[start + 1:cov_end:2] + cov12.shape[cov_end:]) out = out.reshape(out_shape) out = utils.zip_axes(out, start, cov_end) return out
def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray: m = m.astype(g.dtype) return g * (1. - m) + t * m
def poly_from_roots(roots: jnp.ndarray, x: jnp.ndarray): """Evaluate polynomial with given `roots` elementwise on `x`.""" assert len(roots.shape) == 1 roots = roots.reshape((-1,) + tuple(1 for _ in x.shape)) return jnp.prod(x - roots, axis=0)
def __call__(self, x: np.ndarray) -> np.ndarray: x = x.reshape(x.shape[0], -1) x = nn.Dense(features=self.num_features)(x) x = nn.relu(x) return nn.Dense(features=self.num_outputs)(x)
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray: return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
def __call__( self, x: jnp.ndarray, scale: Optional[jnp.ndarray] = None, offset: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Returns normalized inputs. Args: x: An n-D tensor of the ``data_format`` specified in the constructor on which the transformation is performed. scale: A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of ``x``. This is the scale applied to the normalized x. This cannot be passed in if the module was constructed with ``create_scale=True``. offset: A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of ``x``. This is the offset applied to the normalized ``x``. This cannot be passed in if the module was constructed with ``create_offset=True``. Returns: An n-d tensor of the same shape as x that has been normalized. """ if self.rank is not None and x.ndim != self.rank: raise ValueError( "The rank of the inputs cannot change between calls, the" f" original call was rank={self.rank} but this call was " f"rank={x.ndim}.") if self.create_scale and scale is not None: raise ValueError( "Cannot pass `scale` at call time if `create_scale=True`.") if self.create_offset and offset is not None: raise ValueError( "Cannot pass `offset` at call time if `create_offset=True`.") channels = x.shape[self.channel_index] if channels % self.groups != 0: raise ValueError( "The number of channels must be divisible by the number of groups, " f"was channels={channels}, groups={self.groups}") if self.rank is None: self._initialize(x, channels) dtype = x.dtype if self.channel_index == -1: params_shape = (x.shape[-1], ) else: assert self.channel_index == 1 params_shape = (x.shape[1], ) + (1, ) * (self.rank - 2) if self.create_scale: scale = hk.get_parameter("scale", params_shape, dtype, self.scale_init) if self.create_offset: offset = hk.get_parameter("offset", params_shape, dtype, self.offset_init) x = x.reshape(self.group_shape) mean = jnp.mean(x, self.axis, keepdims=True) # TODO(tycai): Consider faster but less precise variance formulation. var = jnp.var(x, self.axis, keepdims=True) x = (x - mean) * jax.lax.rsqrt(var + self.eps) x = x.reshape(self.first_input_shape) if scale is not None: scale = jax.lax.broadcast_to_rank(scale, x.ndim) x = x * scale if offset is not None: offset = jax.lax.broadcast_to_rank(offset, x.ndim) x = x + offset return x
def walsh_hadamard_transform( x: jnp.ndarray, small_n: int = 2**7, precision: Union[jax.lax.Precision, str] = 'highest') -> jnp.ndarray: """Efficient Walsh-Hadamard transform in JAX. An accelerator friendly O(n log n) Walsh-Hadamard transform. Args: x: A vector. len(x) must be a power of 2. small_n: Size to break x into. The default value is tuned on TPUv3. Must be a power of 2 and > 1. precision: Precision for general dot products. Returns: Transformed vector. """ if small_n <= 1: raise ValueError(f'small_n must be > 1, got {small_n}') # Let # - A ⊗ B be the Kronecker product of A and B; # - flat(X) be the vector obtained by flattening the rows of X of shape # [M, N]. # # We can show the following: # # (A ⊗ B^T) flat(X) = flat(A X B) # # Note that the Hadamard matrix H_{2^M 2^N} = H_{2^M} ⊗ H_{2^N}, and # Hadamard matrices are symmetrical. Therefore, for a [2^M, 2^N] matrix X, # # H_{2^M 2^N} flat(X) = flat(H_{2^M} X H_{2^N}) # # The idea can be generalized by breaking a Hadamard matrix into the Kronecker # product of many small Hadamard matrices, and reshaping the vector input into # a many-dimensional array, and running einsum on each dimension. # # Let the input vector be of length D, because our "small" Hadamard matrices # are of size at most small_n x small_n, a constant, each einsum is O(D). We # need to run log D einsums, thus the overall time complexity is O(D log D), # same as the classical divide and conquer algorithm. # # However, thanks to efficient software & hardware implementations of einsum, # we can often achieve far better speed than the classical algorithm on # accelerators, at the same time producing a far simpler XLA HLO graph. n = len(x) # Find out the shape to reshape x into. shape = [] while n > 1: shape.append(min(n, small_n)) n //= small_n shape.reverse() num_dims = len(shape) if num_dims + 1 >= 10: # We will run out of dimension names in einsums. raise ValueError(f'small_n={small_n} is too small for input size {n}') y = x.reshape(shape) # Hadamard matrices we will need. hadamards = dict((d, hadamard_matrix(d, x.dtype)) for d in set(shape)) # einsum on each dimension. for i, d in enumerate(shape): y_dims = ''.join(str(j) for j in range(num_dims)) h_dims = f'{i}{num_dims + 1}' out_dims = y_dims.replace(str(i), str(num_dims + 1), 1) operands = f'{y_dims},{h_dims}->{out_dims}' y = jnp.einsum(operands, y, hadamards[d], precision=precision) return y.flatten()
def _split( self, x: jnp.ndarray, ) -> jnp.ndarray: return x.reshape( (*x.shape[:2], self.num_heads, self.model_size // self.num_heads))