Пример #1
0
    def use_shapes_from_ir(node: Node):
        # This function used instead of operation shape inference function to set all output shapes the same as
        # restored from IR. Firstly, check equality of old (restored from IR) and
        # new (calculated while shape inference) input shapes
        node['new_input_shapes'] = list()
        for n in node.in_ports():
            if not node.in_port(n).disconnected(
            ):  # We use such condition to handle optional inputs
                node.new_input_shapes.append(node.in_port(n).data.get_shape())
        assert len(node.new_input_shapes) == len(node.old_input_shapes), \
            'Something wrong happened while {} node with type {} copy shape inference!'.format(node.name, node.type)
        for new_input_shape, old_input_shape in zip(node.new_input_shapes,
                                                    node.old_input_shapes):
            assert np.array_equal(new_input_shape, old_input_shape), \
                'Something wrong happened while {} node with type {} copy shape inference!'.format(node.name, node.type)

        # We need to use number of connected input ports to avoid errors with numbering
        # in node.ports dictionary, where used numbers of input nodes
        connected_input_ports = []
        for n in node.in_ports():
            if not node.in_port(n).disconnected():
                connected_input_ports.append(node.in_port(n))
        i = len(connected_input_ports)

        # Set all output shapes the same as restored from IR
        for num in node.out_ports():
            node.out_port(num).data.set_shape(int64_array(node.ports[i][0]))
            i += 1
Пример #2
0
def replace_with_hsigmoid(graph: Graph, first_node: Node, last_node: Node):
    # determine the input port of first and last nodes which gets the 'input' node output
    add_input_port_idx = int(
        first_node.in_port(0).get_connection().get_source().node.soft_get('op')
        == 'Const')
    last_node_name = last_node.soft_get('name', last_node.id)

    hsigmoid = HSigmoid(graph, {}).create_node()
    hsigmoid.in_port(0).connect(
        first_node.in_port(add_input_port_idx).get_source())
    last_node.out_port(0).get_connection().set_source(hsigmoid.out_port(0))

    rename_nodes([(last_node, last_node_name + '/TBR'),
                  (hsigmoid, last_node_name)])
Пример #3
0
 def split_offset(offset_node: Node):
     paired_node = MemoryOffset(
         offset_node.graph, {
             'name': offset_node.pair_name,
             'splitted': True,
             'pair_name': offset_node.id,
             'element_size': offset_node['element_size'],
             't': offset_node.t,
             'has_default': offset_node.has_default
         }).create_node()
     offset_node['splitted'] = True
     offset_node.out_port(0).get_connection().set_source(
         paired_node.out_port(0))
     res_node = Result(offset_node.graph, {
         'name': offset_node.id + '_output'
     }).create_node()
     offset_node.out_port(0).connect(res_node.in_port(0))
Пример #4
0
 def extend(op: Node):
     if op.out_port(0).disconnected():
         op['remove_values_output'] = True
     if op.has_valid('index_element_type'):
         op['index_element_type'] = destination_type_to_np_data_type(
             op.index_element_type)