def Lx_func(*args): Lneg_x = fisher.compute_Lx( neg_energies, self.params, args) if self.flags['minresQLP']: return Lneg_x, {} else: return Lneg_x
def test_compute_Lx(): ## baseline result ## Lx = numpy.dot(L, vals['x']) Lx_w = Lx[:N0*N1].reshape(N0,N1) Lx_v = Lx[N0*N1 : N0*N1 + N1*N2].reshape(N1,N2) Lx_a = Lx[N0*N1 + N1*N2 : N0*N1 + N1*N2 + N0] Lx_b = Lx[N0*N1 + N1*N2 + N0 : N0*N1 + N1*N2 + N0 + N1] Lx_c = Lx[-N2:] # natural.compute_Lx implementation symb_inputs = [symb['v'], symb['g'], symb['h'], symb['x_W'], symb['x_V'], symb['x_a'], symb['x_b'], symb['x_c']] Lx = natural.compute_Lx(*symb_inputs) t1 = time.time() f = theano.function(symb_inputs, Lx) print 'natural.compute_Lx elapsed: ', time.time() - t1 rvals = f(vals['v'], vals['g'], vals['h'], vals['x_W'], vals['x_V'], vals['x_a'], vals['x_b'], vals['x_c']) numpy.testing.assert_almost_equal(Lx_w, rvals[0], decimal=3) numpy.testing.assert_almost_equal(Lx_v, rvals[1], decimal=3) numpy.testing.assert_almost_equal(Lx_a, rvals[2], decimal=3) numpy.testing.assert_almost_equal(Lx_b, rvals[3], decimal=3) numpy.testing.assert_almost_equal(Lx_c, rvals[4], decimal=3) # fisher.compute_Lx implementation energies = - T.sum(T.dot(symb['v'], symb['W']) * symb['g'], axis=1) \ - T.sum(T.dot(symb['g'], symb['V']) * symb['h'], axis=1) \ - T.dot(symb['v'], symb['a']) \ - T.dot(symb['g'], symb['b']) \ - T.dot(symb['h'], symb['c']) symb_params = [symb['W'], symb['V'], symb['a'], symb['b'], symb['c']] symb_x = [symb['x_W'], symb['x_V'], symb['x_a'], symb['x_b'], symb['x_c']] LLx = fisher.compute_Lx(energies, symb_params, symb_x) f_inputs = [symb['v'], symb['g'], symb['h']] + symb_params + symb_x f = theano.function(f_inputs, LLx) t1 = time.time() rvals = f(vals['v'], vals['g'], vals['h'], vals['W'], vals['V'], vals['a'], vals['b'], vals['c'], vals['x_W'], vals['x_V'], vals['x_a'], vals['x_b'], vals['x_c']) ### compare both implementation ### print 'fisher.compute_Lx elapsed: ', time.time() - t1 numpy.testing.assert_almost_equal(Lx_w, rvals[0], decimal=3) numpy.testing.assert_almost_equal(Lx_v, rvals[1], decimal=3) numpy.testing.assert_almost_equal(Lx_a, rvals[2], decimal=3) numpy.testing.assert_almost_equal(Lx_b, rvals[3], decimal=3) numpy.testing.assert_almost_equal(Lx_c, rvals[4], decimal=3)
def test_compute_Lx(): ## baseline result ## Lx = numpy.dot(L, vals['x']) Lx_w = Lx[:N0 * N1].reshape(N0, N1) Lx_v = Lx[N0 * N1:N0 * N1 + N1 * N2].reshape(N1, N2) Lx_a = Lx[N0 * N1 + N1 * N2:N0 * N1 + N1 * N2 + N0] Lx_b = Lx[N0 * N1 + N1 * N2 + N0:N0 * N1 + N1 * N2 + N0 + N1] Lx_c = Lx[-N2:] # natural.compute_Lx implementation symb_inputs = [ symb['v'], symb['g'], symb['h'], symb['x_W'], symb['x_V'], symb['x_a'], symb['x_b'], symb['x_c'] ] Lx = natural.compute_Lx(*symb_inputs) t1 = time.time() f = theano.function(symb_inputs, Lx) print 'natural.compute_Lx elapsed: ', time.time() - t1 rvals = f(vals['v'], vals['g'], vals['h'], vals['x_W'], vals['x_V'], vals['x_a'], vals['x_b'], vals['x_c']) numpy.testing.assert_almost_equal(Lx_w, rvals[0], decimal=3) numpy.testing.assert_almost_equal(Lx_v, rvals[1], decimal=3) numpy.testing.assert_almost_equal(Lx_a, rvals[2], decimal=3) numpy.testing.assert_almost_equal(Lx_b, rvals[3], decimal=3) numpy.testing.assert_almost_equal(Lx_c, rvals[4], decimal=3) # fisher.compute_Lx implementation energies = - T.sum(T.dot(symb['v'], symb['W']) * symb['g'], axis=1) \ - T.sum(T.dot(symb['g'], symb['V']) * symb['h'], axis=1) \ - T.dot(symb['v'], symb['a']) \ - T.dot(symb['g'], symb['b']) \ - T.dot(symb['h'], symb['c']) symb_params = [symb['W'], symb['V'], symb['a'], symb['b'], symb['c']] symb_x = [symb['x_W'], symb['x_V'], symb['x_a'], symb['x_b'], symb['x_c']] LLx = fisher.compute_Lx(energies, symb_params, symb_x) f_inputs = [symb['v'], symb['g'], symb['h']] + symb_params + symb_x f = theano.function(f_inputs, LLx) t1 = time.time() rvals = f(vals['v'], vals['g'], vals['h'], vals['W'], vals['V'], vals['a'], vals['b'], vals['c'], vals['x_W'], vals['x_V'], vals['x_a'], vals['x_b'], vals['x_c']) ### compare both implementation ### print 'fisher.compute_Lx elapsed: ', time.time() - t1 numpy.testing.assert_almost_equal(Lx_w, rvals[0], decimal=3) numpy.testing.assert_almost_equal(Lx_v, rvals[1], decimal=3) numpy.testing.assert_almost_equal(Lx_a, rvals[2], decimal=3) numpy.testing.assert_almost_equal(Lx_b, rvals[3], decimal=3) numpy.testing.assert_almost_equal(Lx_c, rvals[4], decimal=3)
def test_runtime(): ### theano implementation ### energies = - T.sum(T.dot(symb['v'], symb['W']) * symb['g'], axis=1) \ - T.sum(T.dot(symb['g'], symb['V']) * symb['h'], axis=1) \ - T.dot(symb['v'], symb['a']) \ - T.dot(symb['g'], symb['b']) \ - T.dot(symb['h'], symb['c']) # Fisher Implementation symb_params = [symb['W'], symb['V'], symb['a'], symb['b'], symb['c']] symb_x = [symb['x_W'], symb['x_V'], symb['x_a'], symb['x_b'], symb['x_c']] f_inputs = [symb['v'], symb['g'], symb['h']] + symb_params + symb_x fisher_Lx = fisher.compute_Lx(energies, symb_params, symb_x) fisher_func = theano.function(f_inputs, fisher_Lx) samples = [symb['v'], symb['g'], symb['h']] symb_weights = [symb['x_W'], symb['x_V']] symb_biases = [symb['x_a'], symb['x_b'], symb['x_c']] f_inputs = [symb['v'], symb['g'], symb['h']] + symb_weights + symb_biases natural_Lx = natural.generic_compute_Lx(samples, symb_weights, symb_biases) natural_func = theano.function(f_inputs, natural_Lx) t1 = time.time() fisher_rval = fisher_func(vals['v'], vals['g'], vals['h'], vals['W'], vals['V'], vals['a'], vals['b'], vals['c'], vals['x_W'], vals['x_V'], vals['x_a'], vals['x_b'], vals['x_c']) print 'Fisher runtime (s): ', time.time() - t1 t1 = time.time() nat_rval = natural_func(vals['v'], vals['g'], vals['h'], vals['x_W'], vals['x_V'], vals['x_a'], vals['x_b'], vals['x_c']) print 'Natural runtime (s): ', time.time() - t1 ### make sure the two return the same thing ### for (rval1, rval2) in zip(fisher_rval, nat_rval): numpy.testing.assert_almost_equal(rval1, rval2, decimal=2)
def Lx_func(*args): Lneg_x = fisher.compute_Lx(neg_energies, self.params, args) if self.flags['minresQLP']: return Lneg_x, {} else: return Lneg_x
def Lx_func(*args): symb_params = [symb[p] for p in params] Lneg_x = fisher.compute_Lx(energies, symb_params, args) return Lneg_x