Ejemplo n.º 1
0
def adj_arr(shape, params):
    """Constructs a sliced weighted acyclic-digraph adjacency matrix.

    Assigns weights from `params` to an acylic-digraph adjacency matrix
    in row-major order. Nodes are assumed to be topologically sorted.

    Args:
        shape: Tuple containing number of nodes per type.
        params: Sequence of weights.

    Returns:
        Weighted adjacency matrix where the rows and columns represent
        source and target nodes respectively. Only arcs with non-output
        source and non-input target are represented.

    Raises:
        AttributeError: If weighted adjacency matrix cannot be
            completely initialized.
    """
    inb, _, out = shape
    Σ = np.sum(shape)

    outbound = np.array([Σ - max(inb, n + 1) for n in range(Σ - out)])
    arr = np.flip(np.arange(Σ - inb), 0) < outbound[:, None]

    if np.count_nonzero(arr) != len(params):
        msg = "{} weights required to initialize adj matrix of a {} network, got {}."
        raise AttributeError(
            msg.format(np.count_nonzero(arr), shape, len(params)))

    return jo.index_update(arr.astype(float), arr, params)
Ejemplo n.º 2
0
    def test_sparse_init(self):
        """Test that sparse_init produces sparse params."""

        rng = jax.random.PRNGKey(0)
        flax_module, params, input_shape, model_hps = _load_model(
            'fully_connected')
        non_zero_connection_weights = 3
        init_hps = sparse_init.DEFAULT_HPARAMS
        init_hps['non_zero_connection_weights'] = non_zero_connection_weights
        init_hps.update(model_hps)
        loss_name = 'cross_entropy'
        loss_fn = losses.get_loss_fn(loss_name)
        new_params = sparse_init.sparse_init(loss_fn=loss_fn,
                                             flax_module=flax_module,
                                             params=params,
                                             hps=init_hps,
                                             input_shape=input_shape[1:],
                                             output_shape=OUTPUT_SHAPE,
                                             rng_key=rng)

        # Check new params are sparse
        for key in new_params:
            num_units = new_params[key]['kernel'].shape[0]
            self.assertEqual(jnp.count_nonzero(new_params[key]['kernel']),
                             num_units * non_zero_connection_weights)
            self.assertEqual(jnp.count_nonzero(new_params[key]['bias']), 0)
Ejemplo n.º 3
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self):
        """Tests shuffled mask generation, for 100% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model, self._rng, 1.0)

        with self.subTest(name='shuffled_full_mask'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_full_mask_values'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
                jnp.prod(jnp.array(self._input_dimensions)))

        with self.subTest(name='shuffled_full_no_input_ablation'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())

        with self.subTest(name='shuffled_full_mask_not_masked_values'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='shuffled_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
Ejemplo n.º 4
0
def logmarglike_lineargaussianmodel_twotransfers(
    M_T,  #  (n_components, n_pix_y)
    y,  # (n_pix_y)
    yinvvar,  # (n_pix_y)
    mu,  # (n_components)
    muinvvar,  #  (n_components)
    logyinvvar=None,
    logmuinvvar=None,
):
    """
    Fit linear model to one Gaussian data sets, with Gaussian prior on linear components.

    Parameters
    ----------
    y, yinvvar : ndarray (n_pix_y)
        data and data inverse variances
    M_T : ndarray (n_components, n_pix_y)
        design matrix of linear model
    mu, muinvvar : ndarray (n_components)
        data and data variances for y

    Returns
    -------
    logfml : ndarray scalar
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components, n_components)
        Parameter covariance

    """
    log2pi = np.log(2.0 * np.pi)
    nt = np.shape(M_T)[-2]
    ny = np.count_nonzero(yinvvar)
    nm = np.count_nonzero(muinvvar)
    M = np.transpose(M_T)  # (n_pix_y, n_components)
    Myinv = M * yinvvar[:, None]  # (n_pix_y, n_components)
    Hbar = (np.matmul(M_T, Myinv) + np.eye(nt) * muinvvar[:, None]
            )  #  (n_components, n_components)
    etabar = np.sum(Myinv * y[:, None],
                    axis=0) + mu * muinvvar  # (n_components)
    theta_map = np.linalg.solve(Hbar, etabar)  # (n_components)
    theta_cov = np.linalg.inv(Hbar)  # (n_components, n_components)
    if logyinvvar is None:
        logyinvvar = np.where(yinvvar == 0, 0, np.log(yinvvar))
    if logmuinvvar is None:
        logmuinvvar = np.where(muinvvar == 0, 0, np.log(muinvvar))
    logdetH = +np.sum(logyinvvar) + np.sum(logmuinvvar)  # scalar
    xi1 = -0.5 * ((ny + nm) * log2pi - logdetH + np.sum(y * y * yinvvar) +
                  np.sum(mu * mu * muinvvar))  # scalar
    sign, logdetHbar = np.linalg.slogdet(Hbar)
    xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map))
    logfml = xi1 - xi2
    print("my Cinv_X", np.sum(y * y * yinvvar) - np.sum(etabar * theta_map))
    print("my logdet", -logdetH + logdetHbar)
    print("my counts", (ny + nm - nt) * log2pi)
    return logfml, theta_map, theta_cov
Ejemplo n.º 5
0
def logmarglike_lineargaussianmodel_onetransfer(M_T,
                                                y,
                                                yinvvar,
                                                logyinvvar=None):
    """
    Fit linear model to one Gaussian data set, with no (=uniform) prior on the linear components.

    Parameters
    ----------
    y, yinvvar, logyinvvar : ndarray (n_pix_y)
        data and data inverse variances.
        Zeros will be ignored.
    M_T : ndarray (n_components, n_pix_y)
        design matrix of linear model

    Returns
    -------
    logfml : ndarray scalar
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components, n_components)
        Parameter covariance

    """
    # assert y.shape[-2] == yinvvar.shape[-2]
    assert y.shape[-1] == yinvvar.shape[-1]
    # assert y.shape[-1] == 1
    assert M_T.shape[-1] == yinvvar.shape[-1]
    assert np.all(np.isfinite(yinvvar))  # no negative elements
    assert np.all(np.isfinite(y))  # all finite
    assert np.all(np.isfinite(M_T))  # all finite
    assert np.count_nonzero(
        yinvvar) > 2  # at least two valid (non zero) pixels

    log2pi = np.log(2.0 * np.pi)
    nt = np.shape(M_T)[-2]
    ny = np.count_nonzero(yinvvar)
    M = np.transpose(M_T)  # (n_pix_y, n_components)
    Myinv = M * yinvvar[:, None]  # (n_pix_y, n_components)
    Hbar = np.matmul(M_T, Myinv)  #  (n_components, n_components)
    etabar = np.sum(Myinv * y[:, None], axis=0)  # (n_components)
    theta_map = np.linalg.solve(Hbar, etabar)  # (n_components)
    theta_cov = np.linalg.inv(Hbar)  # (n_components, n_components)
    if logyinvvar is None:
        logyinvvar = np.where(yinvvar == 0, 0, np.log(yinvvar))
    logdetH = np.sum(logyinvvar)  # scalar
    xi1 = -0.5 * (ny * log2pi - logdetH + np.sum(y * y * yinvvar))  # scalar
    sign, logdetHbar = np.linalg.slogdet(Hbar)
    xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map))
    logfml = xi1 - xi2
    return logfml, theta_map, theta_cov
Ejemplo n.º 6
0
 def testCountNonzero(self, shape, dtype, axis):
   rng = jtu.rand_some_zero()
   onp_fun = lambda x: onp.count_nonzero(x, axis)
   lnp_fun = lambda x: lnp.count_nonzero(x, axis)
   args_maker = lambda: [rng(shape, dtype)]
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
Ejemplo n.º 7
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer(
            self):
        """Tests shuffled mask generation for two layers, and 100% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model_twolayer, self._rng, 1.0)

        with self.subTest(name='shuffled_full_mask_layer1'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_full_mask_values_layer1'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
                jnp.prod(jnp.array(self._input_dimensions)))

        with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        with self.subTest(name='shuffled_full_no_input_ablation_layer1'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())

        with self.subTest(name='shuffled_full_mask_layer2'):
            self.assertIn('MaskedModule_1', mask)

        with self.subTest(name='shuffled_full_mask_values_layer2'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_1']['kernel']),
                jnp.prod(MaskedTwoLayerDense.NUM_FEATURES[0]))

        with self.subTest(name='shuffled_full_mask_not_masked_values_layer2'):
            self.assertIsNone(mask['MaskedModule_1']['bias'])

        with self.subTest(name='shuffled_full_no_input_ablation_layer2'):
            # Note: check no *inputs* are ablated, and inputs < num_neurons.
            self.assertEqual(
                jnp.sum(
                    jnp.count_nonzero(mask['MaskedModule_1']['kernel'],
                                      axis=0)),
                MaskedTwoLayerDense.NUM_FEATURES[0])

        masked_output = self._masked_model_twolayer(self._input, mask=mask)

        with self.subTest(name='shuffled_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output_twolayer.shape)
Ejemplo n.º 8
0
def add_eos_token(indices, eos_token):
    """Add EOS token to sequence."""
    batch_size = indices.shape[0]
    lengths = jnp.count_nonzero(indices, axis=1)

    indices = jnp.pad(indices, pad_width=[(0, 0), (0, 1)], mode='constant')
    # Add EOS token.
    indices = indices.at[jnp.arange(batch_size), lengths].set(eos_token)
    return indices.astype(jnp.int32)
 def example_helper(model, padded_example_and_rng):
     """Run the model on one example."""
     # Compute the same estimates as in training, for ease of comparison.
     # Instead of aggregating with nanmean per-batch, we aggregate over the full
     # validation set.
     if use_sampling_model:
         (output_logits, targets, valid_mask, loss,
          batch_metrics) = sample_loss_fn(model, padded_example_and_rng,
                                          target_edge_index, num_edge_types)
     else:
         (output_logits, targets, valid_mask, num_nodes,
          captured) = extract_outputs_and_targets(model,
                                                  padded_example_and_rng,
                                                  target_edge_index,
                                                  num_edge_types)
         loss, batch_metrics = loss_fn(output_logits, targets, valid_mask,
                                       num_nodes, captured)
     batch_metrics_non_nan = jax.tree_map(jnp.nan_to_num, batch_metrics)
     batch_metrics_non_nan_counts = jax.tree_map(
         lambda x: jnp.count_nonzero(~jnp.isnan(x)), batch_metrics)
     # Compute additional metrics by counting how many predictions cross our
     # thresholds.
     output_probs = jax.scipy.special.expit(output_logits)
     preds = (output_probs[None, :, :] > candidate_thresholds[:, None,
                                                              None])
     # Count true/false target/pred pairs.
     valid = valid_mask.astype(bool)
     count_t_target_t_pred = jnp.count_nonzero(valid & targets & preds,
                                               axis=(1, 2))
     count_t_target_f_pred = jnp.count_nonzero(valid & targets & (~preds),
                                               axis=(1, 2))
     count_f_target_t_pred = jnp.count_nonzero(valid & (~targets) & preds,
                                               axis=(1, 2))
     count_f_target_f_pred = jnp.count_nonzero(valid & (~targets) &
                                               (~preds),
                                               axis=(1, 2))
     counts = (count_t_target_t_pred, count_t_target_f_pred,
               count_f_target_t_pred, count_f_target_f_pred)
     return loss, 1, batch_metrics_non_nan, batch_metrics_non_nan_counts, counts
Ejemplo n.º 10
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self):
        """Tests shuffled mask generation, for a half-full mask."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model, self._rng, 0.5)
        param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
            'kernel'].shape

        with self.subTest(name='shuffled_mask_values'):
            self.assertEqual(jnp.sum(mask['MaskedModule_0']['kernel']),
                             param_shape[0] // 2 * param_shape[1])

        with self.subTest(name='shuffled_half_no_input_ablation'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
Ejemplo n.º 11
0
Archivo: data.py Proyecto: jackd/grax
def get_largest_component_indices(
    graph: JAXSparse,
    dtype: jnp.dtype = jnp.int32,
    directed: bool = True,
    connection="weak",
) -> jnp.ndarray:
    ncomponents, labels = get_component_indices(graph,
                                                dtype=dtype,
                                                directed=directed,
                                                connection=connection)
    if ncomponents == 1:
        return jnp.arange(graph.shape[0], dtype=dtype)
    sizes = jnp.asarray(
        [jnp.count_nonzero(labels == i) for i in range(ncomponents)])
    i = jnp.argmax(sizes)
    (indices, ) = jnp.where(labels == i)
    return jnp.asarray(indices, dtype)
Ejemplo n.º 12
0
    def __call__(self, param_name, param):
        """Shuffles the weight matrix/mask for a given parameter, per-neuron.

    This is to be used with mask_map, and accepts the standard mask_map
    function parameters.

    Args:
      param_name: The parameter's name.
      param: The parameter's weight or mask matrix.

    Returns:
      A shuffled weight/mask matrix, with each neuron shuffled independently.
    """
        del param_name  # Unused.

        incoming_connections = jnp.prod(jnp.array(param.shape[:-1]))
        num_neurons = param.shape[-1]

        # Ensure each input neuron has at least one connection unmasked.
        mask = _fill_diagonal_wrap((incoming_connections, num_neurons),
                                   1,
                                   dtype=jnp.uint8)

        # Randomly shuffle which of the neurons have these connections.
        mask = jax.random.shuffle(self._get_rng(), mask, axis=0)

        # Add extra required random connections to mask to satisfy sparsity.
        mask_cols = []
        for col in range(mask.shape[-1]):
            neuron_mask = mask[:, col]
            off_diagonal_count = max(
                round((1 - self._sparsity) * incoming_connections) -
                jnp.count_nonzero(neuron_mask), 0)

            zero_indices = jnp.flatnonzero(neuron_mask == 0)
            random_entries = _random_neuron_mask(len(zero_indices),
                                                 off_diagonal_count,
                                                 self._get_rng())

            neuron_mask = neuron_mask.at[zero_indices].set(random_entries)
            mask_cols.append(neuron_mask)

        return jnp.column_stack(mask_cols).reshape(param.shape)
Ejemplo n.º 13
0
def logmarglike_scalingmodel_flatprior_jit(
        ymod,  #  (n_components, n_pix_y)
        y,  # (n_components, n_pix_y)
        yinvvar,  # (n_components, n_pix_y)
        logyinvvar,  # (n_components, n_pix_y)
):
    """
    Fit model to one Gaussian data set, with Gaussian prior on scaling.

    Parameters
    ----------
    y, yinvvar, logyinvvar : ndarray (n_components, n_pix_y)
        data and data inverse variances
    ymod : ndarray (n_components, n_pix_y)
        design matrix of linear model
    mu, muinvvar : ndarray (n_components)
        priors

    Returns
    -------
    logfml : ndarray (n_components)
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components)
        Parameter covariance

    """
    log2pi = np.log(2.0 * np.pi)
    ny = np.count_nonzero(yinvvar)
    # n_components
    FOT = np.sum(ymod * y * yinvvar, axis=-1)
    FTT = np.sum(ymod**2 * yinvvar, axis=-1)
    FOO = np.sum(ymod**2 * yinvvar, axis=-1)
    logSigma_det = np.sum(logyinvvar, axis=-1)
    ellML = FOT / FTT
    chi2 = FOO - (FOT / FTT) * FOT
    logfml = -0.5 * (chi2 + np.log(FTT) - logSigma_det + ny * log2pi)
    theta_map = FOT / FTT
    theta_cov = FTT**-1
    return logfml, theta_map, theta_cov
Ejemplo n.º 14
0
def mean_and_var(
    x: Optional[np.ndarray],
    axis: Optional[Axes] = None,
    dtype: Optional[np.dtype] = None,
    out: Optional[None] = None,
    ddof: int = 0,
    keepdims: bool = False,
    mask: Optional[np.ndarray] = None,
    get_var: bool = False
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    """`np.mean` and `np.var` taking the `mask` information into account."""
    var = None
    if x is None:
        return x, var

    if mask is None:
        mean = np.mean(x, axis, dtype, out, keepdims)
        if get_var:
            var = np.var(x, axis, dtype, out, ddof, keepdims)

    else:
        axis = tuple(utils.canonicalize_axis(axis, x))
        size = utils.size_at(x, axis)
        mask = np.broadcast_to(mask, x.shape)
        mask_size = np.count_nonzero(mask, axis)
        for i in axis:
            mask_size = np.expand_dims(mask_size, i)
        size -= mask_size
        size = np.maximum(size, 1)

        mean = np.sum(x, axis=axis, keepdims=True) / size
        if not keepdims:
            mean = np.squeeze(mean, axis)

        if get_var:
            var = np.sum(
                (x - mean)**2, axis=axis, keepdims=True) / (size - ddof)
            if not keepdims:
                var = np.squeeze(var, axis)

    return mean, var
Ejemplo n.º 15
0
def smooth_and_normalize(vec, normalize=True):
    """
    Parameters:
    * vec : np.array of shape (n,)
    * normalize : bool

    Returns:
    out : np.array of shape (n,).
    If vec_i = 0, then out_i = epsilon. If vec_i !=0, then out_i = vec_i - c.
    c is chosen such that sum(vec) == 1.
    """
    vec = np.asarray(vec, dtype=np.float32)

    if normalize:
        vec = vec / vec.sum()
    n = len(vec)
    epsilon = 0.0001
    num_nonzero = np.count_nonzero(vec)
    c = epsilon * (n - num_nonzero) / num_nonzero
    perturbation = (vec == 0) * epsilon - (vec != 0) * c
    return vec + perturbation
Ejemplo n.º 16
0
    def validation_function(model):
        with contextlib.ExitStack() as exit_stack:
            valid_iterator = valid_iterator_factory()
            if prefetch:
                valid_iterator = exit_stack.enter_context(
                    data_loading.ThreadedPrefetcher(valid_iterator, 4))
            accumulated = None
            example_count = 0
            for batch in valid_iterator:
                results = parallel_metrics_batch(model, batch.example,
                                                 batch.mask,
                                                 batch.static_metadata)
                metrics = jax.tree_map(float,
                                       flax.jax_utils.unreplicate(results))
                metrics["epoch"] = np.sum(batch.epoch)
                if accumulated is None:
                    accumulated = metrics
                else:
                    accumulated = jax.tree_multimap(operator.add, accumulated,
                                                    metrics)
                example_count += jnp.count_nonzero(batch.mask)

            assert example_count > 0, "Validation iterator must be nonempty"
            accumulated = typing.cast(Dict[str, Any], accumulated)

            final_metrics = {}
            for k, v in accumulated.items():
                if isinstance(v, RatioMetric):
                    final_metrics[k] = v.numerator / v.denominator
                    if include_total_counts:
                        final_metrics[k + "_numerator"] = v.numerator
                        final_metrics[k + "_denominator"] = v.denominator
                else:
                    final_metrics[k] = v / example_count

            objective = final_metrics[objective_metric_name]
            if include_total_counts:
                final_metrics["validation_total_example_count"] = example_count
            return (objective, final_metrics)
Ejemplo n.º 17
0
def logmarglike_lineargaussianmodel_onetransfer_jit(M_T, y, yinvvar,
                                                    logyinvvar):
    """
    Fit linear model to one Gaussian data set, with no (=uniform) prior on the linear components.

    Parameters
    ----------
    y, yinvvar, logyinvvar : ndarray (n_pix_y)
        data and data inverse variances.
        Zeros will be ignored.
    M_T : ndarray (n_components, n_pix_y)
        design matrix of linear model

    Returns
    -------
    logfml : ndarray scalar
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components, n_components)
        Parameter covariance

    """
    log2pi = np.log(2.0 * np.pi)
    nt = np.shape(M_T)[-2]
    ny = np.count_nonzero(yinvvar)
    M = np.transpose(M_T)  # (n_pix_y, n_components)
    Myinv = M * yinvvar[:, None]  # (n_pix_y, n_components)
    Hbar = np.matmul(M_T, Myinv)  #  (n_components, n_components)
    etabar = np.sum(Myinv * y[:, None], axis=0)  # (n_components)
    theta_map = np.linalg.solve(Hbar, etabar)  # (n_components)
    theta_cov = np.linalg.inv(Hbar)  # (n_components, n_components)
    logdetH = np.sum(logyinvvar)  # scalar
    xi1 = -0.5 * (ny * log2pi - logdetH + np.sum(y * y * yinvvar))  # scalar
    sign, logdetHbar = np.linalg.slogdet(Hbar)
    xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map))
    logfml = xi1 - xi2
    return logfml, theta_map, theta_cov
Ejemplo n.º 18
0
def unique(rng, binary_vectors: np.ndarray) -> int:
    """Computes the number of unique binary columns."""
    alpha = jax.random.normal(rng, shape=((1, binary_vectors.shape[1])))
    return 1 + np.count_nonzero(
        np.diff(np.sort(np.sum(binary_vectors * alpha, axis=-1))))
Ejemplo n.º 19
0
def count_permutations_mask_layer(mask_layer,
                                  next_mask_layer=None,
                                  parameter_key='kernel'):
    """Calculates the number of permutations for a layer, given binary masks.

  Args:
   mask_layer: The binary weight mask of a dense/conv layer, where last
     dimension is number of neurons/filters.
   next_mask_layer: The binary weight mask of the following a dense/conv layer,
     or None if this is the last layer.
   parameter_key: The name of the parameter to count the permutations of in each
     layer.

  Returns:
   A dictionary with stats on the permutation structure of a mask, including
   the number of symmetric permutations of the mask, number of unique mask
   columns, count of the zeroed out (structurally pruned) neurons, and total
   number of neurons/filters.
  """
    # Have to check 'is None' since mask_layer[parameter_key] is jnp.array.
    if not mask_layer or parameter_key not in mask_layer or mask_layer[
            parameter_key] is None:
        return {
            'permutations': 1,
            'zeroed_neurons': 0,
            'total_neurons': 0,
            'unique_neurons': 0,
        }

    mask = mask_layer[parameter_key]

    num_neurons = mask.shape[-1]

    # Initialize with stats for an empty mask.
    mask_stats = {
        'permutations': 0,
        'zeroed_neurons': num_neurons,
        'total_neurons': num_neurons,
        'unique_neurons': 0,
    }

    # Re-shape masks as 1D, in case they are 2D (e.g. convolutional).
    connection_mask = mask.reshape(-1, num_neurons)

    # Only consider non-zero columns (in JAX neurons/filters are last index).
    non_zero_neurons = ~jnp.all(connection_mask == 0, axis=0)

    # Count only zeroed neurons in the current layer.
    zeroed_count = num_neurons - jnp.count_nonzero(non_zero_neurons)

    # Special case where all neurons in current layer are ablated.
    if zeroed_count == num_neurons:
        return mask_stats

    # Have to check is None since next_mask_layer[parameter_key] is jnp.array.
    if next_mask_layer and parameter_key in next_mask_layer and next_mask_layer[
            parameter_key] is not None:
        next_mask = next_mask_layer[parameter_key]

        # Re-shape masks as 1D, in case they are 2D (e.g. convolutional).
        next_connection_mask = next_mask.T.reshape(-1, num_neurons)

        # Update with neurons that are non-zero in outgoing connections too.
        non_zero_neurons &= ~jnp.all(next_connection_mask == 0, axis=0)

        # Remove rows corresponding to neurons that are ablated.
        next_connection_mask = next_connection_mask[:, non_zero_neurons]

        connection_mask = connection_mask[:, non_zero_neurons]

        # Combine the outgoing and incoming masks in one vector per-neuron.
        connection_mask = jnp.concatenate(
            (connection_mask, next_connection_mask), axis=0)

    else:
        connection_mask = connection_mask[:, non_zero_neurons]

    # Effectively no connections between these two layers.
    if not connection_mask.size:
        return mask_stats

    # Note: np.unique not implemented in JAX numpy yet.
    _, unique_counts = np.unique(connection_mask, axis=-1, return_counts=True)

    # Convert from device array.
    mask_stats['zeroed_neurons'] = int(zeroed_count)

    mask_stats['permutations'] = functools.reduce(operator.mul,
                                                  (np.math.factorial(t)
                                                   for t in unique_counts))
    mask_stats['unique_neurons'] = len(unique_counts)

    return mask_stats
Ejemplo n.º 20
0
                            flax.serialization.to_state_dict(
                                replicated_optimizer),
                            "batch":
                            batch,
                            "metrics":
                            metrics,
                            "agg_grads":
                            flax.serialization.to_state_dict(agg_grads),
                        }
                        pickle.dump(postmortem,
                                    fp,
                                    protocol=pickle.HIGHEST_PROTOCOL)

                bad_grads = jax.tree_map(
                    lambda x: float(
                        jnp.count_nonzero(~jnp.isfinite(x)) / x.size),
                    flax.serialization.to_state_dict(agg_grads))
                bad_grads_str = json.dumps(bad_grads, indent=2)
                raise RuntimeError(f"Non-finite gradients at step {step}!\n\n"
                                   f"Bad fraction:\n{bad_grads_str}")

            elapsed_sec = time.time() - start_time
            metrics["elapsed_hours"] = elapsed_sec / 3600
            if max_seconds is not None and elapsed_sec > max_seconds:
                logging.info("Hit max train timeout at step %d", step)
                shutdown_after_this_iteration = True

            # Do a validation step.
            if validation_fn and step % steps_per_validate == 0:
                logging.info("Running validation at step %d", step)
                objective, valid_metrics = validation_fn(
Ejemplo n.º 21
0
 def evaluate_batch(self, images, labels):
     logits = self.model(images, training=False)
     num_correct = jn.count_nonzero(
         jn.equal(jn.argmax(logits, axis=1), labels))
     return num_correct
Ejemplo n.º 22
0
def logmarglike_lineargaussianmodel_threetransfers_jit(
        ell,  # scalar
        M_T,  #  (n_components, n_pix_y)
        R_T,  # (n_components, n_pix_z)
        y,  # (n_pix_y)
        yinvvar,  # (n_pix_y),
        logyinvvar,  # (n_pix_y),
        z,  #  (n_pix_z)
        zinvvar,  #  (n_pix_z)
        logzinvvar,  #  (n_pix_z)
        mu,  # (n_components)
        muinvvar,  # (n_components)
        logmuinvvar,  # (n_components)
):
    """
    Fit linear model to two Gaussian data sets, with Gaussian prior on components.

    Parameters
    ----------
    ell : ndarray scalar
        scaling between the data: y = ell * z
    y, yinvvar, logyinvvar : ndarray (n_pix_y)
        data and data inverse variances
    M_T : ndarray (n_components, n_pix_y)
        design matrix of linear model
    z, zinvvar, logzinvvar : ndarray (n_pix_z)
        data and data variances for y
    R_T : ndarray (n_components, n_pix_z)
        design matrix of linear model for z
    mu, muinvvar, logmuinvvar : ndarray (n_components)
        data and data variances for y

    Returns
    -------
    logfml : ndarray scalar
        log likelihood values with parameters marginalised and at best fit
    theta_map : ndarray (n_components)
        Best fit MAP parameters
    theta_cov : ndarray (n_components, n_components)
        Parameter covariance

    """
    log2pi = np.log(2.0 * np.pi)
    nt = np.shape(M_T)[-2]
    ny = np.count_nonzero(yinvvar)
    nz = np.count_nonzero(zinvvar)
    nm = np.count_nonzero(muinvvar)
    M = np.transpose(M_T)  # (n_pix_y, n_components)
    R = np.transpose(R_T)  # (n_pix_z, n_components)
    Myinv = M * yinvvar[:, None]  # (n_pix_y, n_components)
    Rzinv = R * zinvvar[:, None]  # (n_pix_z, n_components)
    Hbar = (ell**2 * np.matmul(R_T, Rzinv) + np.matmul(M_T, Myinv) +
            np.eye(nt) * muinvvar[:, None])  #  (n_components, n_components)
    etabar = (ell * np.sum(Rzinv * z[:, None], axis=0) +
              np.sum(Myinv * y[:, None], axis=0) + mu * muinvvar
              )  # (n_components)
    theta_map = np.linalg.solve(Hbar, etabar)  # (n_components)
    theta_cov = np.linalg.inv(Hbar)  # (n_components, n_components)
    logdetH = np.sum(logyinvvar) + np.sum(logzinvvar) + np.sum(
        logmuinvvar)  # scalar
    xi1 = -0.5 * (
        (ny + nz + nm) * log2pi - logdetH + np.sum(y * y * yinvvar) +
        np.sum(z * z * zinvvar) + np.sum(mu * mu * muinvvar))  # scalar
    sign, logdetHbar = np.linalg.slogdet(Hbar)
    xi2 = -0.5 * (nt * log2pi - logdetHbar + np.sum(etabar * theta_map))
    logfml = xi1 - xi2
    return logfml, theta_map, theta_cov
Ejemplo n.º 23
0
# confusion_matrix = utils.copy_docstring(
#     tf.math.confusion_matrix,
#     lambda labels, predictions, num_classes=None, weights=None,
#     dtype=tf.int32, name=None: ...)

conj = utils.copy_docstring(tf.math.conj, lambda x, name=None: np.conj(x))

cos = utils.copy_docstring(tf.math.cos, lambda x, name=None: np.cos(x))

cosh = utils.copy_docstring(tf.math.cosh, lambda x, name=None: np.cosh(x))

count_nonzero = utils.copy_docstring(
    tf.math.count_nonzero,
    lambda input, axis=None, keepdims=None, dtype=tf.int64, name=None: (  # pylint: disable=g-long-lambda
        utils.numpy_dtype(dtype)(np.count_nonzero(input, axis))))

cumprod = utils.copy_docstring(tf.math.cumprod, _cumprod)

cumsum = utils.copy_docstring(tf.math.cumsum, _cumsum)

digamma = utils.copy_docstring(tf.math.digamma,
                               lambda x, name=None: scipy_special.digamma(x))

divide = utils.copy_docstring(tf.math.divide,
                              lambda x, y, name=None: np.divide(x, y))

divide_no_nan = utils.copy_docstring(
    tf.math.divide_no_nan,
    lambda x, y, name=None: np.where(  # pylint: disable=g-long-lambda
        onp.broadcast_to(np.equal(y, 0.),
Ejemplo n.º 24
0
    for i in range(args.max_iter):
        # update parameters
        model.step()
        loss = model.objective()

        if not args.silent:
            print(
                "step: {:3d}, loss: {:7.4f}, ||Ax* - b||_2^2: {:6.4f}".format(
                    i + 1, loss, model.feval(model.x)))
        prox_optval.append(loss)

        if i > 1 and np.abs(prox_optval[-1] - prox_optval[-2]) < args.tol:
            break

    print("-" * 40)
    print('Parameters: gamma={}, solver={}'.format(args.gamma, args.opt))
    print('||Ax* - b||_2^2: {:6.3f}, Obj: {:6.3f}'.format(
        model.feval(model.x), prox_optval[-1]))
    if args.opt == 'ADMM':
        print("nnz of x*: ", jnp.count_nonzero(model.z))
    else:
        print("nnz of x*: ", jnp.count_nonzero(model.x))

    print('Writing the results...')
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    output_file = os.path.join(args.output_dir,
                               '{}_{}.npy'.format(args.opt, args.gamma))
    jnp.save(output_file, jnp.array(prox_optval))
    print('Done')
def loss_fn(
    output_logits,
    targets,
    valid_mask,
    num_nodes,
    captured,
    negative_example_weight=1,
    focal_loss_gamma=0.0,
):
    """Compute loss and single-batch metrics for some outputs.

  Args:
    output_logits: Binary logits produced by the model.
    targets: Model targets.
    valid_mask: Mask determining which outputs are valid.
    num_nodes: How many nodes there are in each example.
    captured: Ignored
    negative_example_weight: Weight to assign to a negative example when
      computing the loss. Positive examples always get weight 1.
    focal_loss_gamma: Focusing parameter for the focal loss, as described in Lin
      et al. (2018). If zero, uses standard cross-entropy loss.

  Returns:
    Tuple (loss, metrics_dict).
  """
    del captured
    num_targets = jnp.count_nonzero(targets)
    # Compute cross entropy.
    unmasked_nll = model_util.binary_logit_cross_entropy(
        output_logits, targets)
    if focal_loss_gamma:
        # (1-p_correct)**gamma = (-(p-1))**gamma = (-expm1(log(p)))**gamma
        focus_term = jnp.power(-jnp.expm1(-unmasked_nll), focal_loss_gamma)
        unmasked_nll = unmasked_nll * focus_term
    # Mask the results so that they only count nodes that exist.
    masked_nll = unmasked_nll * valid_mask
    # Primary loss: Sum of nll over all nodes. We use sum because most of the
    # edges are easy negatives.
    positive_nll = jnp.sum(
        jnp.where(targets, masked_nll, jnp.zeros_like(masked_nll)))
    negative_nll = jnp.sum(
        jnp.where(targets, jnp.zeros_like(masked_nll), masked_nll))
    reweighted_nll = positive_nll + negative_example_weight * negative_nll
    binary_nll = jnp.sum(reweighted_nll)
    # Compute additional metrics to track learning progress.
    # Average NLL of target edges:
    avg_nll_per_target = positive_nll / num_targets
    # Average NLL of non-target edges:
    num_non_targets = num_nodes**2 - num_targets
    avg_nll_per_non_target = negative_nll / num_non_targets
    # Max error for any edge prediction:
    worst_nll = jnp.max(masked_nll)

    loss = binary_nll

    # Ratio of positive to negative targets. If this is equal to
    # negative_example_weight, the positive and negative examples will have the
    # same total weight.
    positive_per_negative = num_targets / num_non_targets
    # Precision and recall at 0.1 threshold
    thresholded_preds = output_logits > jax.scipy.special.logit(0.1)
    count_target_pred = jnp.count_nonzero(thresholded_preds & targets)
    count_pred = jnp.count_nonzero(thresholded_preds & valid_mask.astype(bool))
    precision = count_target_pred / count_pred
    recall = count_target_pred / num_targets
    return loss, {
        "avg_per_target":
        avg_nll_per_target,
        "avg_per_non_target":
        avg_nll_per_non_target,
        "worst":
        worst_nll,
        "positive_per_negative":
        positive_per_negative,
        "effective_p_model_given_target":
        jnp.exp(-avg_nll_per_target),
        "effective_p_model_given_nontarget":
        1 - jnp.exp(-avg_nll_per_non_target),
        "batch_clf_thresh_at_0.1/precision":
        precision,
        "batch_clf_thresh_at_0.1/recall":
        recall,
        "batch_clf_thresh_at_0.1/f1":
        2 * (precision * recall) / (precision + recall),
    }
Ejemplo n.º 26
0
def count_nonzero(a, axis=None, keepdims=False):
  if isinstance(a, JaxArray): a = a.value
  return jnp.count_nonzero(a, axis=axis, keepdims=keepdims)
Ejemplo n.º 27
0
 def cond_fun(s: _BasicState):
     rerr = compute_residual_error(s.R, s.eig_vals, s.X)
     num_converged = jnp.count_nonzero(rerr < tol)
     return jnp.logical_and(s.iteration < max_iters, num_converged < k)