def prob_inf_house_size_iter(state, hh_sizes_, house_dist): """ Function that computes the probability of an individual getting infected given their household size. @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation @type : Device Array of shape (# of iterations, population size) @param hh_sizes_ : An array which keeps track of the size of each individual's household @type : Array of length = population size @param house_dist : Distribution of household sizes @type : List or 1D array @return : Returns the probability of infection given household size and the mean probability of infection @type : Tuple """ hh_sizes = np.asarray(hh_sizes_) iterations = len(state) prob_hh_size = np.zeros((iterations, len(house_dist))) pop = len(state[0]) mean_inf_prob = np.zeros(iterations) # First compute the probability of the household size given that the person was infected and then use Bayes rule for i in range(iterations): if_inf = np.where(state[i] > 0)[0] inf_size = len(if_inf) hh_inf = hh_sizes[if_inf] prob = ((np.array(np.unique(hh_inf, return_counts= True))[-1])/inf_size) * (inf_size/pop) * (1/house_dist) # Bayes rule prob_hh_size = index_add(prob_hh_size, i, prob) mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop) # Returns the probability of infection given household size return np.average(prob_hh_size, axis = 0) , np.average(mean_inf_prob)
def prob_inf_workplace_open(indx_active, state): """ Function that computes the probability of infection for an individual who is still working during intervention. @param indx_active : Numpy array with indices of individuals still working during intervention @type : 1D array @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation @type : Device Array of shape (# of iterations, population size) @return : Returns the probability of infection for individuals working during intervention and the population average, averaged over the number of iterations @type : Tuple """ iterations = len(state) prob_inf_work = np.zeros(iterations) pop = len(state[0]) mean_inf_prob = np.zeros(iterations) for i in range(iterations): # Get indices of infected people if_inf = np.where(state[i] > 0)[0] inf_size = len(if_inf) # Calculate the conditional probability prob = (sum(np2.isin(indx_active, if_inf))/inf_size) * (inf_size/pop) * (pop/len(indx_active)) prob_inf_work = index_add(prob_inf_work, i, prob) mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop) return np.average(prob_inf_work), np.average(mean_inf_prob)
def init_state(cluster_id): num_k = jnp.sum(mask & (cluster_id == a_k[:, None]), axis=-1) mu_k = vmap(lambda k: jnp.average( points, axis=0, weights=k == cluster_id))(a_k) C_k = vmap(lambda k, mu_k: jnp.linalg.pinv( jnp.average((points - mu_k)[:, :, None] * (points - mu_k)[:, None, :], axis=0, weights=k == cluster_id)))(a_k, mu_k) logdetC_k = vmap( lambda C_k: jnp.sum(jnp.log(jnp.linalg.eigvals(C_k).real)))(C_k) precision_k = C_k * num_k[:, None, None] # K, N log_maha_k = vmap(lambda mu_k, precision_k: jnp.log( vmap(lambda point: (point - mu_k) @ precision_k @ (point - mu_k)) (points)))(mu_k, precision_k) log_f_k = log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k) log_VE_k = vmap(log_ellipsoid_volume)(logdetC_k, num_k, log_f_k) log_VS_k = jnp.log(num_k) - jnp.log(num_S) return State(i=jnp.asarray(0), done=num_S < K * (D + 1), cluster_id=cluster_id, C_k=C_k, logdetC_k=logdetC_k, mu_k=mu_k, log_maha_k=log_maha_k, num_k=num_k, log_VE_k=log_VE_k, log_VS_k=log_VS_k, min_loss=jnp.asarray(jnp.inf), delay=jnp.asarray(0))
def inertial_restraint(conf, params, box, lamb, a_idxs, b_idxs, masses, k): a_conf = conf[a_idxs] b_conf = conf[b_idxs] a_masses = masses[a_idxs] b_masses = masses[b_idxs] a_com_conf = a_conf - np.average(a_conf, axis=0, weights=a_masses) b_com_conf = b_conf - np.average(b_conf, axis=0, weights=b_masses) a_tensor = inertia_tensor(a_com_conf, a_masses) b_tensor = inertia_tensor(b_com_conf, b_masses) a_eval, a_evec = np.linalg.eigh(a_tensor) b_eval, b_evec = np.linalg.eigh(b_tensor) # eigenvalues are needed for derivatives # a_eval, a_evec = dsyevv3(a_tensor) # b_eval, b_evec = dsyevv3(b_tensor) loss = [] # (ytz): .T is because the eigenvectors are stored in columns for a, b in zip(a_evec.T, b_evec.T): delta = 1 - np.abs(np.dot(a, b)) loss.append(delta * delta) return np.sum(loss) * k
def split_background(background_double_exp): # split the average from 2 exposures: bkg_avg0 = np.average(background_double_exp[0::2], axis=0) bkg_avg1 = np.average(background_double_exp[1::2], axis=0) return np.array([bkg_avg0, bkg_avg1])
def get_peaks_iter(soln,tvec,int=0,Tint=0,loCI=5,upCI=95): """ calculates the peak prevalence for a multiple runs, with or without an intervention soln: 3D array of values for each iteration for each variable at each timepoint tvec: 1D vector of timepoints ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N) scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values) int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0 Tint: Optional, timepoint (days) at which intervention was started loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval """ delta_t=tvec[1]-tvec[0] if int==0: time_int=0 else: time_int=Tint all_cases=soln[:,:,1]+soln[:,:,2]+soln[:,:,3]+soln[:,:,4] # Final values print('Final recovered: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(soln[:,-1,6]), 100*np.percentile(soln[:,-1,6],loCI), 100*np.percentile(soln[:,-1,6],upCI))) print('Final deaths: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(soln[:,-1,5]), 100*np.percentile(soln[:,-1,5],loCI), 100*np.percentile(soln[:,-1,5],upCI))) print('Remaining infections: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100*np.average(all_cases[:,-1]),100*np.percentile(all_cases[:,-1],loCI),100*np.percentile(all_cases[:,-1],upCI))) # Peak prevalence peaks=np.amax(soln[:,:,2],axis=1) print('Peak I1: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln[:,:,3],axis=1) print('Peak I2: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) peaks=np.amax(soln[:,:,4],axis=1) print('Peak I3: {:4.2f}% [{:4.2f}, {:4.2f}]'.format( 100 * np.average(peaks),100 * np.percentile(peaks,loCI),100 * np.percentile(peaks,upCI))) # Timing of peaks tpeak=np.argmax(soln[:,:,2],axis=1)*delta_t-time_int print('Time of peak I1: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak), np.percentile(tpeak,loCI),np.percentile(tpeak,upCI))) tpeak=np.argmax(soln[:,:,3],axis=1)*delta_t-time_int print('Time of peak I2: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,loCI),np.percentile(tpeak,upCI))) tpeak=np.argmax(soln[:,:,4],axis=1)*delta_t-time_int print('Time of peak I3: avg {:4.2f} days, median {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(tpeak),np.median(tpeak),np.percentile(tpeak,loCI),np.percentile(tpeak,upCI))) # Time when all the infections go extinct time_all_extinct = np.array(get_extinction_time(all_cases,0))*delta_t-time_int print('Time of extinction of all infections post intervention: {:4.2f} days [{:4.2f}, {:4.2f}]'.format( np.average(time_all_extinct),np.percentile(time_all_extinct,loCI),np.percentile(time_all_extinct,upCI))) return
def prob_inf_working_hh_member(indx_active, state, house_indices, household_sizes): """ Function that computes the probability of infection for individuals living with a household member working during intervention along with the probability of infection for individuals who aren't working and living with a working household member. @param indx_active : Numpy array with indices of individuals still working during intervention @type : 1D array @param state : A Device Array that encodes the state of each individual in the population at the end of each iteration of the simulation @type : Device Array of shape (# of iterations, population size) @param house_indices : Numpy array that keeps track of the house an individual belongs to @type : 1D array @param household_sizes : Numpy array that keeps track of the size of each individual's household @type : 1D array @return : Returns the probability of infection for individuals living with working household members, probability for non-workers living with no working household member, and the population average, averaged over the number of iterations @type : Tuple """ iterations = len(state) prob_inf = np.zeros(iterations) prob_inf_not_working = np.zeros(iterations) pop = len(state[0]) mean_inf_prob = np.zeros(iterations) for i in range(iterations): # Get indices of infected people if_inf = np.where(state[i] > 0)[0] inf_size = len(if_inf) # Houses of people who are still working house_working = np2.unique(house_indices[indx_active]) # Indices of all people who aren't working and their house index not_working = np2.setdiff1d(np2.arange(0, pop, 1), indx_active) house_not_working = house_indices[not_working] # Probability of living with atleast one working household member prob_house_working = (sum(np2.isin(house_not_working, house_working))/pop) # Probability of living no working household member prob_house_not_working = (sum(~np2.isin(house_not_working, house_working))/pop) # Indices of infected people who aren't working and their house index if_inf_not_working = np2.setdiff1d(if_inf, indx_active) house_inf_not_working = house_indices[if_inf_not_working] # Probability of infection given atleast one hh member was working during intervention prob_1 = (sum(np2.isin(house_inf_not_working, house_working))/inf_size) * (inf_size/pop) * (1/prob_house_working) prob_inf = index_add(prob_inf, i, prob_1) # Probability of infection given no hh member was working during intervention prob_2 = (sum(~np2.isin(house_inf_not_working, house_working))/inf_size) * (inf_size/pop) * (1/prob_house_not_working) prob_inf_not_working = index_add(prob_inf_not_working, i, prob_2) # Population average probability of infection mean_inf_prob = index_add(mean_inf_prob, i, inf_size/pop) return np.average(prob_inf), np.average(prob_inf_not_working), np.average(mean_inf_prob)
def analytic_restraint_force(conf, params, box, lamb, a_idxs, b_idxs, masses, k): a_conf = conf[a_idxs] b_conf = conf[b_idxs] a_masses = masses[a_idxs] b_masses = masses[b_idxs] a_com_conf = a_conf - np.average(a_conf, axis=0, weights=a_masses) b_com_conf = b_conf - np.average(b_conf, axis=0, weights=b_masses) a_tensor = inertia_tensor(a_com_conf, a_masses) b_tensor = inertia_tensor(b_com_conf, b_masses) # a_eval, a_evec = np.linalg.eigh(a_tensor) # b_eval, b_evec = np.linalg.eigh(b_tensor) # eigenvalues are needed for derivatives a_eval, a_evec = dsyevv3(a_tensor) b_eval, b_evec = dsyevv3(b_tensor) loss = [] for a, b in zip(a_evec.T, b_evec.T): delta = 1 - np.abs(np.dot(a, b)) loss.append(delta * delta) dl_daevec_T = [] dl_dbevec_T = [] for a, b in zip(a_evec.T, b_evec.T): delta = 1 - np.abs(np.dot(a, b)) prefactor = -np.sign(np.dot(a, b)) * 2 * delta * k dl_daevec_T.append(prefactor * b) dl_dbevec_T.append(prefactor * a) dl_daevec = np.transpose(np.array(dl_daevec_T)) dl_dbevec = np.transpose(np.array(dl_dbevec_T)) dl_datensor = grad_eigh(a_eval, a_evec, np.array(dl_daevec)) dl_dbtensor = grad_eigh(b_eval, b_evec, np.array(dl_dbevec)) dl_da_com_conf = grad_inertia_tensor(a_com_conf, a_masses, dl_datensor) dl_db_com_conf = grad_inertia_tensor(b_com_conf, b_masses, dl_dbtensor) du_dx = onp.zeros_like(conf) du_dx[a_idxs] += dl_da_com_conf du_dx[b_idxs] += dl_db_com_conf print("ref du_dx", du_dx) # conservative forces are not affected by center of mass changes. # the vjp w.r.t. to the center of mass yields zeros since sum dx=0, dy=0 and dz=0 # return dl_da_com_conf, dl_db_com_conf return du_dx
def get_mu_and_C(k): weights = (cluster_id == k) & mask mu = jnp.average(points, weights=weights, axis=0) dist = points - mu Cov = jnp.average(dist[:, :, None] * dist[:, None, :], weights=weights, axis=0) C = jnp.linalg.pinv(Cov) mu = jnp.where(num_k[k] == 0, 0., mu) C = jnp.where(num_k[k] < D + 1, 0., C) return mu, C
def internal(t, pos, vel): tM = np.full((atoms, 1), t) arr = np.concatenate((tM, idx, pos), axis = 1) (potential, kinetic) = calcEnergy(pos, vel) return ( arr, np.array([[ t * 1e12, potential * mass_unit * (dist_unit / time_unit) ** 2 * 1e21, kinetic * mass_unit * (dist_unit / time_unit) ** 2 * 1e21, np.average(distance(pos[ccPairs])), np.average(distance(pos[chPairs]))]]))
def centroid_restraint(conf, params, box, lamb, masses, group_a_idxs, group_b_idxs, kb, b0): xi = conf[group_a_idxs] xj = conf[group_b_idxs] avg_xi = np.average(xi, axis=0, weights=masses[group_a_idxs]) avg_xj = np.average(xj, axis=0, weights=masses[group_b_idxs]) dx = avg_xi - avg_xj dij = np.sqrt(np.sum(dx * dx)) delta = dij - b0 return kb * delta * delta
def test_against_tf_ctc_loss(self): batchsize = 8 timesteps = 150 labelsteps = 25 nclasses = 400 logits = np.random.randn(batchsize, timesteps, nclasses) logprobs = jax.nn.log_softmax(logits) logprob_paddings = np.zeros((batchsize, timesteps)) labels = np.random.randint( 1, nclasses, size=(batchsize, labelsteps)).astype(np.int32) label_paddings = np.zeros((batchsize, labelsteps)) inputs = [logprobs, logprob_paddings, labels, label_paddings] jax_per_seq, unused_aux_vars = ctc_objectives.ctc_loss(*inputs) tf_per_seq = tf_ctc_loss(*inputs) self.assertAllClose(jax_per_seq.squeeze(), tf_per_seq.squeeze()) average_tf_ctc_loss = lambda *args: jnp.average(tf_ctc_loss(*args)) jax_dloss = jax.grad(average_ctc_loss) tf_dloss = jax.grad(average_tf_ctc_loss) jax_dlogits = jax_dloss(*inputs) tf_dlogits = tf_dloss(*inputs) # Relative error check is disabled as numerical errors explodes when a # probability computed from the input logits is close to zero. self.assertAllClose(jax_dlogits, tf_dlogits, rtol=0.0, atol=1e-4)
def __init__(self, means, covs, weights): """ Arguments: means, covs are np arrays or lists of length k, with entries of shape (d,) and (d, d) respectively. (e.g. covs can be array of shape (k, d, d)) """ means, covs, weights = self._check_and_reshape_args( means, covs, weights) self.d = len(means[0]) self.expectations = self.compute_expectations(means, covs, weights) self.mean = self.expectations[0] # recall Cov(X) = E[XX^T] - mu mu^T = # sum_over_components(Cov(Xi) + mui mui^T) - mu mu^T mumut = np.einsum("ki,kj->kij", means, means) # shape (k, d, d) self.cov = np.average(covs + mumut, weights=weights, axis=0) \ - np.outer(self.mean, self.mean) self.threadkey = random.PRNGKey(0) self.means = means self.covs = covs self.weights = weights self.num_components = len(weights) self.sample_metrics = dict() self.initialize_metric_names() self.tfp_dist = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=self.weights), components_distribution=tfd.MultivariateNormalFullCovariance( loc=self.means, covariance_matrix=self.covs))
def get_mean_sig_bounds(arr, dim, weights, sig_multiple=1.): mean_arr = np.expand_dims(np.average(arr, dim, weights), dim) sig_arr = np.sqrt(np.mean(weights * (arr - mean_arr)**2, dim)) mean_arr = mean_arr.squeeze() sig_upper_arr = mean_arr + sig_arr * sig_multiple sig_lower_arr = mean_arr - sig_arr * sig_multiple return mean_arr, sig_arr, sig_upper_arr, sig_lower_arr
def filter_bblocks(data): #yy=np.reshape(data[:,:,0],(nrcols,nbmux)) # vertical stripes yy = np.reshape(data[:, :, 11], (nrcols, 192)) # clip and smooth filter_strength = 3 bkgthr = filter_strength # background threshold # 2d version deviates significantly from original # gg2 = np.outer(gg,gg) # gg2 /= np.sum(gg2) yy_s = conv2d(np.clip(yy, -bkgthr, bkgthr), gg) yy_s = np.reshape(yy_s, (nrcols, nbmux, 1)) data_out = data - yy_s #data_out = data#-yy_s ###yy_avg=np.reshape(np.average(np.clip(bblocksXtif1(data_out)[1:11,:,:],0,2*bkgthr),axis=0),(1,192,12)) yy_avg = np.reshape( np.average(np.clip(data_out[1:10, :, :], 0, 2 * bkgthr), axis=0), (1, 192, 12)) data_out -= yy_avg data_out *= data_out > filter_strength return data_out #-yy_s-yy_avg
def _RuLSIF(self, x, y, alpha, s_sigma, s_lambda): if len(s_sigma) == 1 and len(s_lambda) == 1: sigma = s_sigma[0] lambda_ = s_lambda[0] else: optimized_params = self._optimize_sigma_lambda( x, y, alpha, s_sigma, s_lambda) sigma = optimized_params['sigma'] lambda_ = optimized_params['lambda'] phi_x = self.__kernel(r=x, sigma=sigma) phi_y = self.__kernel(r=y, sigma=sigma) H = (1. - alpha) * (np.dot(phi_y.T, phi_y) / self.__y_num_row) + alpha * ( np.dot(phi_x.T, phi_x) / self.__x_num_row) # Phi* Phi h = np.average(phi_x, axis=0).T weights = np.linalg.solve(H + lambda_ * np.identity(self.__kernel_num), h).ravel() # weights[weights < 0] = 0. weights = jax.ops.index_update(weights, weights < 0, 0) # G2[G2<0]=0 self.__alpha = alpha self.__weights = weights self.__lambda = lambda_ self.__sigma = sigma self.__phi_x = phi_x self.__phi_y = phi_y
def centroid_restraint(conf, lamb, masses, lamb_flag, lamb_offset, group_a_idxs, group_b_idxs, kb, b0): xi = conf[group_a_idxs] xj = conf[group_b_idxs] avg_xi = np.average(xi, axis=0, weights=masses[group_a_idxs]) avg_xj = np.average(xj, axis=0, weights=masses[group_b_idxs]) dx = avg_xi - avg_xj dij = np.sqrt(np.sum(dx * dx)) delta = dij - b0 lamb_final = lamb * lamb_flag + lamb_offset return lamb_final * kb * delta * delta
def likelihood(args): f0m, f0w, const = np.split(args, 3) d_phif0 = MOD(f0m, f0w, const, data_phif0, data_phi, data_f) m_phif0 = MOD(f0m, f0w, const, mc_phif0, mc_phi, mc_f) d_tmp = alladd(d_phif0) m_tmp = np.average(alladd(m_phif0)) return -np.sum(wt * (np.log(d_tmp) - np.log(m_tmp)))
def return_weighted_average(action_trajectories: jnp.ndarray, cum_reward: jnp.ndarray, kappa: float) -> jnp.ndarray: r"""Calculates return-weighted average over all trajectories. This will calculate the return-weighted average over a set of trajectories as defined on l.17 of Alg. 2 in the MBOP paper: [https://arxiv.org/abs/2008.05556]. Note: Clipping will be performed for `cum_reward` values > 80 to avoid NaNs. Args: action_trajectories: (n_trajectories, horizon, action_dim) tensor of action trajectories, corresponds to `A` in Alg. 2. cum_reward: (n_trajectories) vector of corresponding cumulative rewards (returns) for each trajectory. Corresponds to `\mathcal{R}` in Alg. 2. kappa: `\kappa` constant, changes the 'peakiness' of the exponential averaging. Returns: Single action trajectory corresponding to the return-weighted average of the trajectories. """ # Substract maximum reward to avoid NaNs: cum_reward = cum_reward - cum_reward.max() # Remove the batch dimension of cum_reward allows for an implicit broadcast in # jnp.average: exp_cum_reward = jnp.exp(kappa * jnp.squeeze(cum_reward)) return jnp.average(action_trajectories, weights=exp_cum_reward, axis=0)
def shiva_rule(neighbors): center = neighbors[len(neighbors) // 2] color_add = (np.argmax(np.average(neighbors, axis=0), axis=0) + 1) % 3 arise = jax.nn.one_hot(np.array([color_add], dtype="uint8"), 3)[0] * 255 updated_center = np.floor( np.average( np.stack([center, arise]), axis=0, weights=[1.0 - blend_coefficient, blend_coefficient], ) ) return np.array(updated_center, dtype="uint8")
def estimate_accuracy(X, y, params, rng, num_iterations=1): samples = sample_multi_posterior_predictive( rng, num_iterations, model, (X,), guide, (X,), params ) return jnp.average(samples['obs'] == y)
def estimate_accuracy_fixed_params(X, y, w, intercept, rng, num_iterations=1): samples = sample_multi_prior_predictive(rng, num_iterations, model, (X, ), { 'w': w, 'intercept': intercept }) return jnp.average(samples['obs'] == y)
def simplified_u(a_conf, b_conf, a_masses, b_masses): a_com_conf = a_conf - np.average(a_conf, axis=0, weights=a_masses) b_com_conf = b_conf - np.average(b_conf, axis=0, weights=b_masses) a_tensor = inertia_tensor(a_com_conf, a_masses) b_tensor = inertia_tensor(b_com_conf, b_masses) a_eval, a_evec = np.linalg.eigh(a_tensor) b_eval, b_evec = np.linalg.eigh(b_tensor) loss = [] for a, b in zip(a_evec.T, b_evec.T): delta = 1 - np.abs(np.dot(a, b)) loss.append(delta * delta) return np.sum(loss)
def weight(args): f0m, f0w, const = np.split(args, 3) d_phif0 = MOD(f0m, f0w, const, data_phif0, data_phi, data_f) m_phif0 = MOD(f0m, f0w, const, mc_phif0, mc_phi, mc_f) d_tmp = alladd(d_phif0) m_tmp = np.average(alladd(m_phif0)) #print("weight") return d_tmp / m_tmp
def likelihood(self, args): f0m, f0w, const = np.split(args, 3) d_phif0 = self.MOD(f0m, f0w, const, self.data_phif0, self.data_phi, self.data_f) m_phif0 = self.MOD(f0m, f0w, const, self.mc_phif0, self.mc_phi, self.mc_f) d_tmp = self.alladd(d_phif0) m_tmp = np.average(self.alladd(m_phif0)) return -np.sum(self.wt * (np.log(d_tmp) - np.log(m_tmp)))
def shortcuts_ch_shrink(x, out_features, method): """Match the number of channels in the shortcuts in the 1st conv1x1 layer.""" in_features = x.shape[-1] num_ch_avg = in_features // out_features assert out_features * num_ch_avg == in_features, ( 'in_features needs to be a whole multiple of out_features') dim_nwh = x.shape[0:3] if method == 'consecutive': x = jnp.reshape(x, dim_nwh + (out_features, num_ch_avg)) return jnp.average(x, axis=4) elif method == 'every_n': x = jnp.reshape(x, dim_nwh + (num_ch_avg, out_features)) return jnp.average(x, axis=3) elif method == 'none': # return all zeros to represent no shortcut return jnp.zeros(dim_nwh + (out_features,), dtype=x.dtype) else: raise ValueError('Unsupported channel shrinking shortcut function type.')
def plot_iter_daily_shade(soln_inc,n,ymax=1,scale=1,int=0,Tint=1,loCI=5,upCI=95,plotThis=False,plotName="test"): """ plots the output (cumulative prevalence) from a multiple simulation, with or without an intervention. Shows mean and 95% CI soln_inc: 3D array of values for each iteration for each variable at each timepoint tvec: 1D vector of timepoints n: total population size ymax : highest value on y axis, relative to "scale" value (e.g. 0.5 makes ymax=0.5 or 50% for scale=1 or N) scale: amount to multiple all frequency values by (e.g. "1" keeps as frequency, "N" turns to absolute values) int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0 Tint: Optional, timepoint (days) at which intervention was started loCI,upCI: Optional, upper and lower percentiles for confidence intervals. Defaults to 90% interval plotThis: True or False, whether a plot will be saved as pdf plotName: string, name of the plot to be saved """ tvec=np.arange(1,np.shape(soln_inc)[1]+1) soln_avg=np.average(soln_inc,axis=0) soln_loCI=np.percentile(soln_inc,loCI,axis=0) soln_upCI=np.percentile(soln_inc,upCI,axis=0) # linear scale # add averages plt.figure(figsize=(2*6.4, 4.0)) plt.subplot(121) plt.plot(tvec,soln_avg*scale) plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left") # add ranges plt.gca().set_prop_cycle(None) for i in range(0,7): plt.fill_between(tvec,soln_loCI[:,i]*scale,soln_upCI[:,i]*scale,alpha=0.3) if int==1: plt.plot([Tint,Tint],[0,ymax*scale],'k--') plt.ylim([0,ymax*scale]) plt.xlabel("Time (days)") plt.ylabel("Daily incidence") # log scale # add averages plt.subplot(122) plt.plot(tvec,soln_avg*scale) plt.legend(['S', 'E', 'I1', 'I2', 'I3', 'D', 'R'],frameon=False,framealpha=0.0,bbox_to_anchor=(1.04,1), loc="upper left") # add ranges plt.gca().set_prop_cycle(None) for i in range(0,7): plt.fill_between(tvec,soln_loCI[:,i]*scale,soln_upCI[:,i]*scale,alpha=0.3) if int==1: plt.plot([Tint,Tint],[scale/n,ymax*scale],'k--') plt.ylim([scale/n,ymax*scale]) plt.xlabel("Time (days)") plt.ylabel("Daily incidence") plt.semilogy() plt.tight_layout() if plotThis==True: plt.savefig(plotName+'.pdf',bbox_inches='tight') plt.show()
def get_log_L(weights): mu = jnp.average(points, weights=weights, axis=0) dx = points - mu Cov = jnp.average(dx[:, :, None] * dx[:, None, :], weights=weights, axis=0) logdetCov = jnp.log(jnp.linalg.det(Cov)) C = jnp.linalg.pinv(Cov) # logdetCov = 0. # C = jnp.eye(mu.size) maha = vmap(lambda dx: dx @ C @ dx)(dx) n_i = jnp.sum(weights) log_L_1 = -0.5 * jnp.sum(jnp.where(weights, maha, 0.)) \ - 0.5 * n_i * mu.size * jnp.log(2. * jnp.pi) - 0.5 * n_i * logdetCov \ + jnp.log(n_i) - jnp.log(n) # log_L_1 = -0.5 * jnp.sum(jnp.where(weights, maha, 0.)) \ # - 0.5 * n_i * mu.size * jnp.log(2. * jnp.pi) - 0.5 * n_i * logdetCov return jnp.where(jnp.isnan(log_L_1), -jnp.inf, log_L_1)
def j_score_init(stds, rng2): new_params = custom_init(stds, rng2) rand_input = jax.random.normal(rng2, [n, 4]) rng2 += 1 outputs = jax.vmap( partial(raw_lagrangian_eom, learned_dynamics(new_params)))(rand_input)[:, 2:] #KL-divergence to mu=0, std=1: mu = jnp.average(outputs, axis=0) std = jnp.std(outputs, axis=0) KL = jnp.sum((mu**2 + std**2 - 1) / 2.0 - jnp.log(std)) def total_output(p): return vmap(partial(raw_lagrangian_eom, learned_dynamics(p)))(rand_input).sum() d_params = grad(total_output)(new_params) i = 0 for l1 in d_params: if (len(l1)) == 0: continue new_l1 = [] for l2 in l1: if len(l2.shape) == 1: continue mu = jnp.average(l2) std = jnp.std(l2) KL += (mu**2 + std**2 - 1) / 2.0 - jnp.log(std) #HACK desired_gaussian = jnp.sqrt(6) / jnp.sqrt(l2.shape[0] + l2.shape[1]) scaled_std = stds[i] / desired_gaussian #Avoid extremely large values KL += 0.1 * (scaled_std**2 / 2.0 - jnp.log(scaled_std)) i += 1 return jnp.log10(KL)
def weight(self, args): f0m, f0w, const = np.split(args, 3) d_phif0 = self.MOD(f0m, f0w, const, self.data_phif0, self.data_phi, self.data_f) m_phif0 = self.MOD(f0m, f0w, const, self.mc_phif0, self.mc_phi, self.mc_f) d_tmp = self.alladd(d_phif0) m_tmp = np.average(self.alladd(m_phif0)) #print("weight") return d_tmp / m_tmp