def test_infer_constant_input_sum(self): graph = build_graph(nodes_attributes2, edges2, inputs2) sparse_segment_sum_node = Node(graph, 'sparse_segment_sum_node') SparseSegmentSum.infer(sparse_segment_sum_node) # prepare reference results ref_output_segments_shape = int64_array([2, 4]) ref_output_segments_value = np.array([[0, 0, 0, 0], [5, 6, 7, 8]], dtype=np.float) # get resulted shapes res_output_segments_shape = graph.node['output_segments']['shape'] res_output_segments_value = graph.node['output_segments']['value'] self.assertTrue( np.array_equal(ref_output_segments_shape, res_output_segments_shape), 'Shapes do not match expected: {} and given: {}'.format( ref_output_segments_shape, res_output_segments_shape)) self.assertTrue( np.array_equal(ref_output_segments_value, res_output_segments_value), 'Shapes do not match expected: {} and given: {}'.format( ref_output_segments_value, res_output_segments_value))
def test_partial_infer(self): graph = build_graph(nodes_attributes1, edges1, inputs1) sparse_segment_sum_node = Node(graph, 'sparse_segment_sum_node') SparseSegmentSum.infer(sparse_segment_sum_node) # prepare reference results ref_output_segments_shape = int64_array([40, 4, 5]) # get resulted shapes res_output_segments_shape = graph.node['output_segments']['shape'] self.assertTrue(np.array_equal(ref_output_segments_shape, res_output_segments_shape), 'Shapes do not match expected: {} and given: {}'.format(ref_output_segments_shape, res_output_segments_shape))