def test_non_max_suppression(): boxes_shape = [1, 1000, 4] scores_shape = [1, 1, 1000] boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32) scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32) node = ng.non_max_suppression(boxes_parameter, scores_parameter) assert node.get_type_name() == "NonMaxSuppression" assert node.get_output_size() == 3 assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 3])) assert node.get_output_partial_shape(1).same_scheme(PartialShape([-1, 3])) assert list(node.get_output_shape(2)) == [1]
def test_non_max_suppression(): boxes_shape = [1, 1000, 4] scores_shape = [1, 1, 1000] boxes_parameter = ng.parameter(boxes_shape, name="Boxes", dtype=np.float32) scores_parameter = ng.parameter(scores_shape, name="Scores", dtype=np.float32) node = ng.non_max_suppression(boxes_parameter, scores_parameter, make_constant_node(1000, np.int64)) assert node.get_type_name() == "NonMaxSuppression" assert node.get_output_size() == 3 assert node.get_output_partial_shape(0) == PartialShape([Dimension(0, 1000), Dimension(3)]) assert node.get_output_partial_shape(1) == PartialShape([Dimension(0, 1000), Dimension(3)]) assert list(node.get_output_shape(2)) == [1]
def test_embedding_segments_sum_with_some_opt_inputs(): emb_table = ng.parameter([5, 2], name="emb_table", dtype=np.float32) indices = ng.parameter([4], name="indices", dtype=np.int64) segment_ids = ng.parameter([4], name="segment_ids", dtype=np.int64) num_segments = ng.parameter([], name="num_segments", dtype=np.int64) # only 1 out of 3 optional inputs node = ng.embedding_segments_sum(emb_table, indices, segment_ids, num_segments) assert node.get_type_name() == "EmbeddingSegmentsSum" assert node.get_output_size() == 1 assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 2])) assert node.get_output_element_type(0) == Type.f32
def test_embedding_segments_sum_all_inputs(): emb_table = ng.parameter([5, 2], name="emb_table", dtype=np.float32) indices = ng.parameter([4], name="indices", dtype=np.int64) segment_ids = ng.parameter([4], name="segment_ids", dtype=np.int64) num_segments = ng.parameter([], name="num_segments", dtype=np.int64) default_index = ng.parameter([], name="default_index", dtype=np.int64) per_sample_weights = ng.parameter([4], name="per_sample_weights", dtype=np.float32) node = ng.embedding_segments_sum( emb_table, indices, segment_ids, num_segments, default_index, per_sample_weights ) assert node.get_type_name() == "EmbeddingSegmentsSum" assert node.get_output_size() == 1 assert node.get_output_partial_shape(0).same_scheme(PartialShape([-1, 2])) assert node.get_output_element_type(0) == Type.f32