コード例 #1
0
    def test_infer5(self):
        graph = build_graph(nodes_attributes, edges, inputs5)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # get the result
        res_output_value = graph.node['output']['value']

        self.assertTrue(
            np.array_equal(output5, res_output_value),
            'values do not match expected: {} and given: {}'.format(
                output5, res_output_value))
コード例 #2
0
    def test_infer9_opset_5(self):
        nodes_attributes['gathernd_node']['batch_dims'] = 2
        graph = build_graph(nodes_attributes, edges, inputs8)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # get the result
        res_output_value = graph.node['output']['value']

        output = output8.reshape([6, 3])
        self.assertTrue(
            np.array_equal(output, res_output_value),
            'values do not match expected: {} and given: {}'.format(
                output, res_output_value))
コード例 #3
0
    def test_partial_infer_gather_slice(self):
        graph = build_graph(nodes_attributes, edges, inputs1)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # prepare reference results
        ref_output_shape = int64_array([3, 30])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            np.array_equal(ref_output_shape, res_output_shape),
            'values do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
コード例 #4
0
    def test_partial_infer_gather_slice_batch_dims2_dynamic3(self):
        nodes_attributes['gathernd_node']['batch_dims'] = 2
        graph = build_graph(nodes_attributes, edges, inputs11)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # prepare reference results
        ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            strict_compare_tensors(ref_output_shape, res_output_shape),
            'values do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))
コード例 #5
0
    def test_partial_infer_gather_slice_batch_dims3_opset8(self):
        nodes_attributes['gathernd_node']['batch_dims'] = 3
        nodes_attributes['gathernd_node']['version'] = 'opset8'
        graph = build_graph(nodes_attributes, edges, inputs3)
        gathernd_node = Node(graph, 'gathernd_node')
        GatherND.infer(gathernd_node)

        # prepare reference results
        ref_output_shape = int64_array([1, 64, 64, 1])

        # get the result
        res_output_shape = graph.node['output']['shape']

        self.assertTrue(
            np.array_equal(ref_output_shape, res_output_shape),
            'values do not match expected: {} and given: {}'.format(
                ref_output_shape, res_output_shape))