def test_not(self): x, y, z = ints("xyz") fn = gof.DualLinker().accept(FunctionGraph( [x, y], [invert(x)])).make_function() for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): assert fn(a, b) == ~a, (a, ) x, y, z = ints("xyz") fn = gof.DualLinker().accept(FunctionGraph([x, y], [~x])).make_function() for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): assert fn(a, b) == ~a, (a, )
def test_and(self): x, y, z = ints("xyz") fn = (gof.DualLinker().accept(FunctionGraph( [x, y], [and_(x, y)])).make_function()) for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): assert fn(a, b) == (a & b), (a, b) x, y, z = ints("xyz") fn = gof.DualLinker().accept(FunctionGraph([x, y], [x & y])).make_function() for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): assert fn(a, b) == (a & b), (a, b)
def test_local_reshape_dimshuffle(): reshape_dimshuffle = out2in(local_reshape_dimshuffle) x = tensor.matrix("x") y = x.dimshuffle("x", 0, "x", 1) out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1)) g = FunctionGraph([x], [out]) reshape_dimshuffle(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_alloc_dimshuffle(): alloc_dimshuffle = out2in(local_alloc_dimshuffle) x = tensor.vector("x") m = tensor.iscalar("m") y = x.dimshuffle("x", 0) out = tensor.alloc(y, m, 1, x.shape[0]) g = FunctionGraph([x, m], [out]) alloc_dimshuffle(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_straightforward(self): x, y, z = floats("xyz") e = mul(add(x, y), div_proxy(x, y)) C = Composite([x, y], [e]) c = C.make_node(x, y) # print c.c_code(['x', 'y'], ['z'], dict(id = 0)) g = FunctionGraph([x, y], [c.out]) fn = gof.DualLinker().accept(g).make_function() assert fn(1.0, 2.0) == 1.5
def test_local_dimshuffle_alloc(): reshape_dimshuffle = out2in(local_dimshuffle_alloc) x = tensor.vector("x") out = tensor.alloc(x, 3, 2).dimshuffle("x", "x", 0, 1) g = FunctionGraph([x], [out]) reshape_dimshuffle(g) l = aesara.gof.PerformLinker() l.accept(g) f = l.make_function() assert f([3, 4]).ndim == 4 topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_with_constants(self): x, y, z = floats("xyz") e = mul(add(70.0, y), div_proxy(x, y)) C = Composite([x, y], [e]) c = C.make_node(x, y) assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0)) g = FunctionGraph([x, y], [c.out]) fn = gof.DualLinker().accept(g).make_function() assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self): x, y, z = floats("xyz") e0 = x + y + z e1 = x + y * z e2 = x / y C = Composite([x, y, z], [e0, e1, e2]) c = C.make_node(x, y, z) # print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0)) g = FunctionGraph([x, y, z], c.outputs) fn = gof.DualLinker().accept(g).make_function() assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
def test_composite_printing(self): x, y, z = floats("xyz") e0 = x + y + z e1 = x + y * z e2 = x / y e3 = x // 5 e4 = -x e5 = x - y e6 = x**y + (-z) e7 = x % 3 C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7]) c = C.make_node(x, y, z) g = FunctionGraph([x, y, z], c.outputs) gof.DualLinker().accept(g).make_function() assert str(g) == ("[*1 -> Composite{((i0 + i1) + i2)," " (i0 + (i1 * i2)), (i0 / i1), " "(i0 // Constant{5}), " "(-i0), (i0 - i1), ((i0 ** i1) + (-i2))," " (i0 % Constant{3})}(x, y, z), " "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7]")
def test_local_dimshuffle_subtensor(): dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) x = tensor.dtensor4("x") x = tensor.patternbroadcast(x, (False, True, False, False)) i = tensor.iscalar("i") out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see". x = tensor.tensor(broadcastable=(False, True, False), dtype="float64") out = x[i].dimshuffle(1) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see" but # have in between dimensions. x = tensor.tensor(broadcastable=(False, True, False, True), dtype="float64") out = x[i].dimshuffle(1) f = aesara.function([x, i], out) topo = f.maker.fgraph.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4, ) # Test a corner case that had Aesara return a bug. x = tensor.dtensor4("x") x = tensor.patternbroadcast(x, (False, True, False, False)) assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval({ x: np.ones((5, 1, 6, 7)) }).shape == (5, 3, 7)
def jax_funcify_Scan(op): inner_fg = FunctionGraph(op.inputs, op.outputs) jax_tt_inner_func = jax_funcify(inner_fg) def scan(*outer_inputs): scan_args = ScanArgs( outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info ) # `outer_inputs` is a list with the following composite form: # [n_steps] # + outer_in_seqs # + outer_in_mit_mot # + outer_in_mit_sot # + outer_in_sit_sot # + outer_in_shared # + outer_in_nit_sot # + outer_in_non_seqs n_steps = scan_args.n_steps seqs = scan_args.outer_in_seqs n_non_seqs = len(scan_args.outer_in_non_seqs) # TODO: sit_sots mit_sot_in_slices = [] for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): neg_taps = [abs(t) for t in tap if t < 0] pos_taps = [abs(t) for t in tap if t > 0] max_neg = max(neg_taps) if neg_taps else 0 max_pos = max(pos_taps) if pos_taps else 0 init_slice = seq[: max_neg + max_pos] mit_sot_in_slices.append(init_slice) init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs] def jax_args_to_inner_scan(op, carry, x): # `carry` contains all inner-output taps, non_seqs, and shared # terms ( inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ) = carry # `x` contains the in_seqs inner_in_seqs = x # `inner_scan_inputs` is a list with the following composite form: # inner_in_seqs # + sum(inner_in_mit_mot, []) # + sum(inner_in_mit_sot, []) # + inner_in_sit_sot # + inner_in_shared # + inner_in_non_seqs inner_scan_inputs = [ inner_in_seqs, inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_non_seqs, ] raise NotImplementedError() return inner_scan_inputs def inner_scan_outs_to_jax_outs( op, old_carry, inner_scan_outs, ): # `inner_scan_outs` is a list with the following # composite form: # outer_out_mit_mot # + outer_out_mit_sot # + outer_out_sit_sot # + outer_out_nit_sot # + outer_out_shared # + cond ( outer_out_mit_mot, outer_out_mit_sot, outer_out_sit_sot, outer_out_nit_sot, outer_out_shared, cond, ) = inner_scan_outs outer_out_non_seqs = old_carry[:-n_non_seqs] # This should contain all inner-output taps, non_seqs, and shared # terms carry = [ outer_out_mit_mot, outer_out_mit_sot, outer_out_sit_sot, outer_out_shared, outer_out_non_seqs, ] # This should contain all inner-outputs that produce # outer-outputs y = [] raise NotImplementedError() return (carry, y) def jax_inner_func(carry, x): inner_args = jax_args_to_inner_scan(op, carry, x) inner_scan_outs = jax_tt_inner_func(*inner_args) new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs) return new_carry, y return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) return scan
def test_mul_add_div_proxy(): x, y, z = floats("xyz") e = mul(add(x, y), div_proxy(x, y)) g = FunctionGraph([x, y], [e]) fn = gof.DualLinker().accept(g).make_function() assert fn(1.0, 2.0) == 1.5
def test_xor(self): x, y, z = ints("xyz") fn = gof.DualLinker().accept(FunctionGraph([x, y], [x ^ y])).make_function() for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): assert fn(a, b) == (a ^ b), (a, b)
def test_neq(self): x, y, z = floats("xyz") fn = gof.DualLinker().accept(FunctionGraph( [x, y], [neq(x, y)])).make_function() for a, b in ((3.0, 9), (3, 0.9), (3, 3)): assert fn(a, b) == (a != b)