Пример #1
0
def build_model_with_dot_checkpoints(ir: pir.Ir) -> None:
    """Make a model with 2 dot_checkpoints.

    Args:
        ir (pir.Ir): The ir to write to

    Returns:
    (tuple): tuple containing:

        ir._pb_ir (_ir.Ir): The underlying IR
        a_h2d (HostToDeviceStream): The host to device stream
        f_d2h (DeviceToHostStream): The device to host stream
    """
    main = ir.main_graph()

    with main:
        a_h2d = pir.h2d_stream(_TENSOR_SHAPE, pir.float32, name="a_stream")
        a = ops.host_load(a_h2d, "a")

        b = pir.variable(np.random.rand(*_TENSOR_SHAPE).astype(np.float32),
                         name="b")
        c = ops.add(a, b)
        ir.dot_checkpoint("Foo")

        d = pir.variable(np.random.rand(*_TENSOR_SHAPE).astype(np.float32),
                         name="d")
        e = ops.mul(c, d)
        ir.dot_checkpoint("Bar")

        f = ops.gelu(e)

        f_d2h = pir.d2h_stream(_TENSOR_SHAPE, pir.float32, name="f_stream")
        ops.host_store(f_d2h, f)
Пример #2
0
    def test_repeat_error(self, repeat_count: int):
        """Test an error is thrown with incorrect repeat_count

        Args:
            repeat_count (int): Number of times to repeat.
        """
        ir = pir.Ir()
        main = ir.main_graph()
        with main:
            h2d = pir.h2d_stream((2, 16), pir.dtypes.float32)
            x = ops.host_load(h2d, "x")

            W = pir.variable(np.random.normal(0, 0.1, (16, 16)), name="W")
            b = pir.variable(np.zeros(16), name="b")

            linear = Linear()
            linear_graph = ir.create_graph(linear, x, out_features=16)

            with pytest.raises(ValueError) as e_info:
                y = ops.repeat(linear_graph,
                               repeat_count,
                               x,
                               subgraph_in_to_parent_in={
                                   linear.W: W,
                                   linear.b: b
                               })
            assert e_info.value.args[0].startswith(
                "Repeat trip count for repeat of")
Пример #3
0
def test_tensor_id_conflict():
    ir = pir.Ir()
    main = ir.main_graph()
    with main:
        name0 = pir.variable(1, name="tensor").id
        name1 = pir.variable(1, name="tensor").id
        name2 = pir.constant(1, name="tensor").id
    assert name0 == "tensor"
    ids = [name0, name1, name2]
    assert len(ids) == len(set(ids))
Пример #4
0
    def test_cmp(self):
        ir = pir.Ir()
        main = ir.main_graph()

        with main:
            a = pir.variable(1)
            b = pir.variable(1)
            assert a != b  # test __eq__
            assert len(set([a, b])) == 2  # test __hash__
            str(a)  # test __repr__
Пример #5
0
 def test_dunder(self, available_memory_proportion, serialise_mode,
                 serialise_factor, output_type, partials_type):
     ir = pir.Ir()
     del available_memory_proportion, serialise_mode, serialise_factor, output_type, partials_type
     with ir.main_graph():
         a = pir.variable(np.random.rand(4, 4))
         b = pir.variable(np.random.rand(4, 4))
         c = a @ b
     assert len(ir.main_graph().get_tensors()) == 3
     assert len(ir.main_graph().get_variables()) == 2
     assert contains_op_of_type("MatMul", _ir.op.MatMulOp, ir.main_graph())
Пример #6
0
    def test_fn(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(True, pir.bool)
            b = pir.variable(False, pir.bool)
            c = ops.logical_and(a, b)
        assert len(g.get_tensors()) == 3
        assert len(g.get_variables()) == 2
        assert contains_op_of_type("And", _ir.op.AndOp, g)
Пример #7
0
    def test_ignore_index(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(np.zeros((2, 10)), pir.float32)
            b = pir.variable(np.zeros((2)), pir.float32)
            c = ops.nll_loss_with_softmax_grad(a, b, ignore_index=5)
        assert len(g.get_tensors()) == 5
        assert contains_op_of_type("NlllWithSoftmaxGradDirect",
                                   _ir.op.NlllWithSoftmaxGradDirectOp, g)
Пример #8
0
    def test_needs_casting(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(1, pir.int32)
            b = pir.variable(0, pir.int32)
            c = ops.logical_and(a, b)
        assert len(g.get_tensors()) == 5
        assert len(g.get_variables()) == 2
        assert contains_op_of_type("And", _ir.op.AndOp, g)
Пример #9
0
    def test_fn(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(1)
            b = pir.variable(2)
            c = ops.sub(a, b)
        assert len(g.get_tensors()) == 3
        assert len(g.get_variables()) == 2
        assert contains_op_of_type("Sub", _ir.op.SubtractOp, g)
Пример #10
0
    def test_dunder(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(1)
            b = pir.variable(2)
            c = a - b
        assert len(ir.main_graph().get_tensors()) == 3
        assert len(ir.main_graph().get_variables()) == 2
        assert contains_op_of_type("Sub", _ir.op.SubtractOp, g)
Пример #11
0
    def test_dunder(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(1)
            b = pir.variable(2)
            c = a * b
        assert len(g.get_tensors()) == 3
        assert len(g.get_variables()) == 2
        assert contains_op_of_type("Mul", _ir.op.MulOp, g)
Пример #12
0
    def test_fn(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            x = pir.variable(0)
            seed = pir.variable(np.array([32, 32]), dtype=dtypes.uint32)
            c = ops.dropout(x, seed, 0.3)
        assert len(g.get_tensors()) == 3
        assert len(g.get_variables()) == 2
        assert contains_op_of_type("Dropout", _ir.op.DropoutOp, g)
Пример #13
0
    def test_adamax_updater_invalid(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            with pytest.raises(ValueError) as excinfo:
                updater = ops.var_updates.adamax_updater(m, v)
            message = str(excinfo.value)
            assert "AdaMax requires time_step not None." in message
Пример #14
0
def test_tensor_id_conflict_between_Ir():
    ir1 = pir.Ir()
    with ir1.main_graph():
        t1 = pir.variable(1, dtype=pir.float32, name="tensor")

    ir2 = pir.Ir()
    with ir2.main_graph():
        t2 = pir.variable(1, dtype=pir.float32, name="tensor")

    assert 2 == len(set([t1, t2]))  # test __hash__
    assert t1 != t2  # test __eq__
Пример #15
0
    def test_dampened_add_tensor(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(1)
            b = pir.constant(2)
            factor = pir.variable(0.9)
            c = ops.var_updates.accumulate_(a, b, factor)
        assert contains_op_of_type("Accumulate", _ir.op.AccumulateOp, g)
        op = g._pb_graph.getOps()[0]
        op.getAccumulationType() == _ir.AccumulationType.DampenedAdd
Пример #16
0
    def test_layer_norm(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            x = pir.variable(np.ones((2, 4)))
            weight = pir.variable(np.ones(4))
            bias = pir.variable(np.zeros(4))
            y = ops.layer_norm(x, weight, bias)
        assert len(g.get_tensors()) == 6
        assert len(g.get_variables()) == 3
        assert contains_op_of_type("GroupNormalization", _ir.op.GroupNormOp, g)
Пример #17
0
    def test_mean(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            a = pir.variable(1)
            b = pir.constant(2)
            step = pir.variable(0)
            c = ops.var_updates.accumulate_mean_(a, b, step)
        assert contains_op_of_type("Accumulate", _ir.op.AccumulateOp, g)
        op = g._pb_graph.getOps()[0]
        op.getAccumulationType() == _ir.AccumulationType.Mean
Пример #18
0
    def test_fn(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            t = pir.variable([[1, 2], [3, 4]])
            indices = pir.variable([[0, 1], [1, 0]], dtype=dtypes.int32)
            c = ops.gather(t, indices)

        assert len(g.get_tensors()) == 3
        assert len(g.get_variables()) == 2
        assert contains_op_of_type("Gather", _ir.op.GatherOp, g)
Пример #19
0
    def test_adamax_updater(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            t = pir.variable(1, name='t')

            updater = ops.var_updates.adamax_updater(m, v, time_step=t)
        assert len(g.get_tensors()) == 4
        assert contains_op_of_type("AdamUpdater", _ir.op.AdamUpdaterOp, g)
Пример #20
0
    def test_lamb_updater_no_bias_no_wd(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            updater = ops.var_updates.lamb_updater(m, v)

        assert len(g.get_tensors()) == 3
        assert contains_op_of_type("AdamUpdater", _ir.op.AdamUpdaterOp, g)
        op = g._pb_graph.getOps()[0]
        assert op.isOptimizerOp()
Пример #21
0
    def test__ensure_tensor(self):
        """Test the `_ensure_tensor()` method."""
        ir = pir.Ir()
        main = ir.main_graph()

        with main:
            a = pir.variable(1)
            b = pir.variable(2)
            c = a._ensure_tensor(b)
            d = a._ensure_tensor(3)

            assert c == b
            assert isinstance(d, Constant)
            assert d.dtype == a.dtype
Пример #22
0
    def test_adam_updater_bias_invalid(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            t = pir.variable(2, name='t')
            b1 = 0.9
            with pytest.raises(ValueError) as excinfo:
                updater = ops.var_updates.adam_updater(m,
                                                       v,
                                                       time_step=t,
                                                       beta1=b1)
            message = str(excinfo.value)
        assert "Bias correction requires both beta1 and beta2 not None." in message
Пример #23
0
    def test_adam_wd_updater(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            w = pir.variable(1, name='w')
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            wd = pir.constant(0.2, name='wd')

            updater = ops.var_updates.adam_updater(m,
                                                   v,
                                                   weight=w,
                                                   weight_decay=wd)
        assert len(g.get_tensors()) == 5
        assert contains_op_of_type("AdamUpdater", _ir.op.AdamUpdaterOp, g)
Пример #24
0
    def test_adam_wd_updater_invalid(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            t = pir.variable(1, name='t')
            wd = pir.constant(0.2, name='wd')
            with pytest.raises(ValueError) as excinfo:
                updater = ops.var_updates.adam_updater(m,
                                                       v,
                                                       time_step=t,
                                                       weight_decay=wd)
            message = str(excinfo.value)
        assert "Weight decay requires weight to be not None." in message
Пример #25
0
    def test_adam_bias_wd_updater(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            m = pir.variable(1, name='m')
            v = pir.variable(2, name='v')
            w = pir.variable(1, name='w')
            t = pir.variable(2, name='t')
            wd = pir.constant(0.2, name='wd')
            b1 = 0.9
            b2 = 0.99
            updater = ops.var_updates.adam_updater(m, v, w, t, wd, b1, b2)

        assert len(g.get_tensors()) == 6
        assert contains_op_of_type("AdamUpdater", _ir.op.AdamUpdaterOp, g)
Пример #26
0
    def test_fn(self):
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            t = pir.variable(np.random.rand(3, 5, 7))
            index = pir.variable(np.array((1, 2)))
            axes = [0, 2]
            sizes = [1, 3]
            no_overlap = True
            c = ops.dynamic_slice(t, index, axes, sizes, no_overlap)

        assert c.shape == (sizes[0], t.shape[1], sizes[1])
        assert len(g.get_tensors()) == 3
        assert contains_op_of_type("DynamicSlice",
                                   _ir.op.dynamic.DynamicSliceOp, g)
Пример #27
0
    def test_get_ir(self):
        ir = pir.Ir()
        main = ir.main_graph()

        with main:
            a = pir.variable(1)
            assert a.ir() == ir
Пример #28
0
def test_hook():
    ir = pir.Ir()
    g = ir.main_graph()

    called = False

    def hook(_):
        nonlocal called
        called = True

    handle = g.register_op_created_hook(hook)

    with g:
        x = pir.variable(1)
        x = x + 1

    assert called
    called = False

    # Creating this graph will create
    # an AddOp on the new graph.
    # Ensure this does not trigger the hook.
    sg = ir.create_graph(lambda y: y + 1, x)
    assert not called

    g.remove_op_created_hook(handle)
    with g:
        x = x + 1
    assert not called
Пример #29
0
    def test_by_ref(self):
        ir = pir.Ir()

        def foo(x: pir.TensorByRef, y: pir.Tensor):
            return ops.var_updates.accumulate_(x, y)

        with ir.main_graph():
            v1 = pir.variable(1)
            v2 = pir.variable(2)

            g = ir.create_graph(foo, v1, v2)
            info = ops.call_with_info(g, v1, v2)

        assert len(g._by_ref_inputs) == 1
        assert info._op.modifiesIndex(0)
        assert not info._op.modifiesIndex(1)
Пример #30
0
def test_create_graph():
    ir = pir.Ir()

    def foo(x: pir.TensorByRef, y: pir.Tensor, c: int):
        return (x * c) + y

    with ir.main_graph():
        v1 = pir.variable(1)
        v2 = pir.variable(2)

        g = ir.create_graph(foo, v1, v2, 5)

    assert len(g._by_ref_inputs) == 1
    x = g.get_input_tensors()[0]
    assert x == g._by_ref_inputs.pop()
    assert x.name == "x"