def generate(params): dbn_params = dbn.stack_params(params) # 10 fantasies. # initial_pixels = np.zeros((10, 28**2)) # initial_pixels = inputs[0:10] # Clamp the softmax units, one for each class. sample_v_softmax_clamped = functools.partial(sample_v_softmax, labels=np.eye(10)) # Perform an upward pass from the pixels to the visible units of # the top-level RBM. # initial_v = np.hstack(( # np.eye(10), # up_pass(dbn_params, initial_pixels))) initial_v = np.hstack( (np.eye(10), np.random.random((10, dbn_params[-1].W.shape[1] - 10)))) # Initialize the gibbs chain. gc = rbm.gibbs_chain(initial_v, dbn_params[-1], rbm.sample_h, sample_v_softmax_clamped) tile_2_by_5 = functools.partial(utils.tile, grid_shape=(2, 5)) gen = itertools.islice(gc, 1, None, 1) gen = itertools.islice(gen, 2000) gen = itertools.imap(operator.itemgetter(1), gen) gen = itertools.imap(lambda v: down_pass(dbn_params, v), gen) gen = itertools.imap(tile_2_by_5, gen) # Save to disk. utils.save_images(gen, tempfile.mkdtemp(dir=OUTPUT_PATH))
def generate(params): dbn_params = dbn.stack_params(params) # 10 fantasies. # initial_pixels = np.zeros((10, 28**2)) # initial_pixels = inputs[0:10] # Clamp the softmax units, one for each class. sample_v_softmax_clamped = functools.partial(sample_v_softmax, labels=np.eye(10)) # Perform an upward pass from the pixels to the visible units of # the top-level RBM. # initial_v = np.hstack(( # np.eye(10), # up_pass(dbn_params, initial_pixels))) initial_v = np.hstack(( np.eye(10), np.random.random((10, dbn_params[-1].W.shape[1] - 10)))) # Initialize the gibbs chain. gc = rbm.gibbs_chain(initial_v, dbn_params[-1], rbm.sample_h, sample_v_softmax_clamped) tile_2_by_5 = functools.partial(utils.tile, grid_shape=(2, 5)) gen = itertools.islice(gc, 1, None, 1) gen = itertools.islice(gen, 2000) gen = itertools.imap(operator.itemgetter(1), gen) gen = itertools.imap(lambda v: down_pass(dbn_params, v), gen) gen = itertools.imap(tile_2_by_5, gen) # Save to disk. utils.save_images(gen, tempfile.mkdtemp(dir=OUTPUT_PATH))
def contrastive_wake_sleep(params, data, weight_decay=None, cd_k=1): inputs, targets = data.inputs, data.targets num_cases = inputs.shape[0] # Turn the single tuple of parameters into something easier to # work with. dbn_params = dbn.stack_params(params) grad = [] # Wake phase. wake_hid1_states = rbm.sample_bernoulli(logistic(inputs.dot(dbn_params[0].W_r) + dbn_params[0].b_r)) wake_hid2_states = rbm.sample_bernoulli(logistic(wake_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r)) # Contrastive divergence. gc = rbm.gibbs_chain(np.hstack((targets, wake_hid2_states)), dbn_params[-1], rbm.sample_h, sample_v_softmax, cd_k + 1) pos_sample = gc.next() if cd_k == 1: neg_sample = gc.next() else: recon_sample = gc.next() neg_sample = itertools.islice(gc, cd_k - 2, None).next() # Sleep phase. sleep_hid2_states = neg_sample[0][:,mnist.NUM_CLASSES:] sleep_hid1_states = rbm.sample_bernoulli(logistic(sleep_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g)) sleep_vis_probs = logistic(sleep_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g) # Predictions. p_sleep_hid2 = logistic(sleep_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r) p_sleep_hid1 = logistic(sleep_vis_probs.dot(dbn_params[0].W_r) + dbn_params[0].b_r) p_wake_vis = logistic(wake_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g) p_wake_hid1 = logistic(wake_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g) # Gradients. # Layer 0. W_r_grad = sleep_vis_probs.T.dot(p_sleep_hid1 - sleep_hid1_states) / num_cases b_r_grad = np.mean(p_sleep_hid1 - sleep_hid1_states, 0) W_g_grad = wake_hid1_states.T.dot(p_wake_vis - inputs) / num_cases b_g_grad = np.mean(p_wake_vis - inputs, 0) grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad]) # Layer 1. W_r_grad = sleep_hid1_states.T.dot(p_sleep_hid2 - sleep_hid2_states) / num_cases b_r_grad = np.mean(p_sleep_hid2 - sleep_hid2_states, 0) W_g_grad = wake_hid2_states.T.dot(p_wake_hid1 - wake_hid1_states) / num_cases b_g_grad = np.mean(p_wake_hid1 - wake_hid1_states, 0) grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad]) # Top-level RBM. pos_grad = rbm.neg_free_energy_grad(dbn_params[-1], pos_sample) neg_grad = rbm.neg_free_energy_grad(dbn_params[-1], neg_sample) rbm_grad = map(operator.sub, neg_grad, pos_grad) grad.extend(rbm_grad) # Weight decay. if weight_decay: weight_grad = (weight_decay(p)[1] for p in params) grad = map(operator.add, grad, weight_grad) # One-step reconstruction error. if cd_k == 1: recon = sleep_vis_probs else: # Perform a determisitic down pass from the first sample of # the Gibbs chain in order to compute the one-step # reconstruction error. recon_hid2_probs = recon_sample[1][:,mnist.NUM_CLASSES:] recon_hid1_probs = rbm.sample_bernoulli(logistic(recon_hid2_probs.dot(dbn_params[1].W_g) + dbn_params[1].b_g)) recon = logistic(recon_hid1_probs.dot(dbn_params[0].W_g) + dbn_params[0].b_g) error = np.sum((inputs - recon) ** 2) / num_cases return error, grad
def contrastive_wake_sleep(params, data, weight_decay=None, cd_k=1): inputs, targets = data.inputs, data.targets num_cases = inputs.shape[0] # Turn the single tuple of parameters into something easier to # work with. dbn_params = dbn.stack_params(params) grad = [] # Wake phase. wake_hid1_states = rbm.sample_bernoulli( logistic(inputs.dot(dbn_params[0].W_r) + dbn_params[0].b_r)) wake_hid2_states = rbm.sample_bernoulli( logistic(wake_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r)) # Contrastive divergence. gc = rbm.gibbs_chain(np.hstack( (targets, wake_hid2_states)), dbn_params[-1], rbm.sample_h, sample_v_softmax, cd_k + 1) pos_sample = gc.next() if cd_k == 1: neg_sample = gc.next() else: recon_sample = gc.next() neg_sample = itertools.islice(gc, cd_k - 2, None).next() # Sleep phase. sleep_hid2_states = neg_sample[0][:, mnist.NUM_CLASSES:] sleep_hid1_states = rbm.sample_bernoulli( logistic(sleep_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g)) sleep_vis_probs = logistic( sleep_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g) # Predictions. p_sleep_hid2 = logistic( sleep_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r) p_sleep_hid1 = logistic( sleep_vis_probs.dot(dbn_params[0].W_r) + dbn_params[0].b_r) p_wake_vis = logistic( wake_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g) p_wake_hid1 = logistic( wake_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g) # Gradients. # Layer 0. W_r_grad = sleep_vis_probs.T.dot(p_sleep_hid1 - sleep_hid1_states) / num_cases b_r_grad = np.mean(p_sleep_hid1 - sleep_hid1_states, 0) W_g_grad = wake_hid1_states.T.dot(p_wake_vis - inputs) / num_cases b_g_grad = np.mean(p_wake_vis - inputs, 0) grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad]) # Layer 1. W_r_grad = sleep_hid1_states.T.dot(p_sleep_hid2 - sleep_hid2_states) / num_cases b_r_grad = np.mean(p_sleep_hid2 - sleep_hid2_states, 0) W_g_grad = wake_hid2_states.T.dot(p_wake_hid1 - wake_hid1_states) / num_cases b_g_grad = np.mean(p_wake_hid1 - wake_hid1_states, 0) grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad]) # Top-level RBM. pos_grad = rbm.neg_free_energy_grad(dbn_params[-1], pos_sample) neg_grad = rbm.neg_free_energy_grad(dbn_params[-1], neg_sample) rbm_grad = map(operator.sub, neg_grad, pos_grad) grad.extend(rbm_grad) # Weight decay. if weight_decay: weight_grad = (weight_decay(p)[1] for p in params) grad = map(operator.add, grad, weight_grad) # One-step reconstruction error. if cd_k == 1: recon = sleep_vis_probs else: # Perform a determisitic down pass from the first sample of # the Gibbs chain in order to compute the one-step # reconstruction error. recon_hid2_probs = recon_sample[1][:, mnist.NUM_CLASSES:] recon_hid1_probs = rbm.sample_bernoulli( logistic( recon_hid2_probs.dot(dbn_params[1].W_g) + dbn_params[1].b_g)) recon = logistic( recon_hid1_probs.dot(dbn_params[0].W_g) + dbn_params[0].b_g) error = np.sum((inputs - recon)**2) / num_cases return error, grad