def transform(self, node): if not isinstance(node, tt.Apply): return False if self.node_filter(node): return False input_expr = node.default_output() with variables(*self.relation_lvars): q = var() kanren_results = run(1, q, (self.kanren_relation, input_expr, q)) chosen_res = self.results_filter(kanren_results) if chosen_res: # Turn the meta objects and tuple-form expressions into Theano # objects. if isinstance(chosen_res, tuple) and chosen_res[0] == dict: # We got a dictionary of replacements. new_node = {k.obj: reify_meta(v) for k, v in evalt(chosen_res).items()} assert all(k in node.fgraph.variables for k in new_node) else: new_node = self.adjust_outputs(node, reify_meta(chosen_res)) return new_node else: return False
def test_objects(): fact(commutative, Add) fact(associative, Add) assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({})) assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({})) x = var('x') assert reify(x, tuple(goaleval(eq_assoccomm( add(1, 2, 3), add(1, 2, x)))({}))[0]) == 3 assert reify(x, next(goaleval(eq_assoccomm( add(1, 2, 3), add(x, 2, 1)))({}))) == 3 v = add(1, 2, 3) with variables(v): x = add(5, 6) assert reify(v, next(goaleval(eq_assoccomm(v, x))({}))) == x
def test_metatize(): vec_tt = tt.vector('vec') vec_m = metatize(vec_tt) assert vec_m.base == type(vec_tt) test_list = [1, 2, 3] metatize_test_list = metatize(test_list) assert isinstance(metatize_test_list, list) assert all(isinstance(m, MetaSymbol) for m in metatize_test_list) test_iter = iter([1, 2, 3]) metatize_test_iter = metatize(test_iter) assert isinstance(metatize_test_iter, Iterator) assert all(isinstance(m, MetaSymbol) for m in metatize_test_iter) test_out = metatize(var()) assert isvar(test_out) with variables(vec_tt): test_out = metatize(vec_tt) assert test_out == vec_tt assert isvar(test_out) test_out = metatize(np.r_[1, 2, 3]) assert isinstance(test_out, MetaSymbol) class TestClass(object): pass with pytest.raises(Exception): metatize(TestClass()) class TestOp(tt.gof.Op): pass test_out = metatize(TestOp) assert issubclass(test_out, MetaOp) test_op_tt = TestOp() test_obj = test_out(obj=test_op_tt) assert isinstance(test_obj, MetaSymbol) assert test_obj.obj == test_op_tt assert test_obj.base == TestOp
def transform(self, node): if not isinstance(node, tt.Apply): return False if self.node_filter(node): return False try: input_expr = node.default_output() except AttributeError: input_expr = node.outputs with variables(*self.relation_lvars): q = var() kanren_results = run(None, q, self.kanren_relation(input_expr, q)) chosen_res = self.results_filter(kanren_results) if chosen_res: if isinstance(chosen_res, ExpressionTuple): chosen_res = eval_and_reify_meta(chosen_res) if isinstance(chosen_res, dict): chosen_res = list(chosen_res.items()) if isinstance(chosen_res, list): # We got a dictionary of replacements new_node = {eval_and_reify_meta(k): eval_and_reify_meta(v) for k, v in chosen_res} assert all(k in node.fgraph.variables for k in new_node.keys()) elif isinstance(chosen_res, tt.Variable): # Attempt to automatically format the output for multi-output # `Apply` nodes. new_node = self.adjust_outputs(node, eval_and_reify_meta(chosen_res)) else: raise ValueError( "Unsupported FunctionGraph replacement variable type: {chosen_res}" ) return new_node else: return False
def test_objects(): fact(commutative, Add) fact(associative, Add) assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({})) assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({})) x = var('x') assert reify( x, tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(1, 2, x)))({}))[0]) == 3 assert reify(x, next(goaleval(eq_assoccomm(add(1, 2, 3), add(x, 2, 1)))({}))) == 3 v = add(1, 2, 3) with variables(v): x = add(5, 6) assert reify(v, next(goaleval(eq_assoccomm(v, x))({}))) == x
def test_unification(): x, y, a, b = tt.dvectors('xyab') x_s = tt.scalar('x_s') y_s = tt.scalar('y_s') c_tt = tt.constant(1, 'c') d_tt = tt.constant(2, 'd') # x_l = tt.vector('x_l') # y_l = tt.vector('y_l') # z_l = tt.vector('z_l') x_l = var('x_l') y_l = var('y_l') z_l = var('z_l') assert a == reify(x_l, {x_l: a}).reify() test_expr = mt.add(1, mt.mul(2, x_l)) test_reify_res = reify(test_expr, {x_l: a}) assert graph_equal(test_reify_res.reify(), 1 + 2 * a) z = tt.add(b, a) assert {x_l: z} == unify(x_l, z) assert b == unify(mt.add(x_l, a), mt.add(b, a))[x_l].reify() res = unify(mt.inv(mt.add(x_l, a)), mt.inv(mt.add(b, y_l))) assert res[x_l].reify() == b assert res[y_l].reify() == a # TODO: This produces a `DimShuffle` so that the scalar constant `1` # will match the dimensions of the vector `b`. That `DimShuffle` isn't # handled by the logic variable form. # assert unify(mt.add(x_l, 1), mt.add(b_l, 1))[x] == b with variables(x): assert unify(x + 1, b + 1)[x].reify() == b assert unify(mt.add(x_l, a), mt.add(b, a))[x_l].reify() == b with variables(x): assert unify(x, b)[x] == b assert unify([x], [b])[x] == b assert unify((x, ), (b, ))[x] == b assert unify(x + 1, b + 1)[x].reify() == b assert unify(x + a, b + a)[x].reify() == b with variables(x): assert unify(a + b, a + x)[x].reify() == b mt_expr_add = mt.add(x_l, y_l) # The parameters are vectors tt_expr_add_1 = tt.add(x, y) assert graph_equal( tt_expr_add_1, reify(mt_expr_add, unify(mt_expr_add, tt_expr_add_1)).reify()) # The parameters are scalars tt_expr_add_2 = tt.add(x_s, y_s) assert graph_equal( tt_expr_add_2, reify(mt_expr_add, unify(mt_expr_add, tt_expr_add_2)).reify()) # The parameters are constants tt_expr_add_3 = tt.add(c_tt, d_tt) assert graph_equal( tt_expr_add_3, reify(mt_expr_add, unify(mt_expr_add, tt_expr_add_3)).reify())