Beispiel #1
0
def test_topk():
    data_shape = [6, 12, 10, 24]
    data_parameter = ng.parameter(data_shape, name="Data", dtype=np.float32)
    K = np.int32(3)
    axis = np.int32(1)
    node = ng.topk(data_parameter, K, axis, "max", "value")
    assert node.get_type_name() == "TopK"
    assert node.get_output_size() == 2
    assert list(node.get_output_shape(0)) == [6, 3, 10, 24]
    assert list(node.get_output_shape(1)) == [6, 3, 10, 24]
def test_topk():
    data_shape = [6, 12, 10, 24]
    data_parameter = ng.parameter(data_shape, name='Data', dtype=np.float32)
    K = np.int32(3)
    axis = np.int32(1)
    node = ng.topk(data_parameter, K, axis, 'max', 'value')
    assert node.get_type_name() == 'TopK'
    assert node.get_output_size() == 2
    assert list(node.get_output_shape(0)) == [6, 3, 10, 24]
    assert list(node.get_output_shape(1)) == [6, 3, 10, 24]
Beispiel #3
0
def test_discrete_type_info():
    data_shape = [6, 12, 10, 24]
    data_parameter = ng.parameter(data_shape, name="Data", dtype=np.float32)
    k = np.int32(3)
    axis = np.int32(1)
    n1 = ng.topk(data_parameter, k, axis, "max", "value")
    n2 = ng.topk(data_parameter, k, axis, "max", "value")
    n3 = ng.sin(0.2)

    assert n1.type_info.name == "TopK"
    assert n3.type_info.name == "Sin"
    assert n1.get_type_info().name == "TopK"
    assert n3.get_type_info().name == "Sin"
    assert n1.type_info.name == n2.type_info.name
    assert n1.type_info.version == n2.type_info.version
    assert n1.type_info.parent == n2.type_info.parent
    assert n1.get_type_info().name == n2.get_type_info().name
    assert n1.get_type_info().version == n2.get_type_info().version
    assert n1.get_type_info().parent == n2.get_type_info().parent
    assert n1.get_type_info().name != n3.get_type_info().name
    assert n1.get_type_info().name > n3.get_type_info().name
    assert n1.get_type_info().name >= n3.get_type_info().name
    assert n3.get_type_info().name < n1.get_type_info().name
    assert n3.get_type_info().name <= n1.get_type_info().name
Beispiel #4
0
def test_topk():
    runtime = get_runtime()
    input_x = ng.constant(
        np.array([[9, 2, 10], [12, 8, 4], [6, 1, 5], [3, 11, 7]],
                 dtype=np.float32))
    comp_topk = ng.topk(input_x, 4, 0)
    model0 = runtime.computation(ng.get_output_element(comp_topk, 0))
    result0 = model0()
    assert np.allclose(
        result0,
        np.array([[1, 3, 0], [0, 1, 3], [2, 0, 2], [3, 2, 1]], dtype=np.int32))
    model1 = runtime.computation(ng.get_output_element(comp_topk, 1))
    result1 = model1()
    assert np.allclose(
        result1,
        np.array([[12, 11, 10], [9, 8, 7], [6, 2, 5], [3, 1, 4]],
                 dtype=np.float32))