def pr_loop_conv_cregression(key, logcdf_conditionals, logpdf_joints, x,
                             x_test, rho, rho_x, n, T):
    d = jnp.shape(logcdf_conditionals)[1]

    #generate uniform random numbers
    key, subkey = random.split(key)  #split key
    a_rand = random.uniform(subkey, shape=(T, 1))

    #Draw random x_samp from BB
    key, subkey = random.split(key)  #split key
    n = jnp.shape(x)[0]
    w = random.dirichlet(subkey, jnp.ones(n))  #single set of dirichlet weights
    key, subkey = random.split(key)  #split key
    ind_new = random.choice(key, a=jnp.arange(n), p=w, shape=(1, T))[0]
    x_new = x[ind_new]

    #Append a_rand to empty vn (for correct array size)
    x_samp = jnp.concatenate((x, x_new), axis=0)

    #Track difference
    pdiff = jnp.zeros(T)
    cdiff = jnp.zeros(T)

    inputs = logcdf_conditionals, logpdf_joints, x_samp, x_test, rho, rho_x, n, a_rand, logcdf_conditionals, logpdf_joints, pdiff, cdiff

    #run loop
    outputs = fori_loop(0, T, pr_1step_conv_cregression, inputs)
    logcdf_conditionals, logpdf_joints, x_samp, x_test, rho, rho_x, n, a_rand, logcdf_conditionals, logpdf_joints, pdiff, cdiff = outputs

    return logcdf_conditionals, logpdf_joints, pdiff, cdiff
def predictive_resample_single_loop_cregression(key, logcdf_conditionals,
                                                logpdf_joints, x, x_test, rho,
                                                rho_x, n, T):

    #generate uniform random numbers
    key, subkey = random.split(key)  #split key
    a_rand = random.uniform(subkey, shape=(T, 1))

    #Draw random x_samp from BB
    key, subkey = random.split(key)  #split key
    n = jnp.shape(x)[0]
    w = random.dirichlet(subkey, jnp.ones(n))  #single set of dirichlet weights
    key, subkey = random.split(key)  #split key
    ind_new = random.choice(key, a=jnp.arange(n), p=w, shape=(1, T))[0]
    x_new = x[ind_new]

    # #Draw from KDE (this is experimental)
    # key, subkey = random.split(key) #split key
    # d = jnp.shape(x)[1]
    # x_new = x_new + 0.5*random.normal(key, shape = (T,d))

    #Append a_rand to empty vn (for correct array size)
    vT = jnp.concatenate((jnp.zeros((n, 1)), a_rand), axis=0)
    x_samp = jnp.concatenate((x, x_new), axis=0)

    #run forward loop
    inputs = vT, logcdf_conditionals, logpdf_joints, x_samp, x_test, rho, rho_x
    rng = jnp.arange(n, n + T)
    outputs, rng = mvcr.update_ptest_single_scan(inputs, rng)
    _, logcdf_conditionals, logpdf_joints, *_ = outputs

    return logcdf_conditionals, logpdf_joints
Exemple #3
0
  def testDirichlet(self, alpha, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, alpha)
    compiled_samples = crand(key, alpha)

    for samples in [uncompiled_samples, compiled_samples]:
      self.assertAllClose(samples.sum(-1), onp.ones(10000, dtype=dtype), check_dtypes=True)
      alpha_sum = sum(alpha)
      for i, a in enumerate(alpha):
        self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)
def forward_sample_y_samp(key, logpmf_ytest, logpmf_yn, y, x, x_test, rho,
                          rho_x, T):
    n = jnp.shape(y)[0]

    #generate uniform random numbers
    key, subkey = random.split(key)  #split key
    a_rand = random.uniform(subkey, shape=(T, 1))
    vT = jnp.append(jnp.zeros(n), a_rand)  #uniform rv for sampling
    y_samp = jnp.concatenate((y.reshape(-1, 1), jnp.zeros((T, 1))),
                             axis=0)  #remember y

    #Draw random x_samp from BB
    key, subkey = random.split(key)  #split key
    n = jnp.shape(x)[0]
    w = random.dirichlet(subkey, jnp.ones(n))  #single set of dirichlet weights
    key, subkey = random.split(key)  #split key
    ind_new = random.choice(key, a=jnp.arange(n), p=w, shape=(1, T))[0]
    x_samp = jnp.concatenate((x, x[ind_new]), axis=0)

    # #Draw random x_samp from KDE (this is experimental)
    # key, subkey = random.split(key) #split key
    # n = jnp.shape(x)[0]
    # w = random.dirichlet(subkey, jnp.ones(n)) #single set of dirichlet weights
    # key, subkey = random.split(key) #split key
    # ind_new = random.choice(key,a = jnp.arange(n),p = w,shape = (1,T))[0]
    # x_samp = jnp.concatenate((x,x[ind_new]+random.normal(key,shape = (1,T,d))),axis = 0)

    #Track changes
    pdiff = jnp.zeros((n + T, n))

    #run forward loop
    carry = logpmf_ytest, logpmf_yn, y_samp, pdiff, y, x, x_test, rho, rho_x, ind_new, vT, logpmf_yn
    rng = jnp.arange(n, n + T)
    carry, rng = update_pn_scan_forward(carry, rng)
    logpmf_ytest, logpmf_yn, y_samp, pdiff, *_ = carry
    return logpmf_ytest, logpmf_yn, y_samp, x_samp, pdiff
Exemple #5
0
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.dirichlet(rng_key, self.alpha, shape)