Example #1
def prob_inf_house_size_iter(state, hh_sizes_, house_dist):
  """ Function that computes the probability of an individual getting infected given their household size.
  @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation
  @type : Device Array of shape (# of iterations, population size)
  @param hh_sizes_ : An array which keeps track of the size of each individual's household
  @type : Array of length = population size
  @param house_dist : Distribution of household sizes 
  @type : List or 1D array
  @return : Returns the probability of infection given household size and the mean probability of infection
  @type : Tuple
  hh_sizes = np.asarray(hh_sizes_)
  iterations = len(state)
  prob_hh_size = np.zeros((iterations, len(house_dist)))
  pop = len(state[0])
  mean_inf_prob = np.zeros(iterations)
  # First compute the probability of the household size given that the person was infected and then use Bayes rule
  for i in range(iterations):
    if_inf = np.where(state[i] > 0)[0]
    inf_size = len(if_inf)
    hh_inf = hh_sizes[if_inf]
    prob = ((np.array(np.unique(hh_inf, return_counts= True))[-1])/inf_size) * (inf_size/pop) * (1/house_dist) # Bayes rule
    prob_hh_size = index_add(prob_hh_size, i, prob)
    mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop)

  # Returns the probability of infection given household size
  return np.average(prob_hh_size, axis = 0) , np.average(mean_inf_prob)
Example #2
def prob_inf_workplace_open(indx_active, state):
  """ Function that computes the probability of infection for an individual who is still working during intervention.
  @param indx_active : Numpy array with indices of individuals still working during intervention
  @type : 1D array
  @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation
  @type : Device Array of shape (# of iterations, population size)
  @return : Returns the probability of infection for individuals working during intervention and the population average, averaged over the number of iterations
  @type : Tuple
  iterations = len(state)
  prob_inf_work = np.zeros(iterations)
  pop = len(state[0])
  mean_inf_prob = np.zeros(iterations)

  for i in range(iterations):

    # Get indices of infected people
    if_inf = np.where(state[i] > 0)[0]
    inf_size = len(if_inf)
    # Calculate the conditional probability
    prob = (sum(np2.isin(indx_active, if_inf))/inf_size) * (inf_size/pop) * (pop/len(indx_active))
    prob_inf_work = index_add(prob_inf_work, i, prob)
    mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop)

  return np.average(prob_inf_work), np.average(mean_inf_prob)
Example #3
    def init_state(cluster_id):
        num_k = jnp.sum(mask & (cluster_id == a_k[:, None]), axis=-1)
        mu_k = vmap(lambda k: jnp.average(
            points, axis=0, weights=k == cluster_id))(a_k)
        C_k = vmap(lambda k, mu_k: jnp.linalg.pinv(
            jnp.average((points - mu_k)[:, :, None] *
                        (points - mu_k)[:, None, :],
                        weights=k == cluster_id)))(a_k, mu_k)
        logdetC_k = vmap(
            lambda C_k: jnp.sum(jnp.log(jnp.linalg.eigvals(C_k).real)))(C_k)
        precision_k = C_k * num_k[:, None, None]
        # K, N
        log_maha_k = vmap(lambda mu_k, precision_k: jnp.log(
            vmap(lambda point: (point - mu_k) @ precision_k @ (point - mu_k))
            (points)))(mu_k, precision_k)
        log_f_k = log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k)
        log_VE_k = vmap(log_ellipsoid_volume)(logdetC_k, num_k, log_f_k)

        log_VS_k = jnp.log(num_k) - jnp.log(num_S)
        return State(i=jnp.asarray(0),
                     done=num_S < K * (D + 1),
Example #4
def inertial_restraint(conf, params, box, lamb, a_idxs, b_idxs, masses, k):

    a_conf = conf[a_idxs]
    b_conf = conf[b_idxs]

    a_masses = masses[a_idxs]
    b_masses = masses[b_idxs]

    a_com_conf = a_conf - np.average(a_conf, axis=0, weights=a_masses)
    b_com_conf = b_conf - np.average(b_conf, axis=0, weights=b_masses)

    a_tensor = inertia_tensor(a_com_conf, a_masses)
    b_tensor = inertia_tensor(b_com_conf, b_masses)

    a_eval, a_evec = np.linalg.eigh(a_tensor)
    b_eval, b_evec = np.linalg.eigh(b_tensor)

    # eigenvalues are needed for derivatives
    # a_eval, a_evec = dsyevv3(a_tensor)
    # b_eval, b_evec = dsyevv3(b_tensor)

    loss = []
    # (ytz): .T is because the eigenvectors are stored in columns
    for a, b in zip(a_evec.T, b_evec.T):
        delta = 1 - np.abs(np.dot(a, b))
        loss.append(delta * delta)

    return np.sum(loss) * k
Example #5
def split_background(background_double_exp):

    # split the average from 2 exposures:
    bkg_avg0 = np.average(background_double_exp[0::2], axis=0)
    bkg_avg1 = np.average(background_double_exp[1::2], axis=0)

    return np.array([bkg_avg0, bkg_avg1])
Example #6
def get_peaks_iter(soln,tvec,int=0,Tint=0,loCI=5,upCI=95):

  calculates the peak prevalence for a multiple runs, with or without an intervention
  soln: 3D array of values for each iteration for each variable at each timepoint
  tvec: 1D vector of timepoints
  ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N)
  scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values)
  int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0
  Tint: Optional, timepoint (days) at which intervention was started
  loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval


  if int==0:


  # Final values
  print('Final recovered: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(soln[:,-1,6]), 100*np.percentile(soln[:,-1,6],loCI), 100*np.percentile(soln[:,-1,6],upCI)))
  print('Final deaths: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(soln[:,-1,5]), 100*np.percentile(soln[:,-1,5],loCI), 100*np.percentile(soln[:,-1,5],upCI)))
  print('Remaining infections: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(

  # Peak prevalence
  print('Peak I1: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  print('Peak I2: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  print('Peak I3: {:4.2f}% [{:4.2f}, {:4.2f}]'.format(
      100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI)))
  # Timing of peaks
  print('Time of peak I1: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
      np.average(tpeak),np.median(tpeak), np.percentile(tpeak,loCI),np.percentile(tpeak,upCI)))
  print('Time of peak I2: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
  print('Time of peak I3: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format(
  # Time when all the infections go extinct
  time_all_extinct = np.array(get_extinction_time(all_cases,0))*delta_t-time_int

  print('Time of extinction of all infections post intervention: {:4.2f} days  [{:4.2f}, {:4.2f}]'.format(
Example #7
def prob_inf_working_hh_member(indx_active, state, house_indices, household_sizes):
    """ Function that computes the probability of infection for individuals living with a household member working during intervention along with the probability of infection for individuals who aren't working and
    living with a working household member.
    @param indx_active : Numpy array with indices of individuals still working during intervention
    @type : 1D array
    @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation
    @type : Device Array of shape (# of iterations, population size)
    @param house_indices : Numpy array that keeps track of the house an individual belongs to
    @type : 1D array
    @param household_sizes : Numpy array that keeps track of the size of each individual's household
    @type : 1D array
    @return : Returns the probability of infection for individuals living with working household members, probability for non-workers living with no working household member, and the population average, averaged over the number of iterations
    @type : Tuple
    iterations = len(state)
    prob_inf = np.zeros(iterations)
    prob_inf_not_working = np.zeros(iterations)
    pop = len(state[0])
    mean_inf_prob = np.zeros(iterations)

    for i in range(iterations):

        # Get indices of infected people
        if_inf = np.where(state[i] > 0)[0]
        inf_size = len(if_inf)

        # Houses of people who are still working 
        house_working = np2.unique(house_indices[indx_active])
        # Indices of all people who aren't working and their house index
        not_working = np2.setdiff1d(np2.arange(0, pop, 1), indx_active)
        house_not_working = house_indices[not_working]

        # Probability of living with atleast one working household member
        prob_house_working = (sum(np2.isin(house_not_working, house_working))/pop)

        # Probability of living no working household member
        prob_house_not_working = (sum(~np2.isin(house_not_working, house_working))/pop)

        # Indices of infected people who aren't working and their house index
        if_inf_not_working = np2.setdiff1d(if_inf, indx_active)
        house_inf_not_working = house_indices[if_inf_not_working]

        # Probability of infection given atleast one hh member was working during intervention
        prob_1 = (sum(np2.isin(house_inf_not_working, house_working))/inf_size) * (inf_size/pop) * (1/prob_house_working)
        prob_inf = index_add(prob_inf, i, prob_1)

        # Probability of infection given no hh member was working during intervention
        prob_2 = (sum(~np2.isin(house_inf_not_working, house_working))/inf_size) * (inf_size/pop) * (1/prob_house_not_working)
        prob_inf_not_working = index_add(prob_inf_not_working, i, prob_2)

        # Population average probability of infection
        mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop)

    return np.average(prob_inf), np.average(prob_inf_not_working), np.average(mean_inf_prob)
Example #8
def analytic_restraint_force(conf, params, box, lamb, a_idxs, b_idxs, masses,

    a_conf = conf[a_idxs]
    b_conf = conf[b_idxs]

    a_masses = masses[a_idxs]
    b_masses = masses[b_idxs]

    a_com_conf = a_conf - np.average(a_conf, axis=0, weights=a_masses)
    b_com_conf = b_conf - np.average(b_conf, axis=0, weights=b_masses)

    a_tensor = inertia_tensor(a_com_conf, a_masses)
    b_tensor = inertia_tensor(b_com_conf, b_masses)

    # a_eval, a_evec = np.linalg.eigh(a_tensor)
    # b_eval, b_evec = np.linalg.eigh(b_tensor)

    # eigenvalues are needed for derivatives
    a_eval, a_evec = dsyevv3(a_tensor)
    b_eval, b_evec = dsyevv3(b_tensor)

    loss = []
    for a, b in zip(a_evec.T, b_evec.T):
        delta = 1 - np.abs(np.dot(a, b))
        loss.append(delta * delta)

    dl_daevec_T = []
    dl_dbevec_T = []
    for a, b in zip(a_evec.T, b_evec.T):
        delta = 1 - np.abs(np.dot(a, b))
        prefactor = -np.sign(np.dot(a, b)) * 2 * delta * k
        dl_daevec_T.append(prefactor * b)
        dl_dbevec_T.append(prefactor * a)

    dl_daevec = np.transpose(np.array(dl_daevec_T))
    dl_dbevec = np.transpose(np.array(dl_dbevec_T))

    dl_datensor = grad_eigh(a_eval, a_evec, np.array(dl_daevec))
    dl_dbtensor = grad_eigh(b_eval, b_evec, np.array(dl_dbevec))

    dl_da_com_conf = grad_inertia_tensor(a_com_conf, a_masses, dl_datensor)
    dl_db_com_conf = grad_inertia_tensor(b_com_conf, b_masses, dl_dbtensor)

    du_dx = onp.zeros_like(conf)

    du_dx[a_idxs] += dl_da_com_conf
    du_dx[b_idxs] += dl_db_com_conf

    print("ref du_dx", du_dx)
    # conservative forces are not affected by center of mass changes.
    # the vjp w.r.t. to the center of mass yields zeros since sum dx=0, dy=0 and dz=0
    # return dl_da_com_conf, dl_db_com_conf
    return du_dx
Example #9
 def get_mu_and_C(k):
     weights = (cluster_id == k) & mask
     mu = jnp.average(points, weights=weights, axis=0)
     dist = points - mu
     Cov = jnp.average(dist[:, :, None] * dist[:, None, :],
     C = jnp.linalg.pinv(Cov)
     mu = jnp.where(num_k[k] == 0, 0., mu)
     C = jnp.where(num_k[k] < D + 1, 0., C)
     return mu, C
Example #10
 def internal(t, pos, vel):
   tM = np.full((atoms, 1), t)
   arr = np.concatenate((tM, idx, pos), axis = 1)
   (potential, kinetic) = calcEnergy(pos, vel)
   return (
       t * 1e12,
       potential * mass_unit * (dist_unit / time_unit) ** 2 * 1e21,
       kinetic * mass_unit * (dist_unit / time_unit) ** 2 * 1e21,
Example #11
def centroid_restraint(conf, params, box, lamb, masses, group_a_idxs,
                       group_b_idxs, kb, b0):

    xi = conf[group_a_idxs]
    xj = conf[group_b_idxs]

    avg_xi = np.average(xi, axis=0, weights=masses[group_a_idxs])
    avg_xj = np.average(xj, axis=0, weights=masses[group_b_idxs])

    dx = avg_xi - avg_xj
    dij = np.sqrt(np.sum(dx * dx))
    delta = dij - b0

    return kb * delta * delta
Example #12
  def test_against_tf_ctc_loss(self):
    batchsize = 8
    timesteps = 150
    labelsteps = 25
    nclasses = 400
    logits = np.random.randn(batchsize, timesteps, nclasses)
    logprobs = jax.nn.log_softmax(logits)
    logprob_paddings = np.zeros((batchsize, timesteps))
    labels = np.random.randint(
        1, nclasses, size=(batchsize, labelsteps)).astype(np.int32)
    label_paddings = np.zeros((batchsize, labelsteps))

    inputs = [logprobs, logprob_paddings, labels, label_paddings]

    jax_per_seq, unused_aux_vars = ctc_objectives.ctc_loss(*inputs)
    tf_per_seq = tf_ctc_loss(*inputs)
    self.assertAllClose(jax_per_seq.squeeze(), tf_per_seq.squeeze())

    average_tf_ctc_loss = lambda *args: jnp.average(tf_ctc_loss(*args))
    jax_dloss = jax.grad(average_ctc_loss)
    tf_dloss = jax.grad(average_tf_ctc_loss)

    jax_dlogits = jax_dloss(*inputs)
    tf_dlogits = tf_dloss(*inputs)
    # Relative error check is disabled as numerical errors explodes when a
    # probability computed from the input logits is close to zero.
    self.assertAllClose(jax_dlogits, tf_dlogits, rtol=0.0, atol=1e-4)
 def __init__(self, means, covs, weights):
     means, covs are np arrays or lists of length k, with entries of shape
     (d,) and (d, d) respectively. (e.g. covs can be array of shape (k, d, d))
     means, covs, weights = self._check_and_reshape_args(
         means, covs, weights)
     self.d = len(means[0])
     self.expectations = self.compute_expectations(means, covs, weights)
     self.mean = self.expectations[0]
     # recall Cov(X) = E[XX^T] - mu mu^T =
     # sum_over_components(Cov(Xi) + mui mui^T) - mu mu^T
     mumut = np.einsum("ki,kj->kij", means, means)  # shape (k, d, d)
     self.cov = np.average(covs + mumut, weights=weights, axis=0) \
         - np.outer(self.mean, self.mean)
     self.threadkey = random.PRNGKey(0)
     self.means = means
     self.covs = covs
     self.weights = weights
     self.num_components = len(weights)
     self.sample_metrics = dict()
     self.tfp_dist = tfd.MixtureSameFamily(
             loc=self.means, covariance_matrix=self.covs))
Example #14
def get_mean_sig_bounds(arr, dim, weights, sig_multiple=1.):
    mean_arr = np.expand_dims(np.average(arr, dim, weights), dim)
    sig_arr = np.sqrt(np.mean(weights * (arr - mean_arr)**2, dim))
    mean_arr = mean_arr.squeeze()
    sig_upper_arr = mean_arr + sig_arr * sig_multiple
    sig_lower_arr = mean_arr - sig_arr * sig_multiple
    return mean_arr, sig_arr, sig_upper_arr, sig_lower_arr
Example #15
def filter_bblocks(data):
    # vertical stripes
    yy = np.reshape(data[:, :, 11], (nrcols, 192))
    # clip and smooth

    filter_strength = 3

    bkgthr = filter_strength  # background threshold

    # 2d version deviates significantly from original
    # gg2 = np.outer(gg,gg)
    # gg2 /= np.sum(gg2)

    yy_s = conv2d(np.clip(yy, -bkgthr, bkgthr), gg)
    yy_s = np.reshape(yy_s, (nrcols, nbmux, 1))

    data_out = data - yy_s
    #data_out = data#-yy_s
    yy_avg = np.reshape(
        np.average(np.clip(data_out[1:10, :, :], 0, 2 * bkgthr), axis=0),
        (1, 192, 12))
    data_out -= yy_avg

    data_out *= data_out > filter_strength
    return data_out  #-yy_s-yy_avg
Example #16
    def _RuLSIF(self, x, y, alpha, s_sigma, s_lambda):
        if len(s_sigma) == 1 and len(s_lambda) == 1:
            sigma = s_sigma[0]
            lambda_ = s_lambda[0]
            optimized_params = self._optimize_sigma_lambda(
                x, y, alpha, s_sigma, s_lambda)
            sigma = optimized_params['sigma']
            lambda_ = optimized_params['lambda']

        phi_x = self.__kernel(r=x, sigma=sigma)
        phi_y = self.__kernel(r=y, sigma=sigma)
        H = (1. -
             alpha) * (np.dot(phi_y.T, phi_y) / self.__y_num_row) + alpha * (
                 np.dot(phi_x.T, phi_x) / self.__x_num_row)  # Phi* Phi
        h = np.average(phi_x, axis=0).T
        weights = np.linalg.solve(H + lambda_ * np.identity(self.__kernel_num),
        #  weights[weights < 0] = 0.
        weights = jax.ops.index_update(weights, weights < 0, 0)  # G2[G2<0]=0

        self.__alpha = alpha
        self.__weights = weights
        self.__lambda = lambda_
        self.__sigma = sigma
        self.__phi_x = phi_x
        self.__phi_y = phi_y
Example #17
def centroid_restraint(conf, lamb, masses, lamb_flag, lamb_offset,
                       group_a_idxs, group_b_idxs, kb, b0):

    xi = conf[group_a_idxs]
    xj = conf[group_b_idxs]

    avg_xi = np.average(xi, axis=0, weights=masses[group_a_idxs])
    avg_xj = np.average(xj, axis=0, weights=masses[group_b_idxs])

    dx = avg_xi - avg_xj
    dij = np.sqrt(np.sum(dx * dx))
    delta = dij - b0

    lamb_final = lamb * lamb_flag + lamb_offset

    return lamb_final * kb * delta * delta
Example #18
def likelihood(args):
    f0m, f0w, const = np.split(args, 3)
    d_phif0 = MOD(f0m, f0w, const, data_phif0, data_phi, data_f)
    m_phif0 = MOD(f0m, f0w, const, mc_phif0, mc_phi, mc_f)
    d_tmp = alladd(d_phif0)
    m_tmp = np.average(alladd(m_phif0))
    return -np.sum(wt * (np.log(d_tmp) - np.log(m_tmp)))
Example #19
def return_weighted_average(action_trajectories: jnp.ndarray,
                            cum_reward: jnp.ndarray,
                            kappa: float) -> jnp.ndarray:
    r"""Calculates return-weighted average over all trajectories.

  This will calculate the return-weighted average over a set of trajectories as
  defined on l.17 of Alg. 2 in the MBOP paper:

  Note: Clipping will be performed for `cum_reward` values > 80 to avoid NaNs.

    action_trajectories: (n_trajectories, horizon, action_dim) tensor of action
      trajectories, corresponds to `A` in Alg. 2.
    cum_reward: (n_trajectories) vector of corresponding cumulative rewards
      (returns) for each trajectory. Corresponds to `\mathcal{R}` in Alg. 2.
    kappa: `\kappa` constant, changes the 'peakiness' of the exponential

    Single action trajectory corresponding to the return-weighted average of the
    # Substract maximum reward to avoid NaNs:
    cum_reward = cum_reward - cum_reward.max()
    # Remove the batch dimension of cum_reward allows for an implicit broadcast in
    # jnp.average:
    exp_cum_reward = jnp.exp(kappa * jnp.squeeze(cum_reward))
    return jnp.average(action_trajectories, weights=exp_cum_reward, axis=0)
Example #20
    def shiva_rule(neighbors):
        center = neighbors[len(neighbors) // 2]

        color_add = (np.argmax(np.average(neighbors, axis=0), axis=0) + 1) % 3

        arise = jax.nn.one_hot(np.array([color_add], dtype="uint8"), 3)[0] * 255

        updated_center = np.floor(
                np.stack([center, arise]),
                weights=[1.0 - blend_coefficient, blend_coefficient],

        return np.array(updated_center, dtype="uint8")
Example #21
def estimate_accuracy(X, y, params, rng, num_iterations=1):

    samples = sample_multi_posterior_predictive(
        rng, num_iterations, model, (X,), guide, (X,), params

    return jnp.average(samples['obs'] == y)
Example #22
def estimate_accuracy_fixed_params(X, y, w, intercept, rng, num_iterations=1):
    samples = sample_multi_prior_predictive(rng, num_iterations, model, (X, ),
                                                'w': w,
                                                'intercept': intercept
    return jnp.average(samples['obs'] == y)
Example #23
def simplified_u(a_conf, b_conf, a_masses, b_masses):

    a_com_conf = a_conf - np.average(a_conf, axis=0, weights=a_masses)
    b_com_conf = b_conf - np.average(b_conf, axis=0, weights=b_masses)

    a_tensor = inertia_tensor(a_com_conf, a_masses)
    b_tensor = inertia_tensor(b_com_conf, b_masses)

    a_eval, a_evec = np.linalg.eigh(a_tensor)
    b_eval, b_evec = np.linalg.eigh(b_tensor)

    loss = []
    for a, b in zip(a_evec.T, b_evec.T):
        delta = 1 - np.abs(np.dot(a, b))
        loss.append(delta * delta)

    return np.sum(loss)
Example #24
def weight(args):
    f0m, f0w, const = np.split(args, 3)
    d_phif0 = MOD(f0m, f0w, const, data_phif0, data_phi, data_f)
    m_phif0 = MOD(f0m, f0w, const, mc_phif0, mc_phi, mc_f)
    d_tmp = alladd(d_phif0)
    m_tmp = np.average(alladd(m_phif0))
    return d_tmp / m_tmp
Example #25
 def likelihood(self, args):
     f0m, f0w, const = np.split(args, 3)
     d_phif0 = self.MOD(f0m, f0w, const, self.data_phif0, self.data_phi,
     m_phif0 = self.MOD(f0m, f0w, const, self.mc_phif0, self.mc_phi,
     d_tmp = self.alladd(d_phif0)
     m_tmp = np.average(self.alladd(m_phif0))
     return -np.sum(self.wt * (np.log(d_tmp) - np.log(m_tmp)))
Example #26
def shortcuts_ch_shrink(x, out_features, method):
  """Match the number of channels in the shortcuts in the 1st conv1x1 layer."""
  in_features = x.shape[-1]
  num_ch_avg = in_features // out_features
  assert out_features * num_ch_avg == in_features, (
      'in_features needs to be a whole multiple of out_features')
  dim_nwh = x.shape[0:3]
  if method == 'consecutive':
    x = jnp.reshape(x, dim_nwh + (out_features, num_ch_avg))
    return jnp.average(x, axis=4)
  elif method == 'every_n':
    x = jnp.reshape(x, dim_nwh + (num_ch_avg, out_features))
    return jnp.average(x, axis=3)
  elif method == 'none':
    # return all zeros to represent no shortcut
    return jnp.zeros(dim_nwh + (out_features,), dtype=x.dtype)
    raise ValueError('Unsupported channel shrinking shortcut function type.')
Example #27
def plot_iter_daily_shade(soln_inc,n,ymax=1,scale=1,int=0,Tint=1,loCI=5,upCI=95,plotThis=False,plotName="test"):

  plots the output (cumulative prevalence) from a multiple simulation, with or without an intervention. Shows mean and 95% CI
  soln_inc: 3D array of values for each iteration for each variable at each timepoint
  tvec: 1D vector of timepoints
  n: total population size
  ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N)
  scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values)
  int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0
  Tint: Optional, timepoint (days) at which intervention was started
  loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval
  plotThis: True or False, whether a plot will be saved as pdf 
  plotName: string, name of the plot to be saved



  # linear scale
  # add averages
  plt.figure(figsize=(2*6.4, 4.0))
  plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left")
  # add ranges
  for i in range(0,7):
  if int==1:
  plt.xlabel("Time (days)")
  plt.ylabel("Daily incidence")

  # log scale
  # add averages
  plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left")
  # add ranges
  for i in range(0,7):
  if int==1:
  plt.xlabel("Time (days)")
  plt.ylabel("Daily incidence")
  if plotThis==True:
Example #28
 def get_log_L(weights):
     mu = jnp.average(points, weights=weights, axis=0)
     dx = points - mu
     Cov = jnp.average(dx[:, :, None] * dx[:, None, :],
     logdetCov = jnp.log(jnp.linalg.det(Cov))
     C = jnp.linalg.pinv(Cov)
     # logdetCov = 0.
     # C = jnp.eye(mu.size)
     maha = vmap(lambda dx: dx @ C @ dx)(dx)
     n_i = jnp.sum(weights)
     log_L_1 = -0.5 * jnp.sum(jnp.where(weights, maha, 0.)) \
               - 0.5 * n_i * mu.size * jnp.log(2. * jnp.pi) - 0.5 * n_i * logdetCov \
               + jnp.log(n_i) - jnp.log(n)
     # log_L_1 = -0.5 * jnp.sum(jnp.where(weights, maha, 0.)) \
     #           - 0.5 * n_i * mu.size * jnp.log(2. * jnp.pi) - 0.5 * n_i * logdetCov
     return jnp.where(jnp.isnan(log_L_1), -jnp.inf, log_L_1)
Example #29
    def j_score_init(stds, rng2):

        new_params = custom_init(stds, rng2)

        rand_input = jax.random.normal(rng2, [n, 4])
        rng2 += 1

        outputs = jax.vmap(
                    learned_dynamics(new_params)))(rand_input)[:, 2:]

        #KL-divergence to mu=0, std=1:
        mu = jnp.average(outputs, axis=0)
        std = jnp.std(outputs, axis=0)

        KL = jnp.sum((mu**2 + std**2 - 1) / 2.0 - jnp.log(std))

        def total_output(p):
            return vmap(partial(raw_lagrangian_eom,

        d_params = grad(total_output)(new_params)

        i = 0
        for l1 in d_params:
            if (len(l1)) == 0: continue
            new_l1 = []
            for l2 in l1:
                if len(l2.shape) == 1: continue

                mu = jnp.average(l2)
                std = jnp.std(l2)
                KL += (mu**2 + std**2 - 1) / 2.0 - jnp.log(std)

                desired_gaussian = jnp.sqrt(6) / jnp.sqrt(l2.shape[0] +
                scaled_std = stds[i] / desired_gaussian
                #Avoid extremely large values
                KL += 0.1 * (scaled_std**2 / 2.0 - jnp.log(scaled_std))
                i += 1

        return jnp.log10(KL)
Example #30
 def weight(self, args):
     f0m, f0w, const = np.split(args, 3)
     d_phif0 = self.MOD(f0m, f0w, const, self.data_phif0, self.data_phi,
     m_phif0 = self.MOD(f0m, f0w, const, self.mc_phif0, self.mc_phi,
     d_tmp = self.alladd(d_phif0)
     m_tmp = np.average(self.alladd(m_phif0))
     return d_tmp / m_tmp