示例#1
0
    def test_infer5(self):
        graph = build_graph(nodes_attributes, edges, inputs5)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # get the result
        res_output_value = graph.node['output']['value']

        self.assertTrue(np.array_equal(output5, res_output_value),
                        'values do not match expected: {} and given: {}'.format(output4, res_output_value))
示例#2
0
    def test_partial_infer_gather_slice(self):
        graph = build_graph(nodes_attributes, edges, inputs2)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # prepare reference results
        ref_output_shape = int64_array([3, 30])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
                        'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
示例#3
0
    def test_infer9_opset_5(self):
        nodes_attributes['gathernd_node']['batch_dims'] = 2
        graph = build_graph(nodes_attributes, edges, inputs8)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # get the result
        res_output_value = graph.node['output']['value']

        output = output8.reshape([6, 3])
        self.assertTrue(
            np.array_equal(output, res_output_value),
            'values do not match expected: {} and given: {}'.format(
                output, res_output_value))
示例#4
0
    def test_partial_infer_gather_slice_batch_dims2_dynamic3(self):
        nodes_attributes['gathernd_node']['batch_dims'] = 2
        graph = build_graph(nodes_attributes, edges, inputs11)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # prepare reference results
        ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            strict_compare_tensors(ref_output_shape, res_output_shape),
            'values do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
示例#5
0
    def test_partial_infer_gather_slice_batch_dims3_opset8(self):
        nodes_attributes['gathernd_node']['batch_dims'] = 3
        nodes_attributes['gathernd_node']['version'] = 'opset8'
        graph = build_graph(nodes_attributes, edges, inputs3)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # prepare reference results
        ref_output_shape = int64_array([1, 64, 64, 1])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            np.array_equal(ref_output_shape, res_output_shape),
            'values do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
示例#6
0
 def extract(cls, node):
     attrs = {'batch_dims': onnx_attr(node, 'batch_dims', 'i', default=0)}
     GatherND.update_node_stat(node, attrs)
     return cls.enabled
示例#7
0
 def extract(cls, node):
     attrs = {
         'batch_dims': 0,
     }
     GatherND.update_node_stat(node, attrs)
     return cls.enabled