예제 #1
0
 def test_name_looks_like_port_number(self):
     nodes = {
         'input_id': {
             'type': 'Parameter',
             'kind': 'op',
             'op': 'Parameter',
             'name': '0'
         },
         'conv_id': {
             'type': 'Convolution',
             'kind': 'op',
             'op': 'NotPlaceholder',
             'name': '1'
         },
         'relu_id': {
             'type': 'ReLU',
             'kind': 'op',
             'op': 'NotPlaceholder',
             'name': '2'
         },
     }
     edges = [
         ('input_id', 'conv_id'),
         ('conv_id', 'relu_id'),
     ]
     graph = build_graph(nodes, edges)
     node_id, direction, port = get_node_id_with_ports(graph, '0:2')
     self.assertEqual(node_id, 'relu_id')
     self.assertEqual(direction, 'in')
     self.assertEqual(port, 0)
예제 #2
0
    def find_and_replace_pattern(self, graph: Graph):
        values = graph.graph['cmd_params'].mean_scale_values
        input_nodes = graph.get_op_nodes(op='Parameter')

        if not isinstance(values, dict):
            # The case when input names to apply mean/scales weren't specified
            if len(values) != len(input_nodes):
                raise Error('Numbers of inputs and mean/scale values do not match. ' + refer_to_faq_msg(61))

            data = np.copy(values)
            values = {}
            for idx, node in enumerate(input_nodes):
                values.update(
                    {
                        node.soft_get('name', node.id): {
                            'mean': data[idx][0],
                            'scale': data[idx][1]
                        }
                    }
                )

        for node_name, node_mean_scale_values in values.items():
            node_id = None
            node_name = get_node_name_with_port_from_input_value(node_name)
            try:
                node_id, direction, port = get_node_id_with_ports(graph, node_name, skip_if_no_port=False)
                assert direction != 'out', 'Only input port can be specified for mean/scale application'
            except Error as e:
                log.warning('node_name {} is not found in graph'.format(node_name))
            if Node(graph, node_id) not in input_nodes:
                # if the user cutted-off input of the network then input node name specified in the --scale_values
                # or --mean_values doesn't correspond to a real input node generated by Model Optimizer. But
                # the information about initial input node name is stored in Placeholder's attribute 'initial_node_name'
                new_node_id = None
                for placeholder in input_nodes:
                    try:
                        placeholder_port = int(placeholder.id.split("_")[-1])
                    except Exception as ex:
                        log.debug('Can not get the port number from the node {}'.format(placeholder.id))
                        log.debug('Port will be defined as None')
                        port = None
                    if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_id and (
                            port is None or placeholder_port == port):
                        new_node_id = placeholder.id
                        break
                if new_node_id is None:
                    raise Error('Input with name {} wasn\'t found!'.format(node_name) +
                                refer_to_faq_msg(83))
                node_id = new_node_id

            input_node = Node(graph, node_id)
            AddMeanScaleValues.apply_scale(graph, input_node, node_mean_scale_values)
            AddMeanScaleValues.apply_mean_value(graph, input_node, node_mean_scale_values)
예제 #3
0
 def test_no_port1(self):
     node_id, direction, port = get_node_id_with_ports(
         self.graph, '1input1')
     self.assertEqual(node_id, 'conv_id')
     self.assertEqual(direction, 'port')
     self.assertEqual(port, None)
예제 #4
0
 def test_in_port2(self):
     node_id, direction, port = get_node_id_with_ports(
         self.graph, '0:relu:0')
     self.assertEqual(node_id, 'squeeze_id')
     self.assertEqual(direction, 'in')
     self.assertEqual(port, 0)
예제 #5
0
 def test_in_port1(self):
     node_id, direction, port = get_node_id_with_ports(
         self.graph, '0:1input1')
     self.assertEqual(node_id, 'conv_id')
     self.assertEqual(direction, 'in')
     self.assertEqual(port, 0)
예제 #6
0
 def test_out_port(self):
     node_id, direction, port = get_node_id_with_ports(
         self.graph, '1input1:0:0')
     self.assertEqual(node_id, 'input_id')
     self.assertEqual(direction, 'out')
     self.assertEqual(port, 0)