def test_get_indices_add(): x = IndexedBase('x') y = IndexedBase('y') A = IndexedBase('A') i, j, k = Idx('i'), Idx('j'), Idx('k') assert get_indices(x[i] + 2*y[i]) == ({i}, {}) assert get_indices(y[i] + 2*A[i, j]*x[j]) == ({i}, {}) assert get_indices(y[i] + 2*(x[i] + A[i, j]*x[j])) == ({i}, {}) assert get_indices(y[i] + x[i]*(A[j, j] + 1)) == ({i}, {}) assert get_indices( y[i] + x[i]*x[j]*(y[j] + A[j, k]*x[k])) == ({i}, {})
def test_get_indices_Pow(): x = IndexedBase('x') y = IndexedBase('y') A = IndexedBase('A') i, j, k = Idx('i'), Idx('j'), Idx('k') assert get_indices(Pow(x[i], y[j])) == ({i, j}, {}) assert get_indices(Pow(x[i, k], y[j, k])) == ({i, j, k}, {}) assert get_indices(Pow(A[i, k], y[k] + A[k, j]*x[j])) == ({i, k}, {}) assert get_indices(Pow(2, x[i])) == get_indices(exp(x[i])) # test of a design decision, this may change: assert get_indices(Pow(x[i], 2)) == ({i}, {})
def test_ufunc_support(): f = Function('f') g = Function('g') x = IndexedBase('x') y = IndexedBase('y') i, j = Idx('i'), Idx('j') a = symbols('a') assert get_indices(f(x[i])) == ({i}, {}) assert get_indices(f(x[i], y[j])) == ({i, j}, {}) assert get_indices(f(y[i])*g(x[i])) == (set(), {}) assert get_indices(f(a, x[i])) == ({i}, {}) assert get_indices(f(a, y[i], x[j])*g(x[i])) == ({j}, {}) assert get_indices(g(f(x[i]))) == ({i}, {}) assert get_contraction_structure(f(x[i])) == {None: {f(x[i])}} assert get_contraction_structure( f(y[i])*g(x[i])) == {(i,): {f(y[i])*g(x[i])}} assert get_contraction_structure( f(y[i])*g(f(x[i]))) == {(i,): {f(y[i])*g(f(x[i]))}} assert get_contraction_structure( f(x[j], y[i])*g(x[i])) == {(i,): {f(x[j], y[i])*g(x[i])}}
def _get_all_indices(expr: Expr) -> Set[Idx]: indices, empty_dict = get_indices(expr) return indices
def test_trivial_indices(): x, y = symbols('x y') assert get_indices(x) == (set(), {}) assert get_indices(x*y) == (set(), {}) assert get_indices(x + y) == (set(), {}) assert get_indices(x**y) == (set(), {})
def test_scalar_broadcast(): x = IndexedBase('x') y = IndexedBase('y') i, j = Idx('i'), Idx('j') assert get_indices(x[i] + y[i, i]) == ({i}, {}) assert get_indices(x[i] + y[j, j]) == ({i}, {})
def test_get_indices_exceptions(): x = IndexedBase('x') y = IndexedBase('y') i, j = Idx('i'), Idx('j') raises(IndexConformanceException, lambda: get_indices(x[i] + y[j]))
def test_get_indices_mul(): x = IndexedBase('x') y = IndexedBase('y') i, j = Idx('i'), Idx('j') assert get_indices(x[j]*y[i]) == ({i, j}, {}) assert get_indices(x[i]*y[j]) == ({i, j}, {})
def test_get_indices_Idx(): f = Function('f') i, j = Idx('i'), Idx('j') assert get_indices(f(i)*j) == ({i, j}, {}) assert get_indices(f(j, i)) == ({j, i}, {}) assert get_indices(f(i)*i) == (set(), {})
def test_get_indices_Indexed(): x = IndexedBase('x') i, j = Idx('i'), Idx('j') assert get_indices(x[i, j]) == ({i, j}, {}) assert get_indices(x[j, i]) == ({j, i}, {})