示例#1
0
def test_asciidag():
    n = pt.make_size_param("n")
    array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
    stack = pt.stack([array, 2*array, array + 6])
    y = stack @ stack.T

    from pytato import get_ascii_graph

    res = get_ascii_graph(y, use_color=False)

    ref_str = r"""* Inputs
*-.   Placeholder
|\ \
* | | IndexLambda
| |/
|/|
| * IndexLambda
|/
*   Stack
|\
* | AxisPermutation
|/
* Einsum
* Outputs
"""

    assert res == ref_str
示例#2
0
def test_stack(ctx_factory, input_dims):
    cl_ctx = ctx_factory()
    queue = cl.CommandQueue(cl_ctx)

    shape = (2, 2, 2)[:input_dims]

    from numpy.random import default_rng
    rng = default_rng()
    x_in = rng.random(size=shape)
    y_in = rng.random(size=shape)

    namespace = pt.Namespace()
    x = pt.make_data_wrapper(namespace, x_in)
    y = pt.make_data_wrapper(namespace, y_in)

    for axis in range(0, 1 + input_dims):
        assert_allclose_to_numpy(pt.stack((x, y), axis=axis), queue)
示例#3
0
def test_toposortmapper():
    n = pt.make_size_param("n")
    array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
    stack = pt.stack([array, 2*array, array + 6])
    y = stack @ stack.T

    tm = pt.transform.TopoSortMapper()
    tm(y)

    from pytato.array import (AxisPermutation, IndexLambda,
                              Placeholder, Einsum, SizeParam, Stack)

    assert isinstance(tm.topological_order[0], SizeParam)
    assert isinstance(tm.topological_order[1], Placeholder)
    assert isinstance(tm.topological_order[2], IndexLambda)
    assert isinstance(tm.topological_order[3], IndexLambda)
    assert isinstance(tm.topological_order[4], Stack)
    assert isinstance(tm.topological_order[5], AxisPermutation)
    assert isinstance(tm.topological_order[6], Einsum)
示例#4
0
def main():
    n = pt.make_size_param("n")
    array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
    stack = pt.stack([array, 2 * array, array + 6])
    result = stack @ stack.T

    pt.show_ascii_graph(result)

    dot_code = pt.get_dot_graph(result)

    with open(GRAPH_DOT, "w") as outf:
        outf.write(dot_code)
    logger.info("wrote '%s'", GRAPH_DOT)

    dot_path = shutil.which("dot")
    if dot_path is not None:
        subprocess.run([dot_path, "-Tsvg", GRAPH_DOT, "-o", GRAPH_SVG],
                       check=True)
        logger.info("wrote '%s'", GRAPH_SVG)
    else:
        logger.info("'dot' executable not found; cannot convert to SVG")
示例#5
0
def test_stack(ctx_factory, input_dims):
    cl_ctx = ctx_factory()
    queue = cl.CommandQueue(cl_ctx)

    shape = (2, 2, 2)[:input_dims]

    from numpy.random import default_rng
    rng = default_rng()
    x_in = rng.random(size=shape)
    y_in = rng.random(size=shape)

    namespace = pt.Namespace()
    x = pt.make_data_wrapper(namespace, x_in)
    y = pt.make_data_wrapper(namespace, y_in)

    for axis in range(0, 1 + input_dims):
        prog = pt.generate_loopy(
                pt.stack((x, y), axis=axis),
                target=pt.PyOpenCLTarget(queue))

        _, (out,) = prog()
        assert (out == np.stack((x_in, y_in), axis=axis)).all()
示例#6
0
def main():
    x_in = np.random.randn(2, 2)
    x = pt.make_data_wrapper(x_in)
    y = pt.stack([x @ x.T, 2 * x, 42 + x])
    y = y + 55

    tm = TopoSortMapper()
    tm(y)

    from functools import partial
    pfunc = partial(get_partition_id, tm.topological_order)

    # Find the partitions
    outputs = pt.DictOfNamedArrays({"out": y})
    partition = find_partition(outputs, pfunc)

    # Show the partitions
    from pytato.visualization import get_dot_graph_from_partition
    get_dot_graph_from_partition(partition)

    # Execute the partitions
    ctx = cl.create_some_context()
    queue = cl.CommandQueue(ctx)

    prg_per_partition = generate_code_for_partition(partition)

    context = execute_partition(partition, prg_per_partition, queue)

    final_res = [context[k] for k in outputs.keys()]

    # Execute the unpartitioned code for comparison
    prg = pt.generate_loopy(y)
    _, (out, ) = prg(queue)

    np.testing.assert_allclose([out], final_res)

    print("Partitioning test succeeded.")
示例#7
0
 def face_swap(self, vec):
     return pt.stack((pt.roll(vec[:, 1], +1), pt.roll(vec[:, 0], -1)),
                     axis=1)
示例#8
0
def test_stack_input_validation():
    namespace = pt.Namespace()

    x = pt.make_placeholder(namespace,
                            name="x",
                            shape=(10, 10),
                            dtype=np.float)
    y = pt.make_placeholder(namespace, name="y", shape=(1, 10), dtype=np.float)

    assert pt.stack((x, x, x), axis=0).shape == (3, 10, 10)

    pt.stack((x, ), axis=0)
    pt.stack((x, ), axis=1)

    with pytest.raises(ValueError):
        pt.stack(())

    with pytest.raises(ValueError):
        pt.stack((x, y))

    with pytest.raises(ValueError):
        pt.stack((x, x), axis=3)
示例#9
0
def test_array_dot_repr():
    x = pt.make_placeholder("x", (10, 4), np.int64)
    y = pt.make_placeholder("y", (10, 4), np.int64)

    def _assert_stripped_repr(ary: pt.Array, expected_repr: str):
        expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]])
        result_str = "".join([c for c in expected_repr if c not in [" ", "\n"]])
        assert expected_str == result_str

    _assert_stripped_repr(
        3*x + 4*y,
        """
IndexLambda(
    expr=Sum((Subscript(Variable('_in0'),
                        (Variable('_0'), Variable('_1'))),
              Subscript(Variable('_in1'),
                        (Variable('_0'), Variable('_1'))))),
    shape=(10, 4),
    dtype='int64',
    bindings={'_in0': IndexLambda(expr=Product((3, Subscript(Variable('_in1'),
                                                             (Variable('_0'),
                                                              Variable('_1'))))),
                                  shape=(10, 4),
                                  dtype='int64',
                                  bindings={'_in1': Placeholder(shape=(10, 4),
                                                                dtype='int64',
                                                                name='x')}),
              '_in1': IndexLambda(expr=Product((4, Subscript(Variable('_in1'),
                                                             (Variable('_0'),
                                                              Variable('_1'))))),
                                  shape=(10, 4),
                                  dtype='int64',
                                  bindings={'_in1': Placeholder(shape=(10, 4),
                                                                dtype='int64',
                                                                name='y')})})""")

    _assert_stripped_repr(
        pt.roll(x.reshape(2, 20).reshape(-1), 3),
        """
Roll(
    array=Reshape(array=Reshape(array=Placeholder(shape=(10, 4),
                                                  dtype='int64',
                                                  name='x'),
                                newshape=(2, 20),
                                order='C'),
                  newshape=(40),
                  order='C'),
    shift=3, axis=0)""")
    _assert_stripped_repr(y * pt.not_equal(x, 3),
                          """
IndexLambda(
    expr=Product((Subscript(Variable('_in0'),
                            (Variable('_0'), Variable('_1'))),
                  Subscript(Variable('_in1'),
                            (Variable('_0'), Variable('_1'))))),
    shape=(10, 4),
    dtype='int64',
    bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'),
              '_in1': IndexLambda(
                  expr=Comparison(Subscript(Variable('_in0'),
                                            (Variable('_0'), Variable('_1'))),
                                  '!=',
                                  3),
                  shape=(10, 4),
                  dtype=<class 'numpy.bool_'>,
                  bindings={'_in0': Placeholder(shape=(10, 4),
                                                dtype='int64',
                                                name='x')})})""")
    _assert_stripped_repr(
        x[y[:, 2:3], x[2, :]],
        """
AdvancedIndexInContiguousAxes(
    array=Placeholder(shape=(10, 4), dtype='int64', name='x'),
    indices=(BasicIndex(array=Placeholder(shape=(10, 4),
                                          dtype='int64',
                                          name='y'),
                        indices=(NormalizedSlice(start=0, stop=10, step=1),
                                 NormalizedSlice(start=2, stop=3, step=1))),
             BasicIndex(array=Placeholder(shape=(10, 4),
                                          dtype='int64',
                                          name='x'),
                        indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""")

    _assert_stripped_repr(
        pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]),
        """
Stack(
    arrays=(
        AxisPermutation(
            array=AdvancedIndexInContiguousAxes(
                array=Placeholder(shape=(10, 4),
                                  dtype='int64',
                                  name='x'),
                indices=(BasicIndex(array=(...),
                                    indices=(NormalizedSlice(start=0,
                                                             stop=10,
                                                             step=1),
                                             NormalizedSlice(start=2,
                                                             stop=3,
                                                             step=1))),
                         BasicIndex(array=(...),
                                    indices=(2,
                                             NormalizedSlice(start=0,
                                                             stop=4,
                                                             step=1))))),
            axis_permutation=(1, 0)),
        AxisPermutation(array=AdvancedIndexInContiguousAxes(
            array=Placeholder(shape=(10,
                                     4),
                              dtype='int64',
                              name='y'),
            indices=(BasicIndex(array=(...),
                                indices=(NormalizedSlice(start=0,
                                                         stop=10,
                                                         step=1),
                                         NormalizedSlice(start=2,
                                                         stop=3,
                                                         step=1))),
                     BasicIndex(array=(...),
                                indices=(2,
                                         NormalizedSlice(start=0,
                                                         stop=4,
                                                         step=1))))),
                        axis_permutation=(1, 0))), axis=0)
    """)