def test_serialize_old_api_map_result(self): graph = build_graph( { **regular_op('placeholder', { 'type': 'Parameter', 'rt_info': RTInfo() }), **regular_op('result', { 'type': 'Result', 'rt_info': RTInfo() }) }, [('placeholder', 'result')], {}, nodes_with_edges_only=True) result_node = Node(graph, 'result') result_node.rt_info.info[('old_api_map_order', 0)] = OldAPIMapOrder() result_node.rt_info.info[('old_api_map_order', 0)].old_api_transpose_result([0, 3, 1, 2]) net = Element('net') serialize_runtime_info(result_node, net) serialize_res = str(tostring(net)) self.assertTrue("name=\"old_api_map_order\"" in serialize_res) self.assertTrue("version=\"0\"" in serialize_res) self.assertTrue("value=\"0,3,1,2\"" in serialize_res) self.assertTrue(serialize_res.startswith("b'<net><rt_info>")) self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
def test_serialize_old_api_map_parameter(self): graph = build_graph({**regular_op('placeholder', {'type': 'Parameter', 'rt_info': RTInfo()}), **result('result')}, [('placeholder', 'result')], {}, nodes_with_edges_only=True) param_node = Node(graph, 'placeholder') param_node.rt_info.info[('old_api_map_order', 0)] = OldAPIMapOrder() param_node.rt_info.info[('old_api_map_order', 0)].old_api_transpose_parameter([0, 2, 3, 1]) param_node.rt_info.info[('old_api_map_element_type', 0)] = OldAPIMapElementType() param_node.rt_info.info[('old_api_map_element_type', 0)].set_legacy_type(np.float32) net = Element('net') serialize_runtime_info(param_node, net) serialize_res = str(tostring(net)) self.assertTrue("name=\"old_api_map_order\"" in serialize_res) self.assertTrue("name=\"old_api_map_element_type\"" in serialize_res) self.assertTrue("version=\"0\"" in serialize_res) self.assertTrue("value=\"0,2,3,1\"" in serialize_res) self.assertTrue("value=\"f32\"" in serialize_res) self.assertTrue(serialize_res.startswith("b'<net><rt_info>")) self.assertTrue(serialize_res.endswith("</rt_info></net>'")) del param_node.rt_info.info[('old_api_map_order', 0)] param_node.rt_info.info[('old_api_map_element_type', 0)] = OldAPIMapElementType() param_node.rt_info.info[('old_api_map_element_type', 0)].set_legacy_type(np.float16) net = Element('net') serialize_runtime_info(param_node, net) serialize_res = str(tostring(net)) self.assertTrue("name=\"old_api_map_element_type\"" in serialize_res) self.assertTrue("version=\"0\"" in serialize_res) self.assertTrue("value=\"f16\"" in serialize_res) self.assertTrue(serialize_res.startswith("b'<net><rt_info>")) self.assertTrue(serialize_res.endswith("</rt_info></net>'"))
def test_get_fw_index(self): graph = build_graph(nodes, [*connect('placeholder1', 'result')]) node = Node(graph, 'placeholder1') old_api_map = OldAPIMapOrder(version=0) node.rt_info.info[('old_api_map_order', old_api_map.get_version())] = old_api_map node.rt_info.info[( 'old_api_map_order', old_api_map.get_version())].old_api_transpose_parameter( [0, 2, 3, 1]) self.assertTrue(InsertReverseChannels.get_fw_index(node, 0) == 0) self.assertTrue(InsertReverseChannels.get_fw_index(node, 1) == 3) self.assertTrue(InsertReverseChannels.get_fw_index(node, 2) == 1) self.assertTrue(InsertReverseChannels.get_fw_index(node, 3) == 2) self.assertTrue(InsertReverseChannels.get_fw_index(node, -2) == 1) self.assertTrue( type(InsertReverseChannels.get_fw_index(node, 0)) == int)
def test_insert_old_api_map(self): graph = build_graph(get_nodes([1, 10, 10, 3]), [*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'), *connect('mul', 'result')], nodes_with_edges_only=True, cli=Namespace(reverse_input_channels=True)) node = Node(graph, 'placeholder1') old_api_map = OldAPIMapOrder(version=0) node.rt_info.info[('old_api_map_order', old_api_map.get_version())] = old_api_map node.rt_info.info[('old_api_map_order', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1]) InsertReverseChannels().find_and_replace_pattern(graph) graph_ref = build_graph(get_nodes([1, 10, 10, 3], 3), [*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'), *connect('placeholder2', '1:mul'), *connect('mul', 'result')]) node2 = Node(graph_ref, 'placeholder1') node2.rt_info = node.rt_info (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def add_old_api_map_order_into_rt_info(op: Node): # rt info update assert op.has( 'rt_info' ), 'Unable to preserve runtime information for node with name={}'.format( op) old_api_map = OldAPIMapOrder(version=0) attr_name = old_api_map.get_name() if (attr_name, old_api_map.get_version()) not in op.rt_info.info: op.rt_info.info[(attr_name, old_api_map.get_version())] = old_api_map return attr_name, old_api_map.get_version()
def __read_old_api_map_order(attr, layer_type): version = int(attr.attrib['version']) order = list(map(int, attr.attrib['value'].split(','))) old_api_map = OldAPIMapOrder(version=version) if layer_type == 'Parameter': old_api_map.old_api_transpose_parameter(order) elif layer_type == 'Result': old_api_map.old_api_transpose_result(order) else: raise AttributeError( "Cannot read old_api_map for layer of type: {}".format( layer_type)) return {('old_api_map_order', version): old_api_map}