def test_print_add():
    class Print_Add(Cell):
        def __init__(self):
            super().__init__()
            self.print = P.Print()
            self.add = P.Add()

        def construct(self, x, y):
            x = self.add(x, y)
            self.print("input_x:", x, "input_y:", y)
            return x

    cap = Capture()
    with capture(cap):
        input_x = Tensor(3, dtype=ms.int32)
        input_y = Tensor(4, dtype=ms.int32)
        expect = Tensor(7, dtype=ms.int32)
        net = Print_Add()
        out = net(input_x, input_y)
        time.sleep(0.1)
        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())

    patterns = {
        'input_x:\nTensor(shape=[], dtype=Int32, value=7)\n'
        'input_y:\nTensor(shape=[], dtype=Int32, value=4)'
    }
    check_output(cap.output, patterns)
def test_print_for():
    class Print_For(Cell):
        def __init__(self):
            super().__init__()
            self.print = P.Print()

        def construct(self, x, y):
            y = x + y
            self.print("input_x before:", x, "input_y before:", y)
            for _ in range(3):
                y = y + 1
                self.print("input_x after:", x, "input_y after:", y)
            return y

    cap = Capture()
    with capture(cap):
        input_x = Tensor(2, dtype=ms.int32)
        input_y = Tensor(4, dtype=ms.int32)
        expect = Tensor(9, dtype=ms.int32)
        net = Print_For()
        out = net(input_x, input_y)
        time.sleep(0.1)
        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())

    patterns = {
        'input_x before:\nTensor(shape=[], dtype=Int32, value=2)\n'
        'input_y before:\nTensor(shape=[], dtype=Int32, value=6)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=7)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=8)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=9)'
    }
    check_output(cap.output, patterns)
def test_print_if():
    class Print_If(Cell):
        def __init__(self):
            super().__init__()
            self.print = P.Print()

        def construct(self, x, y):
            self.print("input_x before:", x, "input_y before:", y)
            if x < y:
                self.print("input_x after:", x, "input_y after:", y)
                x = x + 1
            return x

    cap = Capture()
    with capture(cap):
        input_x = Tensor(3, dtype=ms.int32)
        input_y = Tensor(4, dtype=ms.int32)
        expect = Tensor(4, dtype=ms.int32)
        net = Print_If()
        out = net(input_x, input_y)
        time.sleep(0.1)
        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())

    patterns = {
        'input_x before:\nTensor(shape=[], dtype=Int32, value=3)\n'
        'input_y before:\nTensor(shape=[], dtype=Int32, value=4)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)'
    }
    check_output(cap.output, patterns)
def test_print_assign_add():
    class Print_Assign_Add(Cell):
        def __init__(self):
            super().__init__()
            self.print = P.Print()
            self.add = P.Add()
            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')

        def construct(self, x, y):
            self.print("before:", self.para)
            self.para = x
            self.print("after:", self.para)
            x = self.add(self.para, y)
            return x

    cap = Capture()
    with capture(cap):
        input_x = Tensor(3, dtype=ms.int32)
        input_y = Tensor(4, dtype=ms.int32)
        expect = Tensor(7, dtype=ms.int32)
        net = Print_Assign_Add()
        out = net(input_x, input_y)
        time.sleep(0.1)
        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())

    patterns = {
        'before:\nTensor(shape=[], dtype=Int32, value=1)',
        'after:\nTensor(shape=[], dtype=Int32, value=3)'
    }
    check_output(cap.output, patterns)
def test_print_assign_while():
    class Print_Assign_While(Cell):
        def __init__(self):
            super().__init__()
            self.print = P.Print()
            self.para = Parameter(Tensor(0, dtype=ms.int32), name='para')

        def construct(self, x, y):
            self.print("input_x before:", x, "input_y before:", y,
                       "para before:", self.para)
            while x < y:
                self.para = x
                x = self.para + 1
                self.print("input_x after:", x, "input_y after:", y,
                           "para after:", self.para)
            return x

    cap = Capture()
    with capture(cap):
        input_x = Tensor(1, dtype=ms.int32)
        input_y = Tensor(4, dtype=ms.int32)
        expect = Tensor(4, dtype=ms.int32)
        net = Print_Assign_While()
        out = net(input_x, input_y)
        time.sleep(0.1)
        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())

    patterns = {
        'input_x before:\nTensor(shape=[], dtype=Int32, value=1)\n'
        'input_y before:\nTensor(shape=[], dtype=Int32, value=4)\n'
        'para before:\nTensor(shape=[], dtype=Int32, value=0)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
        'para after:\nTensor(shape=[], dtype=Int32, value=1)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
        'para after:\nTensor(shape=[], dtype=Int32, value=2)',
        'input_x after:\nTensor(shape=[], dtype=Int32, value=4)\n'
        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
        'para after:\nTensor(shape=[], dtype=Int32, value=3)'
    }
    check_output(cap.output, patterns)
def test_print():
    class Print(Cell):
        def __init__(self):
            super().__init__()
            self.print = P.Print()

        def construct(self, x, y):
            self.print("input_x:", x, "input_y:", y)
            return x

    cap = Capture()
    with capture(cap):
        input_x = Tensor(3, dtype=ms.int32)
        input_y = Tensor(4, dtype=ms.int32)
        net = Print()
        net(input_x, input_y)
        time.sleep(0.1)

    patterns = {
        'input_x:\nTensor(shape=[], dtype=Int32, value=3)\n'
        'input_y:\nTensor(shape=[], dtype=Int32, value=4)'
    }
    check_output(cap.output, patterns)