def test_local_reshape_dimshuffle(): reshape_dimshuffle = out2in(local_reshape_dimshuffle) x = matrix("x") y = x.dimshuffle("x", 0, "x", 1) out = reshape(y, (1, x.shape[0] * x.shape[1], 1)) g = FunctionGraph([x], [out]) reshape_dimshuffle(g) topo = g.toposort() assert not all(isinstance(x, DimShuffle) for x in topo)
def test_local_alloc_dimshuffle(): alloc_dimshuffle = out2in(local_alloc_dimshuffle) x = vector("x") m = iscalar("m") y = x.dimshuffle("x", 0) out = aet.alloc(y, m, 1, x.shape[0]) g = FunctionGraph([x, m], [out]) alloc_dimshuffle(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])
def test_local_dimshuffle_subtensor(): dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) x = dtensor4("x") x = aet.patternbroadcast(x, (False, True, False, False)) i = iscalar("i") out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see". x = tensor(broadcastable=(False, True, False), dtype="float64") out = x[i].dimshuffle(1) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see" but # have in between dimensions. x = tensor(broadcastable=(False, True, False, True), dtype="float64") out = x[i].dimshuffle(1) f = aesara.function([x, i], out) topo = f.maker.fgraph.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4, ) # Test a corner case that had Aesara return a bug. x = dtensor4("x") x = aet.patternbroadcast(x, (False, True, False, False)) assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval({ x: np.ones((5, 1, 6, 7)) }).shape == (5, 3, 7)
def test_local_dimshuffle_alloc(): reshape_dimshuffle = out2in(local_dimshuffle_alloc) x = vector("x") out = aet.alloc(x, 3, 2).dimshuffle("x", "x", 0, 1) g = FunctionGraph([x], [out]) reshape_dimshuffle(g) l = PerformLinker() l.accept(g) f = l.make_function() assert f([3, 4]).ndim == 4 topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo])