Exemplo n.º 1
0
def test_non_max_suppression():

    boxes_shape = [1, 1000, 4]
    scores_shape = [1, 1, 1000]
    boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32)
    scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32)

    node = ng.non_max_suppression(boxes_parameter, scores_parameter)

    assert node.get_type_name() == "NonMaxSuppression"
    assert node.get_output_size() == 3
    assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 3]))
    assert node.get_output_partial_shape(1).same_scheme(PartialShape([-1, 3]))
    assert list(node.get_output_shape(2)) == [1]
Exemplo n.º 2
0
def test_non_max_suppression():

    boxes_shape = [1, 1000, 4]
    scores_shape = [1, 1, 1000]
    boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32)
    scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32)

    node = ng.non_max_suppression(boxes_parameter, scores_parameter, make_constant_node(1000, np.int64))

    assert node.get_type_name() == "NonMaxSuppression"
    assert node.get_output_size() == 3
    assert node.get_output_partial_shape(0) == PartialShape([Dimension(0, 1000), Dimension(3)])
    assert node.get_output_partial_shape(1) == PartialShape([Dimension(0, 1000), Dimension(3)])
    assert list(node.get_output_shape(2)) == [1]
Exemplo n.º 3
0
def test_embedding_segments_sum_with_some_opt_inputs():
    emb_table = ng.parameter([5, 2], name="emb_table", dtype=np.float32)
    indices = ng.parameter([4], name="indices", dtype=np.int64)
    segment_ids = ng.parameter([4], name="segment_ids", dtype=np.int64)
    num_segments = ng.parameter([], name="num_segments", dtype=np.int64)

    # only 1 out of 3 optional inputs
    node = ng.embedding_segments_sum(emb_table, indices, segment_ids, num_segments)

    assert node.get_type_name() == "EmbeddingSegmentsSum"
    assert node.get_output_size() == 1
    assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 2]))
    assert node.get_output_element_type(0) == Type.f32
Exemplo n.º 4
0
def test_embedding_segments_sum_all_inputs():
    emb_table = ng.parameter([5, 2], name="emb_table", dtype=np.float32)
    indices = ng.parameter([4], name="indices", dtype=np.int64)
    segment_ids = ng.parameter([4], name="segment_ids", dtype=np.int64)
    num_segments = ng.parameter([], name="num_segments", dtype=np.int64)
    default_index = ng.parameter([], name="default_index", dtype=np.int64)
    per_sample_weights = ng.parameter([4], name="per_sample_weights", dtype=np.float32)

    node = ng.embedding_segments_sum(
        emb_table, indices, segment_ids, num_segments, default_index, per_sample_weights
    )

    assert node.get_type_name() == "EmbeddingSegmentsSum"
    assert node.get_output_size() == 1
    assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 2]))
    assert node.get_output_element_type(0) == Type.f32