def test_local_reshape_dimshuffle(): reshape_dimshuffle = out2in(local_reshape_dimshuffle) x = tensor.matrix('x') y = x.dimshuffle('x', 0, 'x', 1) out = tensor.reshape(y, (1, x.shape[0] * x.shape[1], 1)) g = FunctionGraph([x], [out]) reshape_dimshuffle(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_alloc_dimshuffle(): alloc_dimshuffle = out2in(local_alloc_dimshuffle) x = tensor.vector('x') m = tensor.iscalar('m') y = x.dimshuffle('x', 0) out = tensor.alloc(y, m, 1, x.shape[0]) g = FunctionGraph([x, m], [out]) alloc_dimshuffle(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor(): dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) x = tensor.tensor4('x') x = tensor.patternbroadcast(x, (False, True, False, False)) i = tensor.iscalar('i') out = x[:, :, 10:30, ::i].dimshuffle(0,2,3) g = FunctionGraph([x,i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_alloc(): reshape_dimshuffle = out2in(local_dimshuffle_alloc) x = tensor.vector('x') out = tensor.alloc(x, 3, 2).dimshuffle('x', 'x', 0, 1) g = FunctionGraph([x], [out]) reshape_dimshuffle(g) l = theano.gof.PerformLinker() l.accept(g) f = l.make_function() assert f([3, 4]).ndim == 4 topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor(): dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) x = tensor.dtensor4('x') x = tensor.patternbroadcast(x, (False, True, False, False)) i = tensor.iscalar('i') out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see". x = tensor.tensor(broadcastable=(False, True, False), dtype='float64') out = x[i].dimshuffle(1) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see" but # have in between dimensions. x = tensor.tensor(broadcastable=(False, True, False, True), dtype='float64') out = x[i].dimshuffle(1) f = theano.function([x, i], out) topo = f.maker.fgraph.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4,) # Test a corner case that had Theano return a bug. x = tensor.dtensor4('x') x = tensor.patternbroadcast(x, (False, True, False, False)) assert x[:,:, 0:3, ::-1].dimshuffle(0,2,3).eval({x: np.ones((5, 1, 6, 7))}).shape == (5, 3, 7)
def test_local_dimshuffle_subtensor(): dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) x = tensor.dtensor4("x") x = tensor.patternbroadcast(x, (False, True, False, False)) i = tensor.iscalar("i") out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see". x = tensor.tensor(broadcastable=(False, True, False), dtype="float64") out = x[i].dimshuffle(1) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see" but # have in between dimensions. x = tensor.tensor(broadcastable=(False, True, False, True), dtype="float64") out = x[i].dimshuffle(1) f = theano.function([x, i], out) topo = f.maker.fgraph.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) assert f(numpy.random.rand(5, 1, 4, 1), 2).shape == (4,)
def test_constant_cache_error(self): v = theano.tensor.constant(1) assert v.cached with pytest.raises(CachedConstantError): FunctionGraph([], [v + 1], clone=False)
def test_straightforward(self): x, y, z = inputs() e = mul(add(x, y), div_proxy(x, y)) g = FunctionGraph([x, y], [e]) fn = gof.DualLinker().accept(g).make_function() assert fn(1.0, 2.0) == 1.5
def jax_funcify_Scan(op): inner_fg = FunctionGraph(op.inputs, op.outputs) jax_tt_inner_func = jax_funcify(inner_fg) def scan(*outer_inputs): scan_args = ScanArgs( list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info ) # `outer_inputs` is a list with the following composite form: # [n_steps] # + outer_in_seqs # + outer_in_mit_mot # + outer_in_mit_sot # + outer_in_sit_sot # + outer_in_shared # + outer_in_nit_sot # + outer_in_non_seqs n_steps = scan_args.n_steps seqs = scan_args.outer_in_seqs # TODO: mit_mots mit_mot_in_slices = [] mit_sot_in_slices = [] for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): neg_taps = [abs(t) for t in tap if t < 0] pos_taps = [abs(t) for t in tap if t > 0] max_neg = max(neg_taps) if neg_taps else 0 max_pos = max(pos_taps) if pos_taps else 0 init_slice = seq[: max_neg + max_pos] mit_sot_in_slices.append(init_slice) sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot] init_carry = ( mit_mot_in_slices, mit_sot_in_slices, sit_sot_in_slices, scan_args.outer_in_shared, scan_args.outer_in_non_seqs, ) def jax_args_to_inner_scan(op, carry, x): # `carry` contains all inner-output taps, non_seqs, and shared # terms ( inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ) = carry # `x` contains the in_seqs inner_in_seqs = x # `inner_scan_inputs` is a list with the following composite form: # inner_in_seqs # + sum(inner_in_mit_mot, []) # + sum(inner_in_mit_sot, []) # + inner_in_sit_sot # + inner_in_shared # + inner_in_non_seqs inner_in_mit_sot_flatten = [] for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices): inner_in_mit_sot_flatten.extend(array[jnp.array(index)]) inner_scan_inputs = sum( [ inner_in_seqs, inner_in_mit_mot, inner_in_mit_sot_flatten, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ], [], ) return inner_scan_inputs def inner_scan_outs_to_jax_outs( op, old_carry, inner_scan_outs, ): ( inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ) = old_carry def update_mit_sot(mit_sot, new_val): return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0) inner_out_mit_sot = [ update_mit_sot(mit_sot, new_val) for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs) ] # This should contain all inner-output taps, non_seqs, and shared # terms if not inner_in_sit_sot: inner_out_sit_sot = [] else: inner_out_sit_sot = inner_scan_outs new_carry = ( inner_in_mit_mot, inner_out_mit_sot, inner_out_sit_sot, inner_in_shared, inner_in_non_seqs, ) return new_carry def jax_inner_func(carry, x): inner_args = jax_args_to_inner_scan(op, carry, x) inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func] new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) return new_carry, inner_scan_outs _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) # We need to prepend the initial values so that the JAX output will # match the raw `Scan` `Op` output and, thus, work with a downstream # `Subtensor` `Op` introduced by the `scan` helper function. def append_scan_out(scan_in_part, scan_out_part): return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0) if scan_args.outer_in_mit_sot: scan_out_final = [ append_scan_out(init, out) for init, out in zip(scan_args.outer_in_mit_sot, scan_out) ] elif scan_args.outer_in_sit_sot: scan_out_final = [ append_scan_out(init, out) for init, out in zip(scan_args.outer_in_sit_sot, scan_out) ] if len(scan_out_final) == 1: scan_out_final = scan_out_final[0] return scan_out_final return scan
def test_neq(self): x, y, z = inputs() fn = gof.DualLinker().accept(FunctionGraph( [x, y], [neq(x, y)])).make_function() for a, b in ((3., 9), (3, 0.9), (3, 3)): self.assertTrue(fn(a, b) == (a != b))
def test_xor(self): x, y, z = ints('xyz') fn = gof.DualLinker().accept(FunctionGraph([x, y], [x ^ y])).make_function() for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): self.assertTrue(fn(a, b) == (a ^ b), (a, b))
def jax_funcify_Scan(op): inner_fg = FunctionGraph(op.inputs, op.outputs) jax_tt_inner_func = jax_funcify(inner_fg) def scan(*outer_inputs): scan_args = ScanArgs( outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info ) # `outer_inputs` is a list with the following composite form: # [n_steps] # + outer_in_seqs # + outer_in_mit_mot # + outer_in_mit_sot # + outer_in_sit_sot # + outer_in_shared # + outer_in_nit_sot # + outer_in_non_seqs n_steps = scan_args.n_steps seqs = scan_args.outer_in_seqs n_non_seqs = len(scan_args.outer_in_non_seqs) # TODO: sit_sots mit_sot_in_slices = [] for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): neg_taps = [abs(t) for t in tap if t < 0] pos_taps = [abs(t) for t in tap if t > 0] max_neg = max(neg_taps) if neg_taps else 0 max_pos = max(pos_taps) if pos_taps else 0 init_slice = seq[: max_neg + max_pos] mit_sot_in_slices.append(init_slice) init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs] def jax_args_to_inner_scan(op, carry, x): # `carry` contains all inner-output taps, non_seqs, and shared # terms ( inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ) = carry # `x` contains the in_seqs inner_in_seqs = x # `inner_scan_inputs` is a list with the following composite form: # inner_in_seqs # + sum(inner_in_mit_mot, []) # + sum(inner_in_mit_sot, []) # + inner_in_sit_sot # + inner_in_shared # + inner_in_non_seqs inner_scan_inputs = [ inner_in_seqs, inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_non_seqs, ] raise NotImplementedError() return inner_scan_inputs def inner_scan_outs_to_jax_outs( op, old_carry, inner_scan_outs, ): # `inner_scan_outs` is a list with the following # composite form: # outer_out_mit_mot # + outer_out_mit_sot # + outer_out_sit_sot # + outer_out_nit_sot # + outer_out_shared # + cond ( outer_out_mit_mot, outer_out_mit_sot, outer_out_sit_sot, outer_out_nit_sot, outer_out_shared, cond, ) = inner_scan_outs outer_out_non_seqs = old_carry[:-n_non_seqs] # This should contain all inner-output taps, non_seqs, and shared # terms carry = [ outer_out_mit_mot, outer_out_mit_sot, outer_out_sit_sot, outer_out_shared, outer_out_non_seqs, ] # This should contain all inner-outputs that produce # outer-outputs y = [] raise NotImplementedError() return (carry, y) def jax_inner_func(carry, x): inner_args = jax_args_to_inner_scan(op, carry, x) inner_scan_outs = jax_tt_inner_func(*inner_args) new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs) return new_carry, y return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) return scan
def test_mul_add_div_proxy(): x, y, z = floats("xyz") e = mul(add(x, y), div_proxy(x, y)) g = FunctionGraph([x, y], [e]) fn = gof.DualLinker().accept(g).make_function() assert fn(1.0, 2.0) == 1.5
def test_or(self): x, y, z = ints("xyz") fn = gof.DualLinker().accept(FunctionGraph([x, y], [x | y])).make_function() for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): assert fn(a, b) == (a | b), (a, b)
def test_neq(self): x, y, z = floats("xyz") fn = gof.DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])).make_function() for a, b in ((3.0, 9), (3, 0.9), (3, 3)): assert fn(a, b) == (a != b)
def test_clone(self): v = theano.tensor.constant(1) assert v.cached FunctionGraph([], [v + 1])