Пример #1
0
def test_normalize():

    cases = [
        {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
    ]

    def np_normalize(x, p=2, axis=None, eps=1e-12):
        if axis is None:
            norm = np.sum(x ** p) ** (1.0 / p)
        else:
            norm = np.sum(x ** p, axis=axis, keepdims=True) ** (1.0 / p)
        return x / np.clip(norm, a_min=eps, a_max=np.inf)

    # # Test L-2 norm along all dimensions
    # opr_test(cases, F.normalize, ref_fn=np_normalize)

    # # Test L-1 norm along all dimensions
    # opr_test(cases, partial(F.normalize, p=1), ref_fn=partial(np_normalize, p=1))

    # Test L-2 norm along the second dimension
    opr_test(cases, partial(F.normalize, axis=1), ref_fn=partial(np_normalize, axis=1))

    # Test some norm == 0
    cases[0]["input"][0, 0, 0, :] = 0
    cases[1]["input"][0, 0, 0, :] = 0
    opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3))
Пример #2
0
def test_broadcast():
    input1_shape = (20, 30)
    output1_shape = (30, 20, 30)
    data1 = np.random.random(input1_shape).astype(np.float32)

    input2_shape = (10, 1)
    output2_shape = (20, 10, 20)
    data2 = np.random.random(input2_shape).astype(np.float32)

    def compare_fn(x, y):
        assert x.shape[0] == y

    cases = [
        {
            "input": [data1, output1_shape],
            "output": output1_shape
        },
        {
            "input": [data2, output2_shape],
            "output": output2_shape
        },
    ]
    opr_test(cases, F.broadcast_to, compare_fn=compare_fn)

    x = F.ones((2, 1, 3))
    with pytest.raises(RuntimeError):
        F.broadcast_to(x, (2, 3, 4))

    with pytest.raises(RuntimeError):
        F.broadcast_to(x, (4, 1, 3))

    with pytest.raises(RuntimeError):
        F.broadcast_to(x, (1, 3))
Пример #3
0
def common_test_reduce(opr, ref_opr):
    data1_shape = (5, 6, 7)
    data2_shape = (2, 9, 12)
    data1 = np.random.random(data1_shape).astype(np.float32)
    data2 = np.random.random(data2_shape).astype(np.float32)
    cases = [{"input": data1}, {"input": data2}]

    if opr not in (F.argmin, F.argmax):
        # test default axis
        opr_test(cases, opr, ref_fn=ref_opr)
        # test all axises in range of input shape
        for axis in range(-3, 3):
            # test keepdims False
            opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)
            # test keepdims True
            opr_test(
                cases,
                opr,
                ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True),
                axis=axis,
                keepdims=True,
            )
    else:
        # test defaut axis
        opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32))
        # test all axises in range of input shape
        for axis in range(0, 3):
            opr_test(
                cases,
                opr,
                ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
                axis=axis,
            )
Пример #4
0
def test_linspace():
    cases = [
        {
            "input": [1, 9, 9]
        },
        {
            "input": [3, 10, 8]
        },
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(
            start, end, step, dtype=np.float32),
    )

    cases = [
        {
            "input": [9, 1, 9]
        },
        {
            "input": [10, 3, 8]
        },
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(
            start, end, step, dtype=np.float32),
    )
Пример #5
0
def test_matinv():
    shape1 = (5, 5)
    shape2 = (3, 9, 9)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    # make matrix diagonally dominant for numerical stability
    data1 += (np.eye(shape1[0]) * shape1[0]).astype("float32")
    data2 += np.broadcast_to((np.eye(shape2[1]) * shape2[1]).astype("float32"),
                             shape2)

    cases = [
        {
            "input": data1
        },
        {
            "input": data2
        },
    ]

    opr_test(
        cases,
        F.matinv,
        compare_fn=lambda x, y: np.testing.assert_allclose(
            x.numpy(), y, rtol=1e-5),
        ref_fn=np.linalg.inv,
    )
Пример #6
0
def test_round():
    data1_shape = (15, )
    data2_shape = (25, )
    data1 = np.random.random(data1_shape).astype(np.float32)
    data2 = np.random.random(data2_shape).astype(np.float32)

    cases = [{"input": data1}, {"input": data2}]
    opr_test(cases, F.round, ref_fn=np.round)
Пример #7
0
def test_sqrt():
    d1_shape = (15,)
    d2_shape = (25,)
    d1 = np.random.random(d1_shape).astype(np.float32)
    d2 = np.random.random(d2_shape).astype(np.float32)

    cases = [{"input": d1}, {"input": d2}]
    opr_test(cases, F.sqrt, ref_fn=np.sqrt)
Пример #8
0
def test_stack():
    data1 = np.random.random((3, 2, 2)).astype("float32")
    data2 = np.random.random((3, 2, 2)).astype("float32")
    data3 = np.random.random((3, 2, 2)).astype("float32")

    cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
    for ai in range(3):

        def run(data1, data2):
            return F.stack([data1, data2], axis=ai)

        opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai))
Пример #9
0
def test_sort():
    data1_shape = (10, 3)
    data2_shape = (12, 2)
    data1 = np.random.random(data1_shape).astype(np.float32)
    data2 = np.random.random(data2_shape).astype(np.float32)
    output1 = [np.sort(data1), np.argsort(data1).astype(np.int32)]
    output2 = [np.sort(data2), np.argsort(data2).astype(np.int32)]

    cases = [
        {"input": data1, "output": output1},
        {"input": data2, "output": output2},
    ]
    opr_test(cases, F.sort)
Пример #10
0
def test_concat():
    def get_data_shape(length: int):
        return (length, 2, 3)

    data1 = np.random.random(get_data_shape(5)).astype("float32")
    data2 = np.random.random(get_data_shape(6)).astype("float32")
    data3 = np.random.random(get_data_shape(7)).astype("float32")

    def run(data1, data2):
        return F.concat([data1, data2])

    cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
    opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]))
Пример #11
0
def test_round(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    data1_shape = (15, )
    data2_shape = (25, )
    data1 = np.random.random(data1_shape).astype(np.float32)
    data2 = np.random.random(data2_shape).astype(np.float32)

    cases = [{"input": data1}, {"input": data2}]
    opr_test(cases, F.round, ref_fn=np.round, network=network)
Пример #12
0
def test_binary_cross_entropy():
    data1_shape = (2, 2)
    label1_shape = (2, 2)
    data2_shape = (2, 3)
    label2_shape = (2, 3)

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def compare_fn(x, y):
        np.testing.assert_allclose(x.numpy(), y, atol=5e-4)

    np.random.seed(123)
    data1 = np.random.uniform(size=data1_shape).astype(np.float32)
    label1 = np.random.uniform(size=label1_shape).astype(np.float32)
    expect1 = np.array(0.6361, dtype=np.float32)

    np.random.seed(123)
    data2 = np.random.uniform(size=data2_shape).astype(np.float32)
    label2 = np.random.uniform(size=label2_shape).astype(np.float32)
    expect2 = np.array(0.6750, dtype=np.float32)

    cases = [
        {
            "input": [data1, label1],
            "output": expect1,
        },
        {
            "input": [data2, label2],
            "output": expect2,
        },
    ]

    opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)

    cases = [
        {
            "input": [sigmoid(data1), label1],
            "output": expect1,
        },
        {
            "input": [sigmoid(data2), label2],
            "output": expect2,
        },
    ]
    opr_test(
        cases,
        partial(F.nn.binary_cross_entropy, with_logits=False),
        compare_fn=compare_fn,
    )
Пример #13
0
def test_tile(shape, reps, is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    def tile_func(inp):
        return F.tile(inp=inp, reps=reps)

    cases = [{"input": np.random.randn(*shape).astype("float32")}]

    opr_test(cases,
             tile_func,
             ref_fn=lambda inp: np.tile(inp, reps),
             network=network)
Пример #14
0
def test_broadcast(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    input1_shape = (20, 30)
    output1_shape = (30, 20, 30)
    data1 = np.random.random(input1_shape).astype(np.float32)

    input2_shape = (10, 1)
    output2_shape = (20, 10, 20)
    data2 = np.random.random(input2_shape).astype(np.float32)

    input3_shape = (10, 10)
    output3_shape = (10, 10)
    data3 = np.random.random(input3_shape).astype(np.float32)

    def compare_fn(x, y):
        assert x._tuple_shape[0] == y

    cases = [
        {
            "input": [data1, output1_shape],
            "output": output1_shape
        },
        {
            "input": [data2, output2_shape],
            "output": output2_shape
        },
        {
            "input": [data3, output3_shape],
            "output": output3_shape
        },
    ]
    opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)

    x = F.ones((2, 1, 3))
    with pytest.raises(RuntimeError):
        F.broadcast_to(x, (2, 3, 4))

    with pytest.raises(RuntimeError):
        F.broadcast_to(x, (4, 1, 3))

    with pytest.raises(RuntimeError):
        F.broadcast_to(x, (1, 3))
Пример #15
0
def test_matmul():
    shape1 = 3
    shape2 = 3
    shape3 = (3, 5)
    shape4 = (5, 6)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    data3 = np.random.random(shape3).astype("float32")
    data4 = np.random.random(shape4).astype("float32")

    cases = [
        {
            "input": [data1, data2]
        },
        {
            "input": [data2, data3]
        },
        {
            "input": [data3, data4]
        },
    ]
    opr_test(cases, F.matmul, ref_fn=np.matmul)

    batch_size = 10
    shape1 = (batch_size, 2, 3)
    shape2 = (batch_size, 3, 4)
    shape3 = (batch_size, 10, 4, 5)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    data3 = np.random.random(shape3).astype("float32")

    cases = [{"input": [data1, data2]}, {"input": [data2, data3]}]
    for i in range(0, batch_size):

        def compare_fn(x, y):
            x.numpy()[i, ...] == y

        opr_test(
            cases,
            F.matmul,
            compare_fn=compare_fn,
            ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]),
        )
Пример #16
0
def test_diag(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    shapes = [(10, 10), (6, 9), (8, 7), (8, )]
    cases = []
    for shp in shapes:
        cases.append({"input": [np.random.random(shp).astype("float32")]})

    for axis in range(-2, 3):

        def run(data):
            return F.diag(data, k=axis)

        opr_test(cases,
                 run,
                 ref_fn=lambda x: np.diag(x, axis),
                 network=network)
Пример #17
0
def test_stack(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    data1 = np.random.random((3, 2, 2)).astype("float32")
    data2 = np.random.random((3, 2, 2)).astype("float32")
    data3 = np.random.random((3, 2, 2)).astype("float32")

    cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
    for ai in range(3):

        def run(data1, data2):
            return F.stack([data1, data2], axis=ai)

        opr_test(cases,
                 run,
                 ref_fn=lambda x, y: np.stack([x, y], axis=ai),
                 network=network)
Пример #18
0
def test_concat(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    def get_data_shape(length: int):
        return (length, 2, 3)

    data1 = np.random.random(get_data_shape(5)).astype("float32")
    data2 = np.random.random(get_data_shape(6)).astype("float32")
    data3 = np.random.random(get_data_shape(7)).astype("float32")

    def run(data1, data2):
        return F.concat([data1, data2])

    cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
    opr_test(cases,
             run,
             ref_fn=lambda x, y: np.concatenate([x, y]),
             network=network)
Пример #19
0
def test_roll(shape, shifts, axis, is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    inp = np.random.randn(*shape).astype("float32")

    def func(inp):
        return F.roll(inp, shifts, axis)

    cases = [
        {
            "input": inp
        },
    ]

    opr_test(cases,
             func,
             ref_fn=lambda inp: np.roll(inp, shifts, axis),
             network=network)
Пример #20
0
def test_where():
    maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
    xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
    yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)

    maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
    xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]],
                   dtype=np.float32)
    yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)

    maskv2 = np.array([1, 1, 1], dtype=np.bool_)
    xv2 = np.array([1, 3, 2], dtype=np.float32)
    yv2 = np.array([5, 6, 9], dtype=np.float32)

    maskv3 = np.array([0, 0, 0], dtype=np.bool_)
    xv3 = np.array([1, 3, 2], dtype=np.float32)
    yv3 = np.array([5, 6, 9], dtype=np.float32)

    maskv4 = np.array(1, dtype=np.bool_)
    xv4 = np.array(1, dtype=np.float32)
    yv4 = np.array(0, dtype=np.float32)

    cases = [
        {
            "input": [maskv0, xv0, yv0]
        },
        {
            "input": [maskv1, xv1, yv1]
        },
        {
            "input": [maskv2, xv2, yv2]
        },
        {
            "input": [maskv3, xv3, yv3]
        },
        {
            "input": [maskv4, xv4, yv4]
        },
    ]
    opr_test(cases, F.where, ref_fn=np.where, test_trace=True)
Пример #21
0
def test_matinv():
    shape1 = (5, 5)
    shape2 = (3, 9, 9)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")

    cases = [
        {
            "input": data1
        },
        {
            "input": data2
        },
    ]

    opr_test(
        cases,
        F.matinv,
        compare_fn=lambda x, y: np.testing.assert_allclose(
            x.numpy(), y, rtol=1e-5),
        ref_fn=np.linalg.inv,
    )
Пример #22
0
def test_hinge_loss():
    np.random.seed(123)
    # case with L1 norm
    cases = []
    for shape in [(2, 2), (2, 3)]:
        data = np.random.uniform(size=shape).astype(np.float32)
        label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
        expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
        cases.append({"input": [data, label], "output": expect})

    opr_test(cases, F.nn.hinge_loss)

    # cases with L2 norm
    cases = []
    for shape in [(2, 2), (2, 3)]:
        data = np.random.uniform(size=shape).astype(np.float32)
        label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
        expect = ((np.clip(0, np.inf, 1 - data * label)**2).sum(axis=1)).mean()
        cases.append({"input": [data, label], "output": expect})

    def hinge_loss_with_l2_norm(pred, label):
        return F.nn.hinge_loss(pred, label, "L2")

    opr_test(cases, hinge_loss_with_l2_norm)
Пример #23
0
def test_repeat(shape, repeats, axis, is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    def repeat_func(inp):
        return F.repeat(inp=inp, repeats=repeats, axis=axis)

    if shape != ():
        cases = [
            {
                "input": np.random.randn(*shape).astype("float32")
            },
        ]
    else:
        cases = [{"input": np.array(1.23)}]

    opr_test(
        cases,
        repeat_func,
        ref_fn=lambda inp: np.repeat(inp, repeats, axis),
        network=network,
    )
Пример #24
0
def test_linspace(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    cases = [
        {
            "input": [1, 9, 9]
        },
        {
            "input": [3, 10, 8]
        },
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(
            start, end, step, dtype=np.float32),
        network=network,
    )

    cases = [
        {
            "input": [9, 1, 9]
        },
        {
            "input": [10, 3, 8]
        },
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(
            start, end, step, dtype=np.float32),
        network=network,
    )

    cases = [
        {
            "input": [1, make_tensor(9, network), 9]
        },
        {
            "input": [make_tensor(1, network), 9,
                      make_tensor(9, network)]
        },
    ]
    opr_test(
        cases,
        F.linspace,
        ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
        network=network,
    )
Пример #25
0
def test_arange(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    cases = [
        {
            "input": [1, 9, 1]
        },
        {
            "input": [2, 10, 2]
        },
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(
            start, end, step, dtype=np.float32),
        network=network,
    )

    cases = [
        {
            "input": [9, 1, -1]
        },
        {
            "input": [10, 2, -2]
        },
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(
            start, end, step, dtype=np.float32),
        network=network,
    )

    cases = [
        {
            "input": [9.3, 1.2, -0.5]
        },
        {
            "input": [10.3, 2.1, -1.7]
        },
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(
            start, end, step, dtype=np.float32),
        network=network,
    )
Пример #26
0
def test_arange():
    cases = [
        {
            "input": [1, 9, 1]
        },
        {
            "input": [2, 10, 2]
        },
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(
            start, end, step, dtype=np.float32),
    )

    cases = [
        {
            "input": [9, 1, -1]
        },
        {
            "input": [10, 2, -2]
        },
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(
            start, end, step, dtype=np.float32),
    )

    cases = [
        {
            "input": [9.3, 1.2, -0.5]
        },
        {
            "input": [10.3, 2.1, -1.7]
        },
    ]
    opr_test(
        cases,
        F.arange,
        ref_fn=lambda start, end, step: np.arange(
            start, end, step, dtype=np.float32),
    )
Пример #27
0
def test_matmul():
    shape1 = 3
    shape2 = 3
    shape3 = (3, 5)
    shape4 = (5, 6)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    data3 = np.random.random(shape3).astype("float32")
    data4 = np.random.random(shape4).astype("float32")

    cases = [
        {
            "input": [data1, data2]
        },
        {
            "input": [data2, data3]
        },
        {
            "input": [data3, data4]
        },
    ]
    opr_test(cases, F.matmul, ref_fn=np.matmul)

    batch_size = 10
    shape1 = (2, )
    shape2 = (batch_size, 2, 3)
    shape3 = (batch_size, 3, 4)
    shape4 = (batch_size, 10, 4, 2)
    shape5 = (batch_size, 10, 2, 4)
    data1 = np.random.random(shape1).astype("float32")
    data2 = np.random.random(shape2).astype("float32")
    data3 = np.random.random(shape3).astype("float32")
    data4 = np.random.random(shape4).astype("float32")
    data5 = np.random.random(shape5).astype("float32")

    cases = [
        {
            "input": [data1, data2]
        },
        {
            "input": [data2, data3]
        },
        {
            "input": [data3, data4]
        },
        {
            "input": [data4, data5]
        },
    ]
    opr_test(cases, F.matmul, ref_fn=np.matmul)

    opr_test(
        [{
            "input": [data1, data4]
        }],
        F.matmul,
        ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
        transpose_b=True,
    )

    opr_test(
        [{
            "input": [data3, data2]
        }],
        F.matmul,
        ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(
            0, 2, 1)),
        transpose_a=True,
        transpose_b=True,
    )
Пример #28
0
def test_flatten():
    data0_shape = (2, 3, 4, 5)
    data1_shape = (4, 5, 6, 7)
    data0 = np.random.random(data0_shape).astype(np.float32)
    data1 = np.random.random(data1_shape).astype(np.float32)

    def compare_fn(x, y):
        assert x.shape[0] == y

    output0 = (2 * 3 * 4 * 5, )
    output1 = (4 * 5 * 6 * 7, )
    cases = [
        {
            "input": data0,
            "output": output0
        },
        {
            "input": data1,
            "output": output1
        },
    ]
    opr_test(cases, F.flatten, compare_fn=compare_fn)

    output0 = (2, 3 * 4 * 5)
    output1 = (4, 5 * 6 * 7)
    cases = [
        {
            "input": data0,
            "output": output0
        },
        {
            "input": data1,
            "output": output1
        },
    ]
    opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1)

    output0 = (2, 3, 4 * 5)
    output1 = (4, 5, 6 * 7)
    cases = [
        {
            "input": data0,
            "output": output0
        },
        {
            "input": data1,
            "output": output1
        },
    ]
    opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2)

    output0 = (2, 3 * 4, 5)
    output1 = (4, 5 * 6, 7)
    cases = [
        {
            "input": data0,
            "output": output0
        },
        {
            "input": data1,
            "output": output1
        },
    ]
    opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2)
Пример #29
0
def test_reshape_shape_inference(is_varnode):
    if is_varnode:
        network = Network()
        saved_symbolic_shape = set_symbolic_shape(False)
    else:
        network = None

    x_shape_known = make_tensor([1, 2, 3, 4], network)
    x_shape_unknown = F.broadcast_to(make_tensor([1.0], network),
                                     shape=make_tensor([1, 1, 1, 1],
                                                       network).sum())
    tshp_unknown = astensor1d(
        (make_tensor([2], network), make_tensor([2], network)), x_shape_known)
    tshp_known = astensor1d((2, 2), x_shape_known)
    tshp_known_unspec = astensor1d((2, -1), x_shape_known)

    def check_shape(output, target):
        source = output.shape
        if isinstance(source, tensor):
            source = source.numpy()
        np.testing.assert_equal(source, target.shape)

    def func(x, target_shape):
        return x.reshape(target_shape)

    cases = [
        {
            "input": [x_shape_known, tshp_unknown],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_unknown, tshp_unknown],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_known, tshp_known],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_known, tshp_known_unspec],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_unknown, tshp_known],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_unknown, tshp_known_unspec],
            "output": [
                np.zeros((2, 2)),
            ]
        },
    ]
    opr_test(cases,
             func,
             compare_fn=check_shape,
             test_trace=True,
             network=network)
    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)
Пример #30
0
def test_flatten(is_varnode):
    if is_varnode:
        network = Network()
    else:
        network = None

    data0_shape = (2, 3, 4, 5)
    data1_shape = (4, 5, 6, 7)
    data0 = np.random.random(data0_shape).astype(np.float32)
    data1 = np.random.random(data1_shape).astype(np.float32)

    cases = [
        {
            "input": data0,
            "output": data0.flatten()
        },
        {
            "input": data1,
            "output": data1.flatten()
        },
    ]
    opr_test(cases, F.flatten, network=network)

    cases = [
        {
            "input": data0,
            "output": data0.reshape(2, -1)
        },
        {
            "input": data1,
            "output": data1.reshape(4, -1)
        },
    ]
    opr_test(cases, F.flatten, start_axis=1, network=network)

    cases = [
        {
            "input": data0,
            "output": data0.reshape(2, 3, -1)
        },
        {
            "input": data1,
            "output": data1.reshape(4, 5, -1)
        },
    ]
    opr_test(cases, F.flatten, start_axis=2, network=network)

    cases = [
        {
            "input": data0,
            "output": data0.reshape(2, -1, 5)
        },
        {
            "input": data1,
            "output": data1.reshape(4, -1, 7)
        },
    ]
    opr_test(
        cases,
        F.flatten,
        start_axis=1,
        end_axis=2,
        network=network,
    )