def sample_pos_neg(self, sample_shape=(1, ), eps=None): if eps is None: eps = {'W_df': 2 * self.wishart.W.shape[0]} P_neg, P, P_pos = self.wishart.sample_pos_neg(sample_shape, float(eps)) mus = self.normal.sample(sample_shape) samples_neg = tl2lt((P_neg, mus)) samples = tl2lt((P, mus)) samples_pos = tl2lt((P_pos, mus)) return samples_neg, samples, samples_pos
def sample_mult_eps(self, sample_shape=(1, ), eps=None): if eps is None: eps = {'G_df': 1, 'W_df': 2 * self.wishart.W.shape[0]} nu_neg, nu, nu_pos = self.gamma.sample_pos_neg(sample_shape, eps['G_df']) P_neg, P, P_pos = self.wishart.sample_pos_neg(sample_shape, eps['W_df']) samples_neg = {'W_df': tl2lt((nu, P_neg)), 'G_df': tl2lt((nu_neg, P))} samples_pos = {'W_df': tl2lt((nu, P_pos)), 'G_df': tl2lt((nu_pos, P))} samples = tl2lt((nu, P)) return samples_neg, samples, samples_pos
def sample_pos_neg(self, sample_shape=(1, ), eps=1): n = sample_shape[0] df = self.df.detach().numpy()[0] mu = self.mu.detach().numpy().reshape(1, self.d) S = self.S.detach().numpy() rate = self.rate.detach().numpy()[0] lam_neg = np.random.gamma(df - eps / 2, 1 / rate, (n, 1)) lam = lam_neg + np.random.gamma(eps / 2, 1 / rate, (n, 1)) lam_pos = lam + np.random.gamma(eps / 2, 1 / rate, (n, 1)) theta = np.random.randn(n, self.d) scaled_theta = theta.dot(S) X = mu + scaled_theta return (tl2lt((cast(X), cast(lam_neg))), tl2lt( (cast(X), cast(lam))), tl2lt((cast(X), cast(lam_pos))))
def sample(self, sample_shape=(1, )): n = sample_shape[0] df = self.df.detach().numpy()[0] mu = self.mu.detach().numpy().reshape(1, self.d) S = self.S.detach().numpy() rate = self.rate.detach().numpy()[0] lam = np.random.gamma(df, 1 / rate, (n, 1)) theta = np.random.randn(n, self.d) X = mu + theta.dot(S) return tl2lt((torch.from_numpy(X.astype(np.float64)), torch.from_numpy(lam.astype(np.float64))))
def sample(self, sample_shape=(1, )): P_samples = self.wishart.sample(sample_shape) mu_samples = self.normal.sample(sample_shape) return tl2lt((P_samples, mu_samples))
def sample(self, sample_shape=(1, )): nu_samples = self.gamma.sample(sample_shape) P_samples = self.wishart.sample(sample_shape) mu_samples = self.normal.sample(sample_shape) return tl2lt((nu_samples, P_samples, mu_samples))