Exemplo n.º 1
0
def test_node_output():
    input_array = np.array([0, 1, 2, 3, 4, 5])
    splits = 3
    expected_shape = len(input_array) // splits

    input_tensor = ng.constant(input_array, dtype=np.int32)
    axis = ng.constant(0, dtype=np.int64)
    split_node = ng.split(input_tensor, axis, splits)

    split_node_outputs = split_node.outputs()

    assert len(split_node_outputs) == splits
    assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2]
    assert np.equal(
        [output_node.get_element_type() for output_node in split_node_outputs],
        input_tensor.get_element_type(),
    ).all()
    assert np.equal(
        [output_node.get_shape() for output_node in split_node_outputs],
        Shape([expected_shape]),
    ).all()
    assert np.equal(
        [output_node.get_partial_shape() for output_node in split_node_outputs],
        PartialShape([expected_shape]),
    ).all()

    output0 = split_node.output(0)
    output1 = split_node.output(1)
    output2 = split_node.output(2)

    assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
Exemplo n.º 2
0
def test_split():
    runtime = get_runtime()
    input_tensor = ng.constant(np.array([0, 1, 2, 3, 4, 5], dtype=np.int32))
    axis = ng.constant(0, dtype=np.int64)
    splits = 3

    split_node = ng.split(input_tensor, axis, splits)
    computation = runtime.computation(split_node)
    split_results = computation()
    expected_results = np.array([[0, 1], [2, 3], [4, 5]], dtype=np.int32)
    assert np.allclose(split_results, expected_results)
Exemplo n.º 3
0
def test_mutiple_outputs():
    input_shape = [4, 4]
    input_data = np.arange(-8, 8).reshape(input_shape)

    expected_output = np.split(input_data, 2, axis=1)[0]
    expected_output[expected_output < 0] = 0

    test_param = ng.parameter(input_shape, dtype=np.float32, name="A")
    split = ng.split(test_param, axis=1, num_splits=2)
    split_first_output = split.output(0)
    relu = ng.relu(split_first_output)

    runtime = get_runtime()
    computation = runtime.computation(relu, test_param)
    output = computation(input_data)

    assert np.equal(output, expected_output).all()