def use_shapes_from_ir(node: Node): # This function used instead of operation shape inference function to set all output shapes the same as # restored from IR. Firstly, check equality of old (restored from IR) and # new (calculated while shape inference) input shapes node['new_input_shapes'] = list() for n in node.in_ports(): if not node.in_port(n).disconnected( ): # We use such condition to handle optional inputs node.new_input_shapes.append(node.in_port(n).data.get_shape()) assert len(node.new_input_shapes) == len(node.old_input_shapes), \ 'Something wrong happened while {} node with type {} copy shape inference!'.format(node.name, node.type) for new_input_shape, old_input_shape in zip(node.new_input_shapes, node.old_input_shapes): assert np.array_equal(new_input_shape, old_input_shape), \ 'Something wrong happened while {} node with type {} copy shape inference!'.format(node.name, node.type) # We need to use number of connected input ports to avoid errors with numbering # in node.ports dictionary, where used numbers of input nodes connected_input_ports = [] for n in node.in_ports(): if not node.in_port(n).disconnected(): connected_input_ports.append(node.in_port(n)) i = len(connected_input_ports) # Set all output shapes the same as restored from IR for num in node.out_ports(): if i in node.ports: node.out_port(num).data.set_shape(int64_array( node.ports[i][0])) else: assert node.out_port(num).data.get_shape( ) is not None, "Newly added port does not have set shape" i += 1
def replace_with_hsigmoid(graph: Graph, first_node: Node, last_node: Node): # determine the input port of first and last nodes which gets the 'input' node output add_input_port_idx = int( first_node.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const') last_node_name = last_node.soft_get('name', last_node.id) hsigmoid = HSigmoid(graph, {}).create_node() hsigmoid.in_port(0).connect( first_node.in_port(add_input_port_idx).get_source()) last_node.out_port(0).get_connection().set_source(hsigmoid.out_port(0)) rename_nodes([(last_node, last_node_name + '/TBR'), (hsigmoid, last_node_name)])