Пример #1
0
    def replace_identityN(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'data_types'), 'IdentityN {} has no `data_types` attribute'.format(
                name)
        dtypes = node.data_types

        for idx, port in node.in_ports().items():
            if not node.is_in_port_connected(
                    idx) or not node.is_out_port_connected(idx):
                # ATTENTION section in the description above
                continue
            assert idx < len(
                dtypes
            ), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(
                name, dtypes)
            identity = Identity(graph, {
                'name': '{}/{}_port'.format(name, idx),
                'data_type': dtypes[idx]
            }).create_node()
            port.get_connection().set_destination(identity.in_port(0))
            node.out_port(idx).get_connection().set_source(
                identity.out_port(0))

        # ATTENTION section in the description above
        for in_port in node.in_ports().values():
            in_port.disconnect()
        for out_port in node.out_ports().values():
            out_port.disconnect()
Пример #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        identity = Identity(graph, {'name': node.soft_get('name', node.id)}).create_node()
        node.in_port(0).get_connection().set_destination(identity.in_port(0))

        for idx, port in node.out_ports().items():
            port.get_connection().set_source(identity.out_port(0))
Пример #3
0
 def extract(cls, node):
     # some Dropout flavors doesn't have is_test attribute; when it is missing, interpret it as 1
     is_test = onnx_attr(node, 'is_test', 'i', 1)
     if len(node.out_nodes()) > 1:
         raise Error('Dropout node {} has more than one consumer. Unsupported.', node.name)
     if not is_test:
         raise Error('Dropout node {} has is_test: 0. This means training mode which is not supported.', node.name)
     Identity.update_node_stat(node)
     return cls.enabled
Пример #4
0
    def extract(cls, node):
        pb = node.parameters

        collect_until_token(pb, b'<Dim>')
        dim = read_binary_integer32_token(pb)

        collect_until_token(pb, b'<BlockDim>')
        block_dim = read_binary_integer32_token(pb)

        collect_until_token(pb, b'<TimePeriod>')
        time_period = read_binary_integer32_token(pb)

        collect_until_token(pb, b'<DropoutProportion>')
        dropout_proporion = read_binary_float_token(pb)

        # collect_until_token(pb, b'<Continuous>')
        Identity.update_node_stat(node, {})

        return cls.enabled
Пример #5
0
 def extract(cls, node):
     Identity.update_node_stat(node)
     return cls.enabled
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpleConditionFind ===============')
        # init_1
        init_1 = match['init_1_data'].value
        assert init_1 is not None
        init_1 = int(init_1)

        # step_1
        assert match['add_1_y_data'].value is not None
        step_1 = int(match['add_1_y_data'].value)

        match['loop_cond_data'].value = None

        # compute destination (or consumer) ports for time node
        identity_node_name = match['Identity_1'].soft_get(
            'name', match['Identity_1'].id)
        time_dsts = match['Identity_1'].out_port(0).get_destinations()

        # Create condition node and delete all useless nodes from condition pattern
        condition_attrs = dict(iter=dict(init=init_1, step=step_1),
                               name=match['loop_cond'].name +
                               '/TensorIteratorCondition_')
        condition = TensorIteratorCondition(graph, attrs=condition_attrs)
        condition.create_node_with_data(
            inputs=[match['Strided_slice_data']],
            data_nodes=[match['loop_cond_data'], match['Identity_1_data']])

        safe_nodes = [
            'loop_cond_data', 'Identity_1_data', 'Strided_slice',
            'Strided_slice_data'
        ]

        # check if time node has other consumers  different from increment node,
        # input slicing and output concatenation nodes
        other_time_consumers = False
        for time_consumer in time_dsts:
            if time_consumer.node.soft_get('op') not in ['TensorIteratorInput', 'TensorIteratorOutput'] and \
                    time_consumer.node.id != match['add_1'].id:
                other_time_consumers = True
                break
        if other_time_consumers:
            # save time related nodes since they have other consumers different from
            # input slicing and output concatenation nodes
            safe_nodes += [
                'init_1', 'init_1_data', 'Enter_1', 'Enter_1_data', 'Merge_1',
                'Merge_1_data', 'Switch_1', 'Switch_1_data', 'add_1',
                'add_1_y', 'add_1_y_data', 'add_1_data', 'NextIteration_1'
            ]
            switch_node = match['Switch_1']
            new_identity_node = Identity(
                graph, dict(name=identity_node_name)).create_node()
            switch_node.out_port(1).connect(new_identity_node.in_port(0))

            # make the graph consistent to avoid multiple producers by the same input port
            graph.remove_nodes_from([match['Identity_1'].id])
            rename_nodes([(new_identity_node, identity_node_name)])

            for time_consumer in time_dsts:
                if time_consumer.node.soft_get('op') not in [
                        'TensorIteratorInput', 'TensorIteratorOutput'
                ]:
                    time_consumer.get_connection().set_source(
                        new_identity_node.out_port(0))

        # Delete useless nodes
        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
Пример #7
0
 def extract(cls, node: Node):
     Identity.update_node_stat(node, {'op': 'StopGradient'})
     return cls.enabled
Пример #8
0
 def extract(cls, node: Node):
     Identity.update_node_stat(
         node, {
             'data_type': tf_dtype_extractor(node.pb.attr["T"].type),
         })
     return cls.enabled