def inpaint_bottom_conditional(h_upper, p_upper, ll_function, q_upper, x, mask, oversample, ninner): nsamples = 1 x = replicate_batch(x, oversample) mask = replicate_batch(mask, oversample) h_upper = replicate_batch(h_upper, oversample) x_, _ = p_upper.sample(h_upper) x = mask * x + (1 - mask) * x_ log_p = p_upper.log_prob(x, h_upper) # Evaluate q(x) _, log_ql = ll_function(x, ninner) log_qu = q_upper.log_prob(h_upper, x) # Calculate weights log_w = (log_ql + log_qu - log_p) / 2 w_norm = logsumexp(log_w, axis=0) log_w = log_w - w_norm w = tensor.exp(log_w) idx = subsample(w, nsamples) return x[idx, :]
def sample_bottom_conditional(h_upper, p_upper, ll_function, q_upper, oversample, ninner): nsamples = 1 """ #h_upper = replicate_batch(h_upper, oversample) # First, get proposals x = p_upper.sample_expected(h_upper) return x """ h_upper = replicate_batch(h_upper, oversample) x, log_p = p_upper.sample(h_upper) # Evaluate q(x) and q(h1|x) _, log_ql = ll_function(x, ninner) log_qu = q_upper.log_prob(h_upper, x) # Calculate weights log_w = (log_ql + log_qu - log_p) / 2 w_norm = logsumexp(log_w, axis=0) log_w = log_w - w_norm w = tensor.exp(log_w) idx = subsample(w, nsamples) return x[idx, :]
def sample_top_conditional(h_lower, p_top, q_lower, oversample): nsamples = 1 h_lower = replicate_batch(h_lower, oversample) # First, get proposals h1, log_1p = p_top.sample(oversample) log_1q = q_lower.log_prob(h1, h_lower) log_1ps = (log_1p + log_1q) / 2 log_1 = logsumexp2(log_1p, log_1q) h2, log_2q = q_lower.sample(h_lower) log_2p = p_top.log_prob(h2) log_2ps = (log_2p + log_2q) / 2 log_2 = logsumexp2(log_2p, log_2q) h_proposals = tensor.concatenate([h1, h2], axis=0) log_proposals = tensor.concatenate([log_1, log_2], axis=0) # - np.log(2.) log_ps = tensor.concatenate([log_1ps, log_2ps], axis=0) # Calculate weights log_w = log_ps - log_proposals w_norm = logsumexp(log_w, axis=0) log_w = log_w - w_norm w = tensor.exp(log_w) idx = subsample(w, nsamples) return h_proposals[idx, :]
def sample_conditional(h_upper, h_lower, p_upper, p_lower, q_upper, q_lower, oversample): """ return (h, log_ps) """ nsamples = 1 h_upper = replicate_batch(h_upper, oversample) h_lower = replicate_batch(h_lower, oversample) # First, get proposals h1, log_1pu = p_upper.sample(h_upper) log_1pl = p_lower.log_prob(h_lower, h1) log_1qu = q_upper.log_prob(h_upper, h1) log_1ql = q_lower.log_prob(h1, h_lower) log_1ps = (log_1pu + log_1pl + log_1ql + log_1qu) / 2 log_1 = logsumexp2(log_1pu, log_1ql) h2, log_2ql = q_lower.sample(h_lower) log_2qu = q_upper.log_prob(h_upper, h2) log_2pl = p_lower.log_prob(h_lower, h2) log_2pu = p_upper.log_prob(h2, h_upper) log_2ps = (log_2pu + log_2pl + log_2ql + log_2qu) / 2 log_2 = logsumexp2(log_2pu, log_2ql) h_proposals = tensor.concatenate([h1, h2], axis=0) log_proposals = tensor.concatenate([log_1, log_2], axis=0) # - np.log(2.) log_ps = tensor.concatenate([log_1ps, log_2ps], axis=0) # Calculate weights log_w = log_ps - log_proposals w_norm = logsumexp(log_w, axis=0) log_w = log_w - w_norm w = tensor.exp(log_w) idx = subsample(w, nsamples) return h_proposals[idx, :]
brick = m.get_top_bricks()[0] while len(brick.parents) > 0: brick = brick.parents[0] assert isinstance(brick, (ReweightedWakeSleep, BiHM, GMM, VAE)) #---------------------------------------------------------------------- logger.info("Compiling function...") n_layers = len(brick.p_layers) n_samples = tensor.iscalar('n_samples') x = tensor.matrix('features') batch_size = x.shape[0] x_ = replicate_batch(x, n_samples) samples, log_p, log_q = brick.sample_q(x_) # Reshape and sum samples = unflatten_values(samples, batch_size, n_samples) log_p = unflatten_values(log_p, batch_size, n_samples) log_q = unflatten_values(log_q, batch_size, n_samples) # Importance weights for q proposal for p log_p_all = sum(log_p) # This is the python sum over a list log_q_all = sum(log_q) # This is the python sum over a list log_pq = (log_p_all - log_q_all) log_px = logsumexp(log_pq, axis=1) - tensor.log(n_samples) log_qp = (log_q_all - log_p_all)
brick = m.get_top_bricks()[0] while len(brick.parents) > 0: brick = brick.parents[0] assert isinstance(brick, (ReweightedWakeSleep, BiHM, GMM, VAE)) has_ps = isinstance(brick, BiHM) #---------------------------------------------------------------------- logger.info("Compiling function...") batch_size = 1 n_samples = tensor.iscalar('n_samples') x = tensor.matrix('features') x_ = replicate_batch(x, n_samples) samples, log_p, log_q = brick.sample_q(x_) # Reshape and sum samples = unflatten_values(samples, batch_size, n_samples) log_p = unflatten_values(log_p, batch_size, n_samples) log_q = unflatten_values(log_q, batch_size, n_samples) # Importance weights for q proposal for p log_p_all = sum(log_p) # This is the python sum over a list log_q_all = sum(log_q) # This is the python sum over a list log_pq = (log_p_all-log_q_all)-tensor.log(n_samples) w_norm = logsumexp(log_pq, axis=1) log_wp = log_pq-tensor.shape_padright(w_norm) wp = tensor.exp(log_wp)