def generate(params):
    dbn_params = dbn.stack_params(params)

    # 10 fantasies.
    # initial_pixels = np.zeros((10, 28**2))
    # initial_pixels = inputs[0:10]

    # Clamp the softmax units, one for each class.
    sample_v_softmax_clamped = functools.partial(sample_v_softmax,
                                                 labels=np.eye(10))

    # Perform an upward pass from the pixels to the visible units of
    # the top-level RBM.
    # initial_v = np.hstack((
    #         np.eye(10),
    #         up_pass(dbn_params, initial_pixels)))

    initial_v = np.hstack(
        (np.eye(10), np.random.random((10, dbn_params[-1].W.shape[1] - 10))))

    # Initialize the gibbs chain.
    gc = rbm.gibbs_chain(initial_v, dbn_params[-1], rbm.sample_h,
                         sample_v_softmax_clamped)

    tile_2_by_5 = functools.partial(utils.tile, grid_shape=(2, 5))

    gen = itertools.islice(gc, 1, None, 1)
    gen = itertools.islice(gen, 2000)
    gen = itertools.imap(operator.itemgetter(1), gen)
    gen = itertools.imap(lambda v: down_pass(dbn_params, v), gen)
    gen = itertools.imap(tile_2_by_5, gen)

    # Save to disk.
    utils.save_images(gen, tempfile.mkdtemp(dir=OUTPUT_PATH))
def generate(params):
    dbn_params = dbn.stack_params(params)

    # 10 fantasies.
    # initial_pixels = np.zeros((10, 28**2))
    # initial_pixels = inputs[0:10]

    # Clamp the softmax units, one for each class.
    sample_v_softmax_clamped = functools.partial(sample_v_softmax,
                                                 labels=np.eye(10))

    # Perform an upward pass from the pixels to the visible units of
    # the top-level RBM.
    # initial_v = np.hstack((
    #         np.eye(10),
    #         up_pass(dbn_params, initial_pixels)))

    initial_v = np.hstack((
            np.eye(10),
            np.random.random((10, dbn_params[-1].W.shape[1] - 10))))

    # Initialize the gibbs chain.
    gc = rbm.gibbs_chain(initial_v,
                         dbn_params[-1],
                         rbm.sample_h,
                         sample_v_softmax_clamped)

    tile_2_by_5 = functools.partial(utils.tile, grid_shape=(2, 5))

    gen = itertools.islice(gc, 1, None, 1)
    gen = itertools.islice(gen, 2000)
    gen = itertools.imap(operator.itemgetter(1), gen)
    gen = itertools.imap(lambda v: down_pass(dbn_params, v), gen)
    gen = itertools.imap(tile_2_by_5, gen)

    # Save to disk.
    utils.save_images(gen, tempfile.mkdtemp(dir=OUTPUT_PATH))
Exemple #3
0
def contrastive_wake_sleep(params, data, weight_decay=None, cd_k=1):
    inputs, targets = data.inputs, data.targets
    num_cases = inputs.shape[0]

    # Turn the single tuple of parameters into something easier to
    # work with.
    dbn_params = dbn.stack_params(params)
    grad = []

    # Wake phase.
    wake_hid1_states = rbm.sample_bernoulli(logistic(inputs.dot(dbn_params[0].W_r) + dbn_params[0].b_r))
    wake_hid2_states = rbm.sample_bernoulli(logistic(wake_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r))

    # Contrastive divergence.
    gc = rbm.gibbs_chain(np.hstack((targets, wake_hid2_states)),
                         dbn_params[-1],
                         rbm.sample_h,
                         sample_v_softmax,
                         cd_k + 1)

    pos_sample = gc.next()
    if cd_k == 1:
        neg_sample = gc.next()
    else:
        recon_sample = gc.next()
        neg_sample = itertools.islice(gc, cd_k - 2, None).next()

    # Sleep phase.
    sleep_hid2_states = neg_sample[0][:,mnist.NUM_CLASSES:]
    sleep_hid1_states = rbm.sample_bernoulli(logistic(sleep_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g))
    sleep_vis_probs = logistic(sleep_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g)

    # Predictions.
    p_sleep_hid2 = logistic(sleep_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r)
    p_sleep_hid1 = logistic(sleep_vis_probs.dot(dbn_params[0].W_r) + dbn_params[0].b_r)
    p_wake_vis = logistic(wake_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g)
    p_wake_hid1 = logistic(wake_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g)

    # Gradients.
    # Layer 0.
    W_r_grad = sleep_vis_probs.T.dot(p_sleep_hid1 - sleep_hid1_states) / num_cases
    b_r_grad = np.mean(p_sleep_hid1 - sleep_hid1_states, 0)
    W_g_grad = wake_hid1_states.T.dot(p_wake_vis - inputs) / num_cases
    b_g_grad = np.mean(p_wake_vis - inputs, 0)
    grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad])

    # Layer 1.
    W_r_grad = sleep_hid1_states.T.dot(p_sleep_hid2 - sleep_hid2_states) / num_cases
    b_r_grad = np.mean(p_sleep_hid2 - sleep_hid2_states, 0)
    W_g_grad = wake_hid2_states.T.dot(p_wake_hid1 - wake_hid1_states) / num_cases
    b_g_grad = np.mean(p_wake_hid1 - wake_hid1_states, 0)
    grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad])
    
    # Top-level RBM.
    pos_grad = rbm.neg_free_energy_grad(dbn_params[-1], pos_sample)
    neg_grad = rbm.neg_free_energy_grad(dbn_params[-1], neg_sample)
    rbm_grad = map(operator.sub, neg_grad, pos_grad)
    grad.extend(rbm_grad)

    # Weight decay.
    if weight_decay:
        weight_grad = (weight_decay(p)[1] for p in params)
        grad = map(operator.add, grad, weight_grad)

    # One-step reconstruction error.
    if cd_k == 1:
        recon = sleep_vis_probs
    else:
        # Perform a determisitic down pass from the first sample of
        # the Gibbs chain in order to compute the one-step
        # reconstruction error.
        recon_hid2_probs = recon_sample[1][:,mnist.NUM_CLASSES:]
        recon_hid1_probs = rbm.sample_bernoulli(logistic(recon_hid2_probs.dot(dbn_params[1].W_g) + dbn_params[1].b_g))
        recon = logistic(recon_hid1_probs.dot(dbn_params[0].W_g) + dbn_params[0].b_g)

    error = np.sum((inputs - recon) ** 2) / num_cases

    return error, grad
Exemple #4
0
def contrastive_wake_sleep(params, data, weight_decay=None, cd_k=1):
    inputs, targets = data.inputs, data.targets
    num_cases = inputs.shape[0]

    # Turn the single tuple of parameters into something easier to
    # work with.
    dbn_params = dbn.stack_params(params)
    grad = []

    # Wake phase.
    wake_hid1_states = rbm.sample_bernoulli(
        logistic(inputs.dot(dbn_params[0].W_r) + dbn_params[0].b_r))
    wake_hid2_states = rbm.sample_bernoulli(
        logistic(wake_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r))

    # Contrastive divergence.
    gc = rbm.gibbs_chain(np.hstack(
        (targets, wake_hid2_states)), dbn_params[-1], rbm.sample_h,
                         sample_v_softmax, cd_k + 1)

    pos_sample = gc.next()
    if cd_k == 1:
        neg_sample = gc.next()
    else:
        recon_sample = gc.next()
        neg_sample = itertools.islice(gc, cd_k - 2, None).next()

    # Sleep phase.
    sleep_hid2_states = neg_sample[0][:, mnist.NUM_CLASSES:]
    sleep_hid1_states = rbm.sample_bernoulli(
        logistic(sleep_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g))
    sleep_vis_probs = logistic(
        sleep_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g)

    # Predictions.
    p_sleep_hid2 = logistic(
        sleep_hid1_states.dot(dbn_params[1].W_r) + dbn_params[1].b_r)
    p_sleep_hid1 = logistic(
        sleep_vis_probs.dot(dbn_params[0].W_r) + dbn_params[0].b_r)
    p_wake_vis = logistic(
        wake_hid1_states.dot(dbn_params[0].W_g) + dbn_params[0].b_g)
    p_wake_hid1 = logistic(
        wake_hid2_states.dot(dbn_params[1].W_g) + dbn_params[1].b_g)

    # Gradients.
    # Layer 0.
    W_r_grad = sleep_vis_probs.T.dot(p_sleep_hid1 -
                                     sleep_hid1_states) / num_cases
    b_r_grad = np.mean(p_sleep_hid1 - sleep_hid1_states, 0)
    W_g_grad = wake_hid1_states.T.dot(p_wake_vis - inputs) / num_cases
    b_g_grad = np.mean(p_wake_vis - inputs, 0)
    grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad])

    # Layer 1.
    W_r_grad = sleep_hid1_states.T.dot(p_sleep_hid2 -
                                       sleep_hid2_states) / num_cases
    b_r_grad = np.mean(p_sleep_hid2 - sleep_hid2_states, 0)
    W_g_grad = wake_hid2_states.T.dot(p_wake_hid1 -
                                      wake_hid1_states) / num_cases
    b_g_grad = np.mean(p_wake_hid1 - wake_hid1_states, 0)
    grad.extend([W_r_grad, b_r_grad, W_g_grad, b_g_grad])

    # Top-level RBM.
    pos_grad = rbm.neg_free_energy_grad(dbn_params[-1], pos_sample)
    neg_grad = rbm.neg_free_energy_grad(dbn_params[-1], neg_sample)
    rbm_grad = map(operator.sub, neg_grad, pos_grad)
    grad.extend(rbm_grad)

    # Weight decay.
    if weight_decay:
        weight_grad = (weight_decay(p)[1] for p in params)
        grad = map(operator.add, grad, weight_grad)

    # One-step reconstruction error.
    if cd_k == 1:
        recon = sleep_vis_probs
    else:
        # Perform a determisitic down pass from the first sample of
        # the Gibbs chain in order to compute the one-step
        # reconstruction error.
        recon_hid2_probs = recon_sample[1][:, mnist.NUM_CLASSES:]
        recon_hid1_probs = rbm.sample_bernoulli(
            logistic(
                recon_hid2_probs.dot(dbn_params[1].W_g) + dbn_params[1].b_g))
        recon = logistic(
            recon_hid1_probs.dot(dbn_params[0].W_g) + dbn_params[0].b_g)

    error = np.sum((inputs - recon)**2) / num_cases

    return error, grad