def pr_loop_conv_cregression(key, logcdf_conditionals, logpdf_joints, x, x_test, rho, rho_x, n, T): d = jnp.shape(logcdf_conditionals)[1] #generate uniform random numbers key, subkey = random.split(key) #split key a_rand = random.uniform(subkey, shape=(T, 1)) #Draw random x_samp from BB key, subkey = random.split(key) #split key n = jnp.shape(x)[0] w = random.dirichlet(subkey, jnp.ones(n)) #single set of dirichlet weights key, subkey = random.split(key) #split key ind_new = random.choice(key, a=jnp.arange(n), p=w, shape=(1, T))[0] x_new = x[ind_new] #Append a_rand to empty vn (for correct array size) x_samp = jnp.concatenate((x, x_new), axis=0) #Track difference pdiff = jnp.zeros(T) cdiff = jnp.zeros(T) inputs = logcdf_conditionals, logpdf_joints, x_samp, x_test, rho, rho_x, n, a_rand, logcdf_conditionals, logpdf_joints, pdiff, cdiff #run loop outputs = fori_loop(0, T, pr_1step_conv_cregression, inputs) logcdf_conditionals, logpdf_joints, x_samp, x_test, rho, rho_x, n, a_rand, logcdf_conditionals, logpdf_joints, pdiff, cdiff = outputs return logcdf_conditionals, logpdf_joints, pdiff, cdiff
def predictive_resample_single_loop_cregression(key, logcdf_conditionals, logpdf_joints, x, x_test, rho, rho_x, n, T): #generate uniform random numbers key, subkey = random.split(key) #split key a_rand = random.uniform(subkey, shape=(T, 1)) #Draw random x_samp from BB key, subkey = random.split(key) #split key n = jnp.shape(x)[0] w = random.dirichlet(subkey, jnp.ones(n)) #single set of dirichlet weights key, subkey = random.split(key) #split key ind_new = random.choice(key, a=jnp.arange(n), p=w, shape=(1, T))[0] x_new = x[ind_new] # #Draw from KDE (this is experimental) # key, subkey = random.split(key) #split key # d = jnp.shape(x)[1] # x_new = x_new + 0.5*random.normal(key, shape = (T,d)) #Append a_rand to empty vn (for correct array size) vT = jnp.concatenate((jnp.zeros((n, 1)), a_rand), axis=0) x_samp = jnp.concatenate((x, x_new), axis=0) #run forward loop inputs = vT, logcdf_conditionals, logpdf_joints, x_samp, x_test, rho, rho_x rng = jnp.arange(n, n + T) outputs, rng = mvcr.update_ptest_single_scan(inputs, rng) _, logcdf_conditionals, logpdf_joints, *_ = outputs return logcdf_conditionals, logpdf_joints
def testDirichlet(self, alpha, dtype): key = random.PRNGKey(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, alpha) compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: self.assertAllClose(samples.sum(-1), onp.ones(10000, dtype=dtype), check_dtypes=True) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
def forward_sample_y_samp(key, logpmf_ytest, logpmf_yn, y, x, x_test, rho, rho_x, T): n = jnp.shape(y)[0] #generate uniform random numbers key, subkey = random.split(key) #split key a_rand = random.uniform(subkey, shape=(T, 1)) vT = jnp.append(jnp.zeros(n), a_rand) #uniform rv for sampling y_samp = jnp.concatenate((y.reshape(-1, 1), jnp.zeros((T, 1))), axis=0) #remember y #Draw random x_samp from BB key, subkey = random.split(key) #split key n = jnp.shape(x)[0] w = random.dirichlet(subkey, jnp.ones(n)) #single set of dirichlet weights key, subkey = random.split(key) #split key ind_new = random.choice(key, a=jnp.arange(n), p=w, shape=(1, T))[0] x_samp = jnp.concatenate((x, x[ind_new]), axis=0) # #Draw random x_samp from KDE (this is experimental) # key, subkey = random.split(key) #split key # n = jnp.shape(x)[0] # w = random.dirichlet(subkey, jnp.ones(n)) #single set of dirichlet weights # key, subkey = random.split(key) #split key # ind_new = random.choice(key,a = jnp.arange(n),p = w,shape = (1,T))[0] # x_samp = jnp.concatenate((x,x[ind_new]+random.normal(key,shape = (1,T,d))),axis = 0) #Track changes pdiff = jnp.zeros((n + T, n)) #run forward loop carry = logpmf_ytest, logpmf_yn, y_samp, pdiff, y, x, x_test, rho, rho_x, ind_new, vT, logpmf_yn rng = jnp.arange(n, n + T) carry, rng = update_pn_scan_forward(carry, rng) logpmf_ytest, logpmf_yn, y_samp, pdiff, *_ = carry return logpmf_ytest, logpmf_yn, y_samp, x_samp, pdiff
def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape return random.dirichlet(rng_key, self.alpha, shape)