def test_infer5(self): graph = build_graph(nodes_attributes, edges, inputs5) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) # get the result res_output_value = graph.node['output']['value'] self.assertTrue( np.array_equal(output5, res_output_value), 'values do not match expected: {} and given: {}'.format( output5, res_output_value))
def test_infer9_opset_5(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 graph = build_graph(nodes_attributes, edges, inputs8) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) # get the result res_output_value = graph.node['output']['value'] output = output8.reshape([6, 3]) self.assertTrue( np.array_equal(output, res_output_value), 'values do not match expected: {} and given: {}'.format( output, res_output_value))
def test_partial_infer_gather_slice(self): graph = build_graph(nodes_attributes, edges, inputs1) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) # prepare reference results ref_output_shape = int64_array([3, 30]) # get the result res_output_shape = graph.node['output']['shape'] self.assertTrue( np.array_equal(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format( ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic3(self): nodes_attributes['gathernd_node']['batch_dims'] = 2 graph = build_graph(nodes_attributes, edges, inputs11) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) # prepare reference results ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9]) # get the result res_output_shape = graph.node['output']['shape'] self.assertTrue( strict_compare_tensors(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format( ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims3_opset8(self): nodes_attributes['gathernd_node']['batch_dims'] = 3 nodes_attributes['gathernd_node']['version'] = 'opset8' graph = build_graph(nodes_attributes, edges, inputs3) gathernd_node = Node(graph, 'gathernd_node') GatherND.infer(gathernd_node) # prepare reference results ref_output_shape = int64_array([1, 64, 64, 1]) # get the result res_output_shape = graph.node['output']['shape'] self.assertTrue( np.array_equal(ref_output_shape, res_output_shape), 'values do not match expected: {} and given: {}'.format( ref_output_shape, res_output_shape))