示例#1
0
    def test_serialize_old_api_map_result(self):
        graph = build_graph(
            {
                **regular_op('placeholder', {
                    'type': 'Parameter',
                    'rt_info': RTInfo()
                }),
                **regular_op('result', {
                    'type': 'Result',
                    'rt_info': RTInfo()
                })
            }, [('placeholder', 'result')], {},
            nodes_with_edges_only=True)
        result_node = Node(graph, 'result')
        result_node.rt_info.info[('old_api_map_order', 0)] = OldAPIMapOrder()
        result_node.rt_info.info[('old_api_map_order',
                                  0)].old_api_transpose_result([0, 3, 1, 2])

        net = Element('net')
        serialize_runtime_info(result_node, net)
        serialize_res = str(tostring(net))
        self.assertTrue("name=\"old_api_map_order\"" in serialize_res)
        self.assertTrue("version=\"0\"" in serialize_res)
        self.assertTrue("value=\"0,3,1,2\"" in serialize_res)
        self.assertTrue(serialize_res.startswith("b'<net><rt_info>"))
        self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
示例#2
0
    def test_serialize_old_api_map_parameter(self):
        graph = build_graph({**regular_op('placeholder', {'type': 'Parameter', 'rt_info': RTInfo()}),
                             **result('result')},
                            [('placeholder', 'result')], {}, nodes_with_edges_only=True)
        param_node = Node(graph, 'placeholder')
        param_node.rt_info.info[('old_api_map_order', 0)] = OldAPIMapOrder()
        param_node.rt_info.info[('old_api_map_order', 0)].old_api_transpose_parameter([0, 2, 3, 1])
        param_node.rt_info.info[('old_api_map_element_type', 0)] = OldAPIMapElementType()
        param_node.rt_info.info[('old_api_map_element_type', 0)].set_legacy_type(np.float32)

        net = Element('net')
        serialize_runtime_info(param_node, net)
        serialize_res = str(tostring(net))
        self.assertTrue("name=\"old_api_map_order\"" in serialize_res)
        self.assertTrue("name=\"old_api_map_element_type\"" in serialize_res)
        self.assertTrue("version=\"0\"" in serialize_res)
        self.assertTrue("value=\"0,2,3,1\"" in serialize_res)
        self.assertTrue("value=\"f32\"" in serialize_res)
        self.assertTrue(serialize_res.startswith("b'<net><rt_info>"))
        self.assertTrue(serialize_res.endswith("</rt_info></net>'"))

        del param_node.rt_info.info[('old_api_map_order', 0)]
        param_node.rt_info.info[('old_api_map_element_type', 0)] = OldAPIMapElementType()
        param_node.rt_info.info[('old_api_map_element_type', 0)].set_legacy_type(np.float16)

        net = Element('net')
        serialize_runtime_info(param_node, net)
        serialize_res = str(tostring(net))
        self.assertTrue("name=\"old_api_map_element_type\"" in serialize_res)
        self.assertTrue("version=\"0\"" in serialize_res)
        self.assertTrue("value=\"f16\"" in serialize_res)
        self.assertTrue(serialize_res.startswith("b'<net><rt_info>"))
        self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
 def test_get_fw_index(self):
     graph = build_graph(nodes, [*connect('placeholder1', 'result')])
     node = Node(graph, 'placeholder1')
     old_api_map = OldAPIMapOrder(version=0)
     node.rt_info.info[('old_api_map_order',
                        old_api_map.get_version())] = old_api_map
     node.rt_info.info[(
         'old_api_map_order',
         old_api_map.get_version())].old_api_transpose_parameter(
             [0, 2, 3, 1])
     self.assertTrue(InsertReverseChannels.get_fw_index(node, 0) == 0)
     self.assertTrue(InsertReverseChannels.get_fw_index(node, 1) == 3)
     self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1)
     self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2)
     self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1)
     self.assertTrue(
         type(InsertReverseChannels.get_fw_index(node, 0)) == int)
    def test_insert_old_api_map(self):
        graph = build_graph(get_nodes([1, 10, 10, 3]),
                            [*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'),
                             *connect('mul', 'result')], nodes_with_edges_only=True,
                            cli=Namespace(reverse_input_channels=True))

        node = Node(graph, 'placeholder1')
        old_api_map = OldAPIMapOrder(version=0)
        node.rt_info.info[('old_api_map_order', old_api_map.get_version())] = old_api_map
        node.rt_info.info[('old_api_map_order', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1])

        InsertReverseChannels().find_and_replace_pattern(graph)
        graph_ref = build_graph(get_nodes([1, 10, 10, 3], 3),
                                [*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'),
                                 *connect('placeholder2', '1:mul'), *connect('mul', 'result')])

        node2 = Node(graph_ref, 'placeholder1')
        node2.rt_info = node.rt_info

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
    def add_old_api_map_order_into_rt_info(op: Node):
        # rt info update
        assert op.has(
            'rt_info'
        ), 'Unable to preserve runtime information for node with name={}'.format(
            op)

        old_api_map = OldAPIMapOrder(version=0)
        attr_name = old_api_map.get_name()
        if (attr_name, old_api_map.get_version()) not in op.rt_info.info:
            op.rt_info.info[(attr_name,
                             old_api_map.get_version())] = old_api_map
        return attr_name, old_api_map.get_version()
示例#6
0
    def __read_old_api_map_order(attr, layer_type):
        version = int(attr.attrib['version'])
        order = list(map(int, attr.attrib['value'].split(',')))
        old_api_map = OldAPIMapOrder(version=version)
        if layer_type == 'Parameter':
            old_api_map.old_api_transpose_parameter(order)
        elif layer_type == 'Result':
            old_api_map.old_api_transpose_result(order)
        else:
            raise AttributeError(
                "Cannot read old_api_map for layer of type: {}".format(
                    layer_type))

        return {('old_api_map_order', version): old_api_map}