def check_affine(f, *nu_inputs): types = ",".join(["{%s,%s}" % (x.dtype, x.ndim) for x in nu_inputs]) cgt.utils.colorprint(cgt.utils.Color.YELLOW, "Testing %s(%s)\n" % (f.__name__, types)) sy_inputs = map(tensor_like, nu_inputs) for (i, sy) in enumerate(sy_inputs): sy.name = "x%i" % i sy_result = f(*sy_inputs) def maybeprint(msg): if DISPLAY: print msg maybeprint("Function:") if DISPLAY: cgt.print_tree([sy_result]) f_cgt = cgt.function(sy_inputs, sy_result) sy_grads = cgt.grad(sy_result, sy_inputs) gradf_cgt = cgt.function(sy_inputs, sy_grads) sy_result_simple = core.simplify([sy_result]) sy_grads_simple = core.simplify(sy_grads) maybeprint("Gradient:") if DISPLAY: cgt.print_tree(sy_grads) maybeprint("Gradient after simplification:") if DISPLAY: cgt.print_tree(sy_grads_simple) out_true = f(*nu_inputs) out_cgt = f_cgt(*nu_inputs) grads_true = gradients_affine(f_cgt, nu_inputs, h=1e-4 if "max" in f.__name__ else 1e-1) grads_cgt = gradf_cgt(*nu_inputs) rtol = {"single": 1e-3, "double": 1e-5}[cgt.get_precision()] np.testing.assert_allclose(out_cgt, out_true, rtol=rtol) for (g_cgt, g_true) in zip(grads_cgt, grads_true): np.testing.assert_allclose(g_cgt, g_true, rtol=rtol) result_count = cgt.count_nodes(sy_result_simple) grad_count = cgt.count_nodes(sy_grads_simple) maybeprint("Result before: %i. after: %i" % (cgt.count_nodes([sy_result]), result_count)) maybeprint("Grad before: %i. after: %i" % (cgt.count_nodes(sy_grads), grad_count)) PROB2RESULT[f.__name__] = {} PROB2RESULT[f.__name__]["fn"] = result_count PROB2RESULT[f.__name__]["grad"] = grad_count
X_tnk = cgt.tensor3("X") cell = gru.GRUCell([dim_x], mem_size) Minit_nk = cgt.zeros((X_tnk.shape[0], X_tnk.shape[1]), cgt.floatX) M = Minit_nk for t in xrange(horizon): M = cell(M, X_tnk[t]) # cgt.print_tree(M) print "simplifying..." M_simp = cgt.simplify([M]) print "done" # cgt.print_tree(M_simp) print "fn before:", cgt.count_nodes(M) print "fn after:", cgt.count_nodes(M_simp) gs = cgt.grad(cgt.sum(M), cell.params()) print "grad before", cgt.count_nodes(gs) g_simp = cgt.simplify(gs) print "grad after", cgt.count_nodes(g_simp) # M = cgt.simplify(M) elapsed.append(time() - tstart) import matplotlib.pyplot as plt plt.plot(horizons, elapsed, 'x-') plt.show()
X_tnk = cgt.tensor3("X") cell = gru.GRUCell([dim_x], mem_size) Minit_nk = cgt.zeros((X_tnk.shape[0], X_tnk.shape[1]),cgt.floatX) M = Minit_nk for t in xrange(horizon): M = cell(M, X_tnk[t]) # cgt.print_tree(M) print "simplifying..." M_simp = cgt.simplify([M]) print "done" # cgt.print_tree(M_simp) print "fn before:",cgt.count_nodes(M) print "fn after:",cgt.count_nodes(M_simp) gs = cgt.grad(cgt.sum(M), cell.params()) print "grad before", cgt.count_nodes(gs) g_simp = cgt.simplify(gs) print "grad after",cgt.count_nodes(g_simp) # M = cgt.simplify(M) elapsed.append(time()-tstart) import matplotlib.pyplot as plt plt.plot(horizons,elapsed,'x-') plt.show()