def exact_samples(rbm, num, batch_units=10, show_progress=False): scores = get_scores(rbm, batch_units=batch_units).as_numpy_array() scores -= np.logaddexp.reduce(scores.ravel()) p = np.exp(scores) prefix_len = rbm.nhid - batch_units prefixes = combinations_array(prefix_len).as_numpy_array() postfixes = combinations_array(batch_units).as_numpy_array() p_row = p.sum(1) p_row /= p_row.sum() cond_p_col = p / p_row[:, nax] cond_p_col *= (1. - 1e-8) # keep np.random.multinomial from choking because the sum is greater than 1 vis = np.zeros((num, rbm.nvis)) hid = np.zeros((num, rbm.nhid)) with misc.gnumpy_conversion_check('allow'): rows = np.random.multinomial(1, p_row, size=num).argmax(1) #cols = np.random.multinomial(1, cond_p_col[rows, :]).argmax(1) cols = np.array([np.random.multinomial(1, cond_p_col[row, :]).argmax() for row in rows]) hid = np.hstack([prefixes[rows, :], postfixes[cols, :]]) vis = np.random.binomial(1, gnp.logistic(rbm.vis_inputs(hid))) return binary_rbms.RBMState(gnp.garray(vis), gnp.garray(hid))
def check_fisher_information_consistent(): """The top left block of exact_fisher_information should agree with exact_fisher_information_biases.""" with misc.gnumpy_conversion_check('allow'): rbm = random_rbm() G_bias = tractable.exact_fisher_information_biases(rbm, batch_units=BATCH_UNITS) G = tractable.exact_fisher_information(rbm, batch_units=BATCH_UNITS) assert_close(G_bias, G[:NVIS+NHID, :NVIS+NHID])
def check_fisher_information_indep(): """Fisher information should agree with analytic solution for base rate RBM.""" with misc.gnumpy_conversion_check('allow'): rbm = random_base_rate_rbm() E_v = gnp.logistic(rbm.vbias) E_h = gnp.logistic(rbm.hbias) G = tractable.exact_fisher_information(rbm, batch_units=BATCH_UNITS) assert_close(G, G.T, 'G not symmetric') G_vis_vishid = G[:NVIS, NVIS + NHID:].reshape((NVIS, NVIS, NHID)) G_hid_vishid = G[NVIS:NVIS + NHID, NVIS + NHID:].reshape( (NHID, NVIS, NHID)) G_vishid_vishid = G[NVIS + NHID:, NVIS + NHID:].reshape( (NVIS, NHID, NVIS, NHID)) assert_close(G_vis_vishid[0, 0, 1], E_v[0] * (1. - E_v[0]) * E_h[1]) assert_close(G_vis_vishid[0, 1, 2], 0.) assert_close(G_hid_vishid[0, 1, 0], E_h[0] * (1. - E_h[0]) * E_v[1]) assert_close(G_hid_vishid[0, 1, 2], 0.) assert_close(G_vishid_vishid[0, 1, 0, 1], E_v[0] * E_h[1] * (1. - E_v[0] * E_h[1])) assert_close(G_vishid_vishid[0, 1, 0, 2], E_v[0] * (1. - E_v[0]) * E_h[1] * E_h[2]) assert_close(G_vishid_vishid[0, 2, 1, 2], E_h[2] * (1. - E_h[2]) * E_v[0] * E_v[1]) assert_close(G_vishid_vishid[0, 1, 2, 3], 0.)
def check_against_exact(): with misc.gnumpy_conversion_check('allow'): rbm = test_tractable.random_rbm(NVIS, NHID) G, s = tractable.exact_fisher_information(rbm, return_mean=True, batch_units=2) rw = fisher.RegressionWeights.from_maximum_likelihood(G, NVIS, NHID) G, s = gnp.garray(G), gnp.garray(s) S = G + np.outer(s, s) m_unary = s[:NVIS + NHID] S_unary = S[:NVIS + NHID, :NVIS + NHID] m_pair = gnp.zeros((NVIS, NHID, 3)) S_pair = gnp.zeros((NVIS, NHID, 3, 3)) for i in range(NVIS): for j in range(NHID): vis_idx = i hid_idx = NVIS + j vishid_idx = NVIS + NHID + NHID * i + j idxs = np.array([vis_idx, hid_idx, vishid_idx]) m_pair[i, j, :] = s[idxs] S_pair[i, j, :] = S[idxs[:, nax], idxs[nax, :]] stats = fang.Statistics(m_unary, S_unary, m_pair, S_pair) beta, sigma_sq = stats.compute_regression_weights() assert np.allclose(beta, rw.beta) assert np.allclose(sigma_sq, rw.sigma_sq) Sigma = stats.unary_covariance() assert np.max(np.abs(Sigma - G[:NVIS + NHID, :NVIS + NHID])) < 1e-6
def check_fisher_information_consistent(): """The top left block of exact_fisher_information should agree with exact_fisher_information_biases.""" with misc.gnumpy_conversion_check('allow'): rbm = random_rbm() G_bias = tractable.exact_fisher_information_biases( rbm, batch_units=BATCH_UNITS) G = tractable.exact_fisher_information(rbm, batch_units=BATCH_UNITS) assert_close(G_bias, G[:NVIS + NHID, :NVIS + NHID])
def test_symmetric(): v = gnp.garray(np.random.uniform(size=(N, NVIS))) h = gnp.garray(np.random.uniform(size=(N, NHID))) stats = fang.Statistics.from_activations(v, h) with misc.gnumpy_conversion_check('allow'): assert np.allclose(stats.S_unary, stats.S_unary.T) assert np.allclose(stats.S_pair, stats.S_pair.as_numpy_array().swapaxes(2, 3))
def check_partition_function(): with misc.gnumpy_conversion_check('allow'): rbm = random_rbm() total = -np.infty for vis_ in itertools.product(*[[0, 1]] * NVIS): vis = gnp.garray(vis_) for hid_ in itertools.product(*[[0, 1]] * NHID): hid = gnp.garray(hid_) total = np.logaddexp(total, rbm.energy(vis[nax, :], hid[nax, :])[0]) assert np.allclose(tractable.exact_partition_function(rbm, batch_units=BATCH_UNITS), total)
def check_get_scores(): with misc.gnumpy_conversion_check('allow'): rbm = random_rbm() scores = tractable.get_scores(rbm, batch_units=BATCH_UNITS) prefixes = tractable.combinations_array(NHID - BATCH_UNITS) suffixes = tractable.combinations_array(BATCH_UNITS) for i, pre in enumerate(prefixes): for j, suff in enumerate(suffixes): hid = np.concatenate([pre, suff]) assert np.allclose(scores[i, j], rbm.free_energy_hid(hid[nax, :])[0])
def check_statistics(num_samples=1000): v = gnp.garray(np.random.uniform(size=(N, NVIS))) h = gnp.garray(np.random.uniform(size=(N, NHID))) stats = fang.Statistics.from_activations(v, h) with misc.gnumpy_conversion_check('allow'): g = np.zeros((num_samples, 5)) for i in range(num_samples): idx = np.random.randint(N) curr_v = np.random.binomial(1, v[idx, :]) curr_h = np.random.binomial(1, h[idx, :]) g[i, :] = np.array([ curr_v[0], curr_v[1], curr_h[0], curr_h[1], curr_v[0] * curr_h[1] ]) print 'm_unary v[0]', misc.check_expectation(stats.m_unary[0], g[:, 0]) print 'm_unary h[0]', misc.check_expectation(stats.m_unary[NVIS], g[:, 2]) print 'S_unary v[0] v[0]', misc.check_expectation(stats.S_unary[0, 0], g[:, 0]) print 'S_unary v[0] v[1]', misc.check_expectation(stats.S_unary[0, 1], g[:, 0] * g[:, 1]) print 'S_unary v[0] h[0]', misc.check_expectation(stats.S_unary[0, NVIS], g[:, 0] * g[:, 2]) print 'S_unary h[0] h[0]', misc.check_expectation(stats.S_unary[NVIS, NVIS], g[:, 2]) print 'S_unary h[0] h[1]', misc.check_expectation(stats.S_unary[NVIS, NVIS + 1], g[:, 2] * g[:, 3]) print print 'm_pair v[0]', misc.check_expectation(stats.m_pair[0, 1, 0], g[:, 0]) print 'm_pair h[1]', misc.check_expectation(stats.m_pair[0, 1, 1], g[:, 3]) print 'm_pair v[0] h[1]', misc.check_expectation(stats.m_pair[0, 1, 2], g[:, 4]) print 'S_pair v[0] v[0]', misc.check_expectation(stats.S_pair[0, 1, 0, 0], g[:, 0]) print 'S_pair v[0] h[1]', misc.check_expectation(stats.S_pair[0, 1, 0, 1], g[:, 0] * g[:, 3]) print 'S_pair v[0] vh[0, 1]', misc.check_expectation(stats.S_pair[0, 1, 0, 2], g[:, 0] * g[:, 4]) print 'S_pair h[1] h[1]', misc.check_expectation(stats.S_pair[0, 1, 1, 1], g[:, 3]) print 'S_pair h[1] vh[0, 1]', misc.check_expectation(stats.S_pair[0, 1, 1, 2], g[:, 3] * g[:, 4]) print 'S_pair vh[0, 1] vh[0, 1]', misc.check_expectation(stats.S_pair[0, 1, 2, 2], g[:, 4])
def check_partition_function(): with misc.gnumpy_conversion_check('allow'): rbm = random_rbm() total = -np.infty for vis_ in itertools.product(*[[0, 1]] * NVIS): vis = gnp.garray(vis_) for hid_ in itertools.product(*[[0, 1]] * NHID): hid = gnp.garray(hid_) total = np.logaddexp(total, rbm.energy(vis[nax, :], hid[nax, :])[0]) assert np.allclose( tractable.exact_partition_function(rbm, batch_units=BATCH_UNITS), total)
def correlation_fraction(g, s, nvis, nhid): with misc.gnumpy_conversion_check('allow'): expect_vis = s[:nvis] expect_hid = s[nvis:nvis+nhid] da = g[:nvis] db = g[nvis:nvis+nhid] dW = g[nvis+nhid:].reshape((nvis, nhid)) first_order_expl = gnp.outer(da, expect_hid) + gnp.outer(expect_vis, db) first_order_norm = gnp.sum(da**2) + gnp.sum(db**2) + gnp.sum(first_order_expl**2) dcorr = dW - first_order_expl dcorr_norm = gnp.sum(dcorr**2) g_norm = gnp.sum(g**2) #return first_order_norm, dcorr_norm, g_norm return dcorr_norm / (dcorr_norm + first_order_norm)
def correlation_fraction(g, s, nvis, nhid): with misc.gnumpy_conversion_check('allow'): expect_vis = s[:nvis] expect_hid = s[nvis:nvis + nhid] da = g[:nvis] db = g[nvis:nvis + nhid] dW = g[nvis + nhid:].reshape((nvis, nhid)) first_order_expl = gnp.outer(da, expect_hid) + gnp.outer( expect_vis, db) first_order_norm = gnp.sum(da**2) + gnp.sum(db**2) + gnp.sum( first_order_expl**2) dcorr = dW - first_order_expl dcorr_norm = gnp.sum(dcorr**2) g_norm = gnp.sum(g**2) #return first_order_norm, dcorr_norm, g_norm return dcorr_norm / (dcorr_norm + first_order_norm)
def check_fisher_information_biases_indep(): """Fisher information should agree with analytic solution for base rate RBM.""" with misc.gnumpy_conversion_check('allow'): rbm = random_base_rate_rbm() E_v = gnp.logistic(rbm.vbias) E_h = gnp.logistic(rbm.hbias) G = tractable.exact_fisher_information_biases(rbm, batch_units=BATCH_UNITS) assert_close(G, G.T, 'G not symmetric') G_vis_vis = G[:NVIS, :NVIS] G_vis_hid = G[:NVIS, NVIS:] G_hid_hid = G[NVIS:, NVIS:] assert_close(G_vis_vis[0, 0], E_v[0] * (1. - E_v[0])) assert_close(G_vis_vis[0, 1], 0.) assert_close(G_vis_hid[0, 0], 0.) assert_close(G_hid_hid[0, 0], E_h[0] * (1. - E_h[0])) assert_close(G_hid_hid[0, 1], 0.)
def plot_eigenspectrum(G, s, nvis, nhid): with misc.gnumpy_conversion_check('allow'): dim = G.shape[0] d, Q = scipy.linalg.eigh(G) d = d[::-1] Q = Q[:, ::-1] pts = np.unique(np.floor(np.logspace(0., np.log10(dim-1), 500)).astype(int)) - 1 cf = [fisher.correlation_fraction(Q[:, i], s, nvis, nhid) for i in pts] pylab.figure() pylab.subplot(2, 1, 1) pylab.loglog(range(1, dim+1), d, 'b-', lw=2.) pylab.xticks([]) pylab.yticks(fontsize='large') pylab.subplot(2, 1, 2) pylab.semilogx(pts+1, cf, 'r-', lw=2.) pylab.xticks(fontsize='x-large') pylab.yticks(fontsize='large')
def check_fisher_information_indep(): """Fisher information should agree with analytic solution for base rate RBM.""" with misc.gnumpy_conversion_check('allow'): rbm = random_base_rate_rbm() E_v = gnp.logistic(rbm.vbias) E_h = gnp.logistic(rbm.hbias) G = tractable.exact_fisher_information(rbm, batch_units=BATCH_UNITS) assert_close(G, G.T, 'G not symmetric') G_vis_vishid = G[:NVIS, NVIS+NHID:].reshape((NVIS, NVIS, NHID)) G_hid_vishid = G[NVIS:NVIS+NHID, NVIS+NHID:].reshape((NHID, NVIS, NHID)) G_vishid_vishid = G[NVIS+NHID:, NVIS+NHID:].reshape((NVIS, NHID, NVIS, NHID)) assert_close(G_vis_vishid[0, 0, 1], E_v[0] * (1. - E_v[0]) * E_h[1]) assert_close(G_vis_vishid[0, 1, 2], 0.) assert_close(G_hid_vishid[0, 1, 0], E_h[0] * (1. - E_h[0]) * E_v[1]) assert_close(G_hid_vishid[0, 1, 2], 0.) assert_close(G_vishid_vishid[0, 1, 0, 1], E_v[0] * E_h[1] * (1. - E_v[0] * E_h[1])) assert_close(G_vishid_vishid[0, 1, 0, 2], E_v[0] * (1. - E_v[0]) * E_h[1] * E_h[2]) assert_close(G_vishid_vishid[0, 2, 1, 2], E_h[2] * (1. - E_h[2]) * E_v[0] * E_v[1]) assert_close(G_vishid_vishid[0, 1, 2, 3], 0.)
def check_moments(): with misc.gnumpy_conversion_check('allow'): rbm = random_rbm() pfn = tractable.exact_partition_function(rbm, batch_units=BATCH_UNITS) expect_vis = gnp.zeros(rbm.nvis) expect_hid = gnp.zeros(rbm.nhid) expect_prod = gnp.zeros((rbm.nvis, rbm.nhid)) for hid_ in itertools.product(*[[0, 1]] * NHID): hid = gnp.garray(hid_) cond_vis = rbm.vis_expectations(hid) p = np.exp(rbm.free_energy_hid(hid[nax, :])[0] - pfn) expect_vis += p * cond_vis expect_hid += p * hid expect_prod += p * gnp.outer(cond_vis, hid) moments = tractable.exact_moments(rbm, batch_units=BATCH_UNITS) assert np.allclose(expect_vis, moments.expect_vis) assert np.allclose(expect_hid, moments.expect_hid) assert np.allclose(expect_prod, moments.expect_prod)
def exact_fisher_information(rbm, batch_units=10, show_progress=False, vis_shape=None, downsample=1, return_mean=False): batch_size = 2 ** batch_units if downsample == 1: vis_idxs = np.arange(rbm.nvis) else: temp = np.arange(rbm.nvis).reshape((28, 28)) mask = np.zeros((28, 28), dtype=bool) mask[::downsample, ::downsample] = 1 vis_idxs = temp[mask] nvis = vis_idxs.size nhid = rbm.nhid num_params = nvis + nhid + nvis * nhid E_vis = np.zeros(nvis) E_hid = np.zeros(nhid) E_vishid = np.zeros((nvis, nhid)) E_vis_vis = np.zeros((nvis, nvis)) E_vis_hid = np.zeros((nvis, nhid)) E_vis_vishid = np.zeros((nvis, nvis, nhid)) E_hid_hid = np.zeros((nhid, nhid)) E_hid_vishid = np.zeros((nhid, nvis, nhid)) E_vishid_vishid = np.zeros((nvis, nhid, nvis, nhid)) for hid, p in iter_configurations(rbm, batch_units=batch_units, show_progress=show_progress): with misc.gnumpy_conversion_check('allow'): cond_vis = gnp.logistic(rbm.vis_inputs(hid)) cond_vis = gnp.garray(cond_vis.as_numpy_array()[:, vis_idxs]) vishid = (cond_vis[:, :, nax] * hid[:, nax, :]).reshape((batch_size, nvis * nhid)) var_vis = cond_vis * (1. - cond_vis) E_vis += gnp.dot(p, cond_vis) E_hid += gnp.dot(p, hid) E_vishid += gnp.dot(cond_vis.T * p, hid) E_vis_vis += gnp.dot(cond_vis.T * p, cond_vis) diag_term = gnp.dot(p, cond_vis * (1. - cond_vis)) E_vis_vis += gnp.garray(np.diag(diag_term.as_numpy_array())) E_vis_hid += gnp.dot(cond_vis.T * p, hid) E_hid_hid += gnp.dot(hid.T * p, hid) E_vis_vishid += gnp.dot(cond_vis.T * p, vishid).reshape((nvis, nvis, nhid)) diag_term = gnp.dot(var_vis.T * p, hid) E_vis_vishid[np.arange(nvis), np.arange(nvis), :] += diag_term E_hid_vishid += gnp.dot(hid.T * p, vishid).reshape((nhid, nvis, nhid)) E_vishid_vishid += gnp.dot(vishid.T * p, vishid).reshape((nvis, nhid, nvis, nhid)) diag_term = ((cond_vis * (1. - cond_vis))[:, :, nax, nax] * hid[:, nax, :, nax] * hid[:, nax, nax, :] * p[:, nax, nax, nax]).sum(0) E_vishid_vishid[np.arange(nvis), :, np.arange(nvis), :] += diag_term G = np.zeros((num_params, num_params)) vis_slc = slice(0, nvis) hid_slc = slice(nvis, nvis + nhid) vishid_slc = slice(nvis + nhid, None) G[vis_slc, vis_slc] = E_vis_vis G[vis_slc, hid_slc] = E_vis_hid G[vis_slc, vishid_slc] = E_vis_vishid.reshape((nvis, nvis * nhid)) G[hid_slc, vis_slc] = E_vis_hid.T G[hid_slc, hid_slc] = E_hid_hid G[hid_slc, vishid_slc] = E_hid_vishid.reshape((nhid, nvis * nhid)) G[vishid_slc, vis_slc] = E_vis_vishid.reshape((nvis, nvis * nhid)).T G[vishid_slc, hid_slc] = E_hid_vishid.reshape((nhid, nvis * nhid)).T G[vishid_slc, vishid_slc] = E_vishid_vishid.reshape((nvis * nhid, nvis * nhid)) s = np.concatenate([E_vis, E_hid, E_vishid.ravel()]) G -= np.outer(s, s) if return_mean: return G, s else: return G