Ejemplo n.º 1
0
def q_learning_loss(q_value_vec, target_q_value_vec, action, reward, done):
    td_target = reward + gamma * jnp.amax(target_q_value_vec) * (1. - done)
    td_error = jax.lax.stop_gradient(td_target) - q_value_vec[action]
    return jnp.square(td_error)
Ejemplo n.º 2
0
def softnorm(v):
    return jnp.amax([jnp.linalg.norm(v), SOFTNORMTHRESH])
Ejemplo n.º 3
0
def stable_scaled_log_softmax(x, tau, axis=-1):
    max_x = jnp.amax(x, axis=axis, keepdims=True)
    y = x - max_x
    tau_lse = max_x + tau * jnp.log(
        jnp.sum(jnp.exp(y / tau), axis=axis, keepdims=True))
    return x - tau_lse
Ejemplo n.º 4
0
def stable_softmax(x, tau, axis=-1):
    max_x = jnp.amax(x, axis=axis, keepdims=True)
    y = x - max_x
    return jax.nn.softmax(y / tau, axis=axis)
Ejemplo n.º 5
0
 def apply_fun(params, inputs, **kwargs):
     out_dim = inputs.shape[-1] // filter
     r = np.reshape(inputs, inputs.shape[:-1] + (out_dim, filter))
     return np.amax(r, axis=-1)
Ejemplo n.º 6
0
def propagate_masks(
    mask,
    param_names = WEIGHT_PARAM_NAMES
):
  """Accounts for implicitly pruned neurons in a model's weight masks.

  When neurons are randomly ablated in one layer, they can effectively ablate
  neurons in the next layer if in effect all incoming weights of a neuron are
  zero. This method accounts for this by propagating forward mask information
  through the entire model.

  Args:
    mask: Model masks to check, in same pytree structure as Model.params.
    param_names: List of param keys in mask to count.

  Returns:
   A refined model mask with weights that are effectively ablated in the
   original mask set to zero.
  """

  flat_mask = flax.traverse_util.flatten_dict(mask)
  mask_layer_list = list(flat_mask.values())
  mask_layer_keys = list(flat_mask.keys())

  mask_layer_param_names = [layer_param[-1] for layer_param in mask_layer_keys]

  for param_name in param_names:
    # Find which of the param arrays correspond to leaf nodes with this name.
    param_indices = [
        i for i, names in enumerate(mask_layer_param_names)
        if param_name in names
    ]

    for i in range(1, len(param_indices)):
      last_weight_mask = mask_layer_list[param_indices[i - 1]]
      weight_mask = mask_layer_list[param_indices[i]]

      if last_weight_mask is None or weight_mask is None:
        continue

      last_weight_mask_reshaped = jnp.reshape(last_weight_mask,
                                              (-1, last_weight_mask.shape[-1]))

      # Neurons with any outgoing weights from previous layer.
      alive_incoming = jnp.sum(last_weight_mask_reshaped, axis=0) != 0

      # Combine effective mask of previous layer with neuron's current mask.
      if len(weight_mask.shape) > 2:
        # Convolutional layer, only consider channel-wise masks, if any spatial
        # weight is non-zero that channel is considered non-masked.
        spatial_dim = len(weight_mask.shape) - 2
        new_weight_mask = alive_incoming[:, jnp.newaxis] * jnp.amax(
            weight_mask, axis=tuple(range(spatial_dim)))
        new_weight_mask = jnp.tile(new_weight_mask,
                                   weight_mask.shape[:-2] + (1, 1))
      else:
        # Check for case of dense following convolution, i.e. spatial input into
        # dense, to prevent b/156135283. Must use convolution for these layers.
        if len(last_weight_mask.shape) > 2:
          raise ValueError(
              'propagate_masks requires knowledge of the spatial '
              'dimensions of the previous layer. Use a functionally equivalent '
              'conv. layer in place of a dense layer in a model with a mixed '
              'conv/dense setting.')
        new_weight_mask = alive_incoming[:, jnp.newaxis] * weight_mask

      mask_layer_list[param_indices[i]] = jnp.reshape(
          new_weight_mask, mask_layer_list[param_indices[i]].shape)

  return flax.traverse_util.unflatten_dict(
      dict(zip(mask_layer_keys, mask_layer_list)))
Ejemplo n.º 7
0
def logsumexp(x, axis=0, keepdims=False):
    # TODO: remove when https://github.com/google/jax/pull/2260 merged upstream
    x_max = lax.stop_gradient(np.amax(x, axis=axis, keepdims=True))
    y = np.log(np.sum(np.exp(x - x_max), axis=axis, keepdims=True)) + x_max
    return y if keepdims else y.squeeze(axis=axis)
    dX = np.matmul(
        K_trte.T,
        np.linalg.solve(
            np.transpose(L),
            np.linalg.solve(L,
                            X_train_f.flatten('F') - X_tr.flatten('F'))))
    X_ode1 = X_ode_i[i, :, ind[0]] / model.max_X[0]
    X_ode2 = X_ode_i[i, :, ind[1]] / model.max_X[1]
    X_ode3 = X_ode_i[i, :, ind[2]] / model.max_X[2]
    X_ode = np.concatenate((X_ode1, X_ode2, X_ode3), axis=0)

    mu = X_ode.flatten('F') + dX
    K = K_te - np.matmul(
        K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L, K_trte)))
    pred = onp.random.multivariate_normal(mu, K)
    if not math.isnan(np.amax(np.abs(pred))):
        Npred_GP_f += 1
        X_pred_GP.append(pred.reshape((D, Nt_test)).T)
        Y_PCA_GP.append(
            np.matmul(
                np.matmul(
                    dat_max * np.array(model.max_X) * pred.reshape(
                        (D, Nt_test)).T, np.diag(np.sqrt(v_pca[:, 0]))),
                u_pca))

X_pred_GP = np.array(X_pred_GP)
mean_prediction_GP, std_prediction_GP = np.mean(X_pred_GP,
                                                axis=0), np.std(X_pred_GP,
                                                                axis=0)
lower_GP = mean_prediction_GP - 2.0 * std_prediction_GP
upper_GP = mean_prediction_GP + 2.0 * std_prediction_GP
def loopy_belief_propagation(tests, groups,
                             base_infection_rate,
                             sensitivity, specificity,
                             min_iterations, max_iterations,
                             atol):
  """LBP approach to compute approximate marginal of posterior distribution.

  Outputs marginal approximation of posterior distribution using all tests'
  history and test setup parameters.

  Args:
    tests : np.ndarray<bool>[n_groups] results stored as a vector of booleans
    groups : np.ndarray<bool>[n_groups, n_patients] matrix of groups
    base_infection_rate : np.ndarray<float> [1,] or [n_patients,] infection rate
    sensitivity : np.ndarray<float> [?,] of sensitivity per group size
    specificity : np.ndarray<float> [?,] of specificity per group size
    min_iterations: int, min number of belief propagation iterations
    max_iterations: int, max number of belief propagation iterations
    atol: float, elementwise tolerance for the difference between two
      consecutive iterations.

  Returns:
    two vectors of marginal probabilities for all n_patients, obtained
    as consecutive evaluations of the LBP algorithm after n_iter and n_iter+1
    iterations.
  """
  n_groups, n_patients = groups.shape
  if np.size(groups) == 0:
    if np.size(base_infection_rate) == 1:  # only one rate
      marginal = base_infection_rate * np.ones(n_patients)
      return marginal, 0
    elif np.size(base_infection_rate) == n_patients:
      return base_infection_rate, 0
    else:
      raise ValueError("Improper size for vector of base infection rates")

  mu = -jax.scipy.special.logit(base_infection_rate)

  groups_size = np.sum(groups, axis=1)
  sensitivity = utils.select_from_sizes(sensitivity, groups_size)
  specificity = utils.select_from_sizes(specificity, groups_size)
  gamma0 = np.log(sensitivity + specificity - 1) - np.log(1 - sensitivity)
  gamma1 = np.log(sensitivity + specificity - 1) - np.log(sensitivity)
  gamma = tests * gamma1 + (1 - tests) * gamma0
  test_sign = 1 - 2 * tests[:, np.newaxis]

  # Initialization
  alphabeta = np.zeros((2, n_groups, n_patients))
  alpha_beta_iteration = [alphabeta, 0]

  # return marginal from alphabeta
  def marginal_from_alphabeta(alphabeta):
    beta_bar = np.sum(alphabeta[1, :, :], axis=0)
    return jax.scipy.special.expit(-beta_bar - mu)

  # lbp loop
  def lbp_loop(_, alphabeta):
    alpha = alphabeta[0, :, :]
    beta = alphabeta[1, :, :]

    # update alpha
    beta_bar = np.sum(beta, axis=0)
    alpha = jax.nn.log_sigmoid(beta_bar - beta + mu)
    alpha *= groups

    # update beta
    alpha_bar = np.sum(alpha, axis=1, keepdims=True)
    beta = np.log1p(test_sign *
                    np.exp(-alpha + alpha_bar + gamma[:, np.newaxis]))
    beta *= groups
    return np.stack((alpha, beta), axis=0)

  def cond_fun(alpha_beta_iteration):
    alphabeta, iteration = alpha_beta_iteration
    marginal = marginal_from_alphabeta(alphabeta)
    marginal_plus_one_iteration = marginal_from_alphabeta(
        lbp_loop(0, alphabeta))
    converged = np.allclose(marginal, marginal_plus_one_iteration, atol=atol)
    return (not converged) and (iteration < max_iterations)

  def body_fun(alpha_beta_iteration):
    alphabeta, iteration = alpha_beta_iteration
    alphabeta = jax.lax.fori_loop(0, min_iterations, lbp_loop, alphabeta)
    iteration += min_iterations
    return [alphabeta, iteration]

  # Run LBP while loop
  while cond_fun(alpha_beta_iteration):
    alpha_beta_iteration = body_fun(alpha_beta_iteration)

  alphabeta, _ = alpha_beta_iteration

  # Compute two consecutive marginals
  marginal = marginal_from_alphabeta(alphabeta)
  marginal_plus_one_iteration = marginal_from_alphabeta(lbp_loop(0, alphabeta))

  return marginal, np.amax(np.abs(marginal - marginal_plus_one_iteration))
Ejemplo n.º 10
0
def _amax(x, dim, keepdims=False):
    return np.amax(x, axis=dim, keepdims=keepdims)
Ejemplo n.º 11
0
def calculate_td_error(q_value_vec, target_q_value_vec, action, reward):
    td_target = reward + gamma * jnp.amax(target_q_value_vec)
    td_error = td_target - q_value_vec[action]
    return jnp.abs(td_error)
Ejemplo n.º 12
0
 def test_soft_coulomb(self, center):
     grids = jnp.linspace(-10, 10, 201)
     soft_coulomb_interaction = utils.soft_coulomb(grids - center)
     self.assertAlmostEqual(float(jnp.amax(soft_coulomb_interaction)), 1)
     self.assertAlmostEqual(
         float(grids[jnp.argmax(soft_coulomb_interaction)]), center)
Ejemplo n.º 13
0
 def test_gaussian(self, center, sigma, expected_max_value):
     gaussian = utils.gaussian(grids=jnp.linspace(-10, 10, 201),
                               center=center,
                               sigma=sigma)
     self.assertAlmostEqual(float(jnp.sum(gaussian) * 0.1), 1, places=5)
     self.assertAlmostEqual(float(jnp.amax(gaussian)), expected_max_value)
Ejemplo n.º 14
0
    def calculate_emt(R: Array, box: Array, **kwargs) -> Array:
        """Calculate the elastic modulus tensor.

    energy_fn(R) corresponds to the state around which we are expanding
      
    Args:
      R: array of shape (N,dimension) of particle positions. This does not
        generalize to arbitrary dimensions and is only implemented for
          dimension == 2
          dimension == 3
      box: A box specifying the shape of the simulation volume. Used to infer
        the volume of the unit cell.
    
    Return: C or the tuple (C,converged)
      where C is the Elastic modulus tensor as an array of shape (dimension,
      dimension,dimension,dimension) that respects the major and minor 
      symmetries, and converged is a boolean flag (see above).

    """
        if not (R.shape[-1] == 2 or R.shape[-1] == 3):
            raise AssertionError('Only implemented for 2d and 3d systems.')

        if R.dtype is not jnp.dtype('float64'):
            logging.warning('Elastic modulus calculations can sometimes lose '
                            'precision when not using 64-bit precision.')

        dim = R.shape[-1]

        def setup_energy_fn_general(strain_tensor):
            I = jnp.eye(dim, dtype=R.dtype)

            @jit
            def energy_fn_general(R, gamma):
                perturbation = I + gamma * strain_tensor
                return energy_fn(R, perturbation=perturbation, **kwargs)

            return energy_fn_general

        def get_affine_response(strain_tensor):
            energy_fn_general = setup_energy_fn_general(strain_tensor)
            d2U_dRdgamma = jacfwd(jacrev(energy_fn_general, argnums=0),
                                  argnums=1)(R, 0.)
            d2U_dgamma2 = jacfwd(jacrev(energy_fn_general, argnums=1),
                                 argnums=1)(R, 0.)
            return d2U_dRdgamma, d2U_dgamma2

        strain_tensors = _get_strain_tensor_list(dim, R.dtype)
        d2U_dRdgamma_all, d2U_dgamma2_all = vmap(get_affine_response)(
            strain_tensors)

        #Solve the system of equations.
        energy_fn_Ronly = partial(energy_fn, **kwargs)

        def hvp(f, primals, tangents):
            return jvp(grad(f), primals, tangents)[1]

        def hvp_specific_with_tether(v):
            return hvp(energy_fn_Ronly, (R, ), (v, )) + tether_strength * v

        non_affine_response_all = jsp.sparse.linalg.cg(
            vmap(hvp_specific_with_tether), d2U_dRdgamma_all, tol=cg_tol)[0]
        #The above line should be functionally equivalent to:
        #H0=hessian(energy_fn)(R, box=box, **kwargs).reshape(R.size,R.size) \
        #    + tether_strength * jnp.identity(R.size)
        #non_affine_response_all = jnp.transpose(jnp.linalg.solve(
        #   H0,
        #   jnp.transpose(d2U_dRdgamma_all))
        #   )

        residual = jnp.linalg.norm(
            vmap(hvp_specific_with_tether)(non_affine_response_all) -
            d2U_dRdgamma_all)
        converged = residual / jnp.linalg.norm(d2U_dRdgamma_all) < cg_tol

        response_all = d2U_dgamma2_all - jnp.einsum(
            "nij,nij->n", d2U_dRdgamma_all, non_affine_response_all)

        vol_0 = quantity.volume(dim, box)
        response_all = response_all / vol_0
        C = _convert_responses_to_elastic_constants(response_all)

        # JAX does not allow proper runtime error handling in jitted function.
        # Instead, if the user requests a gradient check and the check fails,
        # we convert C into jnp.nan's. While this doesn't raise an exception,
        # it at least is very "loud".
        if gradient_check is not None:
            maxgrad = jnp.amax(jnp.abs(grad(energy_fn)(R, **kwargs)))
            C = lax.cond(maxgrad > gradient_check, lambda _: jnp.nan * C,
                         lambda _: C, None)

        if check_convergence:
            return C, converged
        else:
            return C
Ejemplo n.º 15
0
def _relax_primitive(index: int, out_bounds: bound_propagation.Bound,
                     primitive: jax.core.Primitive, *args,
                     **kwargs) -> RelaxVariable:
    """Generates the relaxation for a given primitive op.

  Args:
    index: Integer identifying the computation node.
    out_bounds: Concrete bounds on the outputs of the primitive op.
    primitive: jax primitive
    *args: Arguments of the primitive, wrapped as RelaxVariables
    **kwargs: Keyword Arguments of the primitive.
  Returns:
    `RelaxVariable` that contains the output of this primitive for the
    relaxation, with all the constraints linking the output to inputs.
  """
    # Create variable for output of this primitive
    out_variable = RelaxVariable(index, out_bounds)
    # Create constraints linking output and input of primitive
    constraints = []
    if primitive in _passthrough_list:
        invar = args[0]
        constraints = [
            RelaxActivationConstraint(out_variable, invar,
                                      jnp.ones_like(invar.lower),
                                      jnp.zeros_like(invar.lower), 0)
        ]
    elif primitive in _affine_primitives_list:
        results = _get_linear(primitive, out_bounds.lower[0, ...], *args,
                              **kwargs)
        for i, (bias, coeffs) in enumerate(results):
            # Coefficients of the input variable(s).
            vars_and_coeffs = [(arg, coeff)
                               for arg, coeff in zip(args, coeffs)
                               if isinstance(arg, RelaxVariable)]
            # Equate with the output variable, by using a coefficient of -1.
            out_coeff = (np.array([i], dtype=np.int64), np.array([-1.]))
            vars_and_coeffs.append((out_variable, out_coeff))
            constraints.append(LinearConstraint(vars_and_coeffs, bias, 0))
    elif primitive in _activation_list:
        if len(args) == 2:
            # Generate relu relaxation
            # Relu is implemented as max(x, 0) and both are treated as arguments
            # We find the one that is a RelaxVariable (the other would just be
            # a JAX literal)
            if isinstance(args[0], RelaxVariable):
                invar = args[0]
                if jnp.amax(jnp.abs(args[1])) != 0.:
                    raise NotImplementedError(
                        'Unsupported activation function')
            elif isinstance(args[1], RelaxVariable):
                invar = args[1]
                if jnp.amax(jnp.abs(args[0])) != 0.:
                    raise NotImplementedError(
                        'Unsupported activation function')
            else:
                raise NotImplementedError(
                    'Activations with multiple arguments not'
                    'supported.')
            slope, bias = _get_relu_relax(invar.lower, invar.upper)
            constraints += [
                RelaxActivationConstraint(out_variable, invar,
                                          jnp.zeros_like(invar.lower),
                                          jnp.zeros_like(invar.lower),
                                          1),  # relu(x) >= 0
                RelaxActivationConstraint(out_variable, invar,
                                          jnp.ones_like(invar.lower),
                                          jnp.zeros_like(invar.lower),
                                          1),  # relu(x) >= x
                RelaxActivationConstraint(out_variable, invar, slope, bias, -1)
            ]  # upper chord of triangle relax
        else:
            raise NotImplementedError('Activations with multiple arguments not'
                                      'supported.')
    out_variable.set_constraints(constraints)
    return out_variable