Пример #1
0
def test_set_symbolic_shape():

    a = Tensor([1.0, 2.0])

    @trace(symbolic=True, capture_as_const=True)
    def fwd(a):
        return F.relu(a * 2)

    fwd(a)
    orig_model = io.BytesIO()
    fwd.dump(
        orig_model,
        arg_names=["a"],
        output_names=["o"],
        optimize_for_inference=False,
    )
    orig_model.seek(0)
    net = Net.load(orig_model)
    var_a = net.input_vars[0]

    saved_symbolic_shape = set_symbolic_shape(True)
    assert isinstance(var_a.shape, VarNode)
    set_symbolic_shape(False)
    assert var_a.shape == var_a.partial_shape
    set_symbolic_shape(saved_symbolic_shape)
Пример #2
0
def test_squeeze(is_varnode):
    if is_varnode:
        network = Network()
        saved_symbolic_shape = set_symbolic_shape(False)
    else:
        network = None

    x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
    xx = make_tensor(x, network)

    for axis in [None, 3, -4, (3, -4)]:
        y = np.squeeze(x, axis)
        yy = F.squeeze(xx, axis)
        np.testing.assert_equal(y, yy.numpy())

    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)
Пример #3
0
def test_split_basic(is_varnode):
    if is_varnode:
        network = Network()
        saved_symbolic_shape = set_symbolic_shape(False)
    else:
        network = None

    data = np.random.random((2, 3, 4, 5)).astype(np.float32)
    inp = make_tensor(data, network)

    mge_out0 = F.split(inp, 2, axis=3)
    mge_out1 = F.split(inp, [3], axis=3)

    np_out = np.split(data, [3, 5], axis=3)

    assert len(mge_out0) == 2
    assert len(mge_out1) == 2

    np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
    np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])

    np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
    np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])

    try:
        F.split(inp, 4)
        assert False
    except ValueError as e:
        pass

    try:
        F.split(inp, [3, 2, 5], axis=3)
        assert False
    except ValueError as e:
        assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"

    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)
Пример #4
0
def test_reshape_shape_inference(is_varnode):
    if is_varnode:
        network = Network()
        saved_symbolic_shape = set_symbolic_shape(False)
    else:
        network = None

    x_shape_known = make_tensor([1, 2, 3, 4], network)
    x_shape_unknown = F.broadcast_to(make_tensor([1.0], network),
                                     shape=make_tensor([1, 1, 1, 1],
                                                       network).sum())
    tshp_unknown = astensor1d(
        (make_tensor([2], network), make_tensor([2], network)), x_shape_known)
    tshp_known = astensor1d((2, 2), x_shape_known)
    tshp_known_unspec = astensor1d((2, -1), x_shape_known)

    def check_shape(output, target):
        source = output.shape
        if isinstance(source, tensor):
            source = source.numpy()
        np.testing.assert_equal(source, target.shape)

    def func(x, target_shape):
        return x.reshape(target_shape)

    cases = [
        {
            "input": [x_shape_known, tshp_unknown],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_unknown, tshp_unknown],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_known, tshp_known],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_known, tshp_known_unspec],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_unknown, tshp_known],
            "output": [
                np.zeros((2, 2)),
            ]
        },
        {
            "input": [x_shape_unknown, tshp_known_unspec],
            "output": [
                np.zeros((2, 2)),
            ]
        },
    ]
    opr_test(cases,
             func,
             compare_fn=check_shape,
             test_trace=True,
             network=network)
    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)