def merge_infer(node: Node): # we infer only through executable input nodes inferred_nodes = [ n for n in node.in_nodes().values() if n['is_partial_inferred'] ] assert len(inferred_nodes) != 0 tensor = inferred_nodes[0] if len(inferred_nodes) < len(node.in_nodes()): node['is_not_fully_inferred'] = True else: node['is_not_fully_inferred'] = False assert np.all( compatible_shapes(node.shape, inferred_nodes[0].shape) for node in inferred_nodes) inferred_and_executable = [ n for n in node.in_nodes().values() if n['is_partial_inferred'] and 'executable' in n and n['executable'] ] if len(inferred_and_executable) > 0: tensor = inferred_and_executable[0] if all([ tensor.has_valid('value') and n.has_valid('value') and strict_compare_tensors(tensor.value, n.value) for n in inferred_and_executable ]): node.out_node().value = tensor.value.copy() else: node.out_node().value = None # do not use set_shape(tensor.shape) here because input port shape may be different from the calculated output # shape and `set_shape` will raise an error that shape has changed node.out_node(0).shape = shape_array(tensor.shape)
def infer(node: Node): if node.has_and_set('extra_inputs'): assert len(node.in_nodes()) == 8 else: assert len(node.in_nodes()) == 5 assert len(node.out_nodes()) in [1, 2] hidden_shape = node.in_node(1).shape.copy() cell_shape = node.in_node(2).shape.copy() mark_input_bins(node, start_port=3) node.out_node(0).shape = hidden_shape if len(node.out_nodes()) == 2: node.out_node(1).shape = cell_shape hidden_size = hidden_shape[1] if node.has_valid('hidden_size'): if node.hidden_size != hidden_size: raise Error( "Input shape {} for hidden size doesn't match pre-defined hidden_size in node {}" .format(node.in_node(1).shape, node.soft_get('name'))) else: node['hidden_size'] = hidden_size assert cell_shape[1] == hidden_size input_shape = node.in_node(0).shape assert input_shape is not None assert compatible_dims(hidden_shape[0], cell_shape[0]) and \ compatible_dims(cell_shape[0], input_shape[0]), 'States are not broadcast-able by batch for node {}' \ ''.format(node.soft_get('name', node.id))
def roipooling_infer(node: Node): """ Sets shape of output node according specified parameters input blobs and node Sets number from the first input blob, channels from the second one, height and width are specified Parameters ---------- node """ shapes = [node.in_node(i).shape for i in range(len(node.in_nodes()))] if any(s is None for s in shapes): return if len(node.in_nodes()) == 4: # TensorFlow case of CropAndResize operation crop_size = node.in_node(3).value if crop_size is None: log.error('The ROIPooling size is not known for node {}'.format( node.soft_get('name'))) return if not isinstance(crop_size, np.ndarray) or len(crop_size) != 2: log.error( 'The ROIPooling size is should have 2 elements for node {}'. format(node.soft_get('name'))) node.pooled_h = crop_size[0] node.pooled_w = crop_size[1] node.graph.remove_edge(node.in_node(3).id, node.id) node.graph.remove_edge(node.in_node(2).id, node.id) layout = node.graph.graph['layout'] assert len(layout) == 4 node.out_port(0).data.set_shape( shape_for_layout(layout, batch=shapes[1][get_batch_dim(layout, 4)], features=shapes[0][get_features_dim(layout, 4)], height=node.pooled_h, width=node.pooled_w))
def find_and_replace_pattern(self, graph: Graph): mp = {} used = {} for node in graph.get_op_nodes(type='Concat'): in_nodes = tuple( [node.in_node(idx).id for idx in range(len(node.in_nodes()))]) out_node = (node.id, node.out_node().id) if in_nodes in mp: log.warning("Something is weird! {} and {}".format( node.id, mp[in_nodes])) else: mp.update({in_nodes: out_node}) used.update({node.id: {x: False for x in in_nodes}}) for key in mp.keys(): replacers = [] for i in range(len(key)): for j in range(i + 1, len(key)): arr = tuple(key[i:j + 1]) if arr in mp.keys() and arr != key: replacers.append((len(arr), arr)) replacers.sort(reverse=True) concat_id = mp[key][0] for ln, arr in replacers: # Check that we can do it!!! we_can = True for x in arr: if used[concat_id][x]: we_can = False break if not we_can: continue for x in arr: used[concat_id][x] = True edge_attrs = graph.get_edge_data(arr[0], concat_id)[0] for in_node in arr: graph.remove_edge(in_node, concat_id) new_input = mp[arr][1] out_port = len(Node(graph, new_input).out_nodes()) + 1 edge_attrs['out'] = out_port graph.add_edge(new_input, concat_id, **edge_attrs) # Renumber 'in' attrs concat_node = Node(graph, concat_id) ln = len(concat_node.in_nodes()) ports = [x for x in concat_node.in_nodes().keys()] ports.sort() p_id = 0 for p in ports: in_node = concat_node.in_nodes()[p] graph[in_node.id][concat_id][0]['in'] = p_id p_id += 1
def infer(node: Node): # there are limitations coming from ONNX LSTM definition and normalization rules assert len(node.in_nodes()) >= 3 # X, W and R assert len(node.in_nodes()) <= 7 assert len(node.out_nodes()) <= 3 assert node.batch_dim <= 1 assert node.sequence_dim <= 1 assert node.batch_dim != node.sequence_dim assert node.direction in ['forward', 'reverse', 'bidirectional'] if node.blobs_wrb: mark_input_bins(node, ['W', 'R', 'B']) else: mark_input_bins(node) input_shape = node.in_node(0).shape assert len(input_shape) == 3 for port in [2, 3]: if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \ 'zero_shapes' in node.in_node(port).in_node(): for i in node.in_node(port).in_node().zero_shapes: if node.in_node(port).shape[i] != input_shape[i]: node.in_node(port).value = np.repeat( node.in_node(port).value, input_shape[i], axis=i) node.in_node(port).shape[i] = input_shape[i] out_shape = shape_array([ input_shape[node.sequence_dim], input_shape[node.batch_dim], node.hidden_size ]) assert not node.has_num_directions or node.sequence_dim == 0, \ 'If has_num_directions == True, then node.sequence_dim should be equal 0, but it is {}'.format( node.sequence_dim) num_directions = 2 if node.direction in ['bidirectional'] else 1 num_layers = node.num_layers if node.has_num_directions: # insert extra dimension to output shape for num_directions out_shape = shape_insert(out_shape, 1, np.int64(num_directions)) node.out_node(0).shape = out_shape # extra outputs for hidden/cell states state_size = shape_array([input_shape[1], node.hidden_size]) if node.has_num_directions: state_size = shape_insert(state_size, 0, num_directions * num_layers) for i in [1, 2]: if i not in node.out_nodes(): data_node = Op._create_data_node(node.graph, name=node.node + '/ExtraOutput/' + str(i), attrs={'executable': True}) node.graph.add_edge(node.id, data_node.id, key=0, out=i) add_opoutput(node.graph, data_node.id, 0, False) else: data_node = node.out_node(i) data_node.shape = state_size.copy()
def test_partial_infer(self): graph = build_graph(nodes_attributes, [('node_1', 'concat'), ('node_2', 'concat'), ('concat', 'node_3'), ('node_3', 'op_output')], { 'node_3': { 'kind': 'data', 'shape': None, 'infer': None }, 'node_1': { 'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None }, 'node_2': { 'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None }, 'concat': { 'kind': 'op', 'axis': 2, 'infer': concat_infer } }, nodes_with_edges_only=True) start_node = 'concat' partial_infer(graph, start_node) node = Node(graph, start_node) self.assertTrue(node.is_partial_inferred) self.assertTrue(node.out_node().is_partial_inferred) # check if previous nodes are not inferred node = Node(graph, start_node) while True: # collect nodes in a list if isinstance(node.in_nodes(), list): in_nodes = node.in_nodes() else: in_nodes = [y for x, y in node.in_nodes().items()] # check parents and find next parent for n in in_nodes: if 'embedded_input_' not in n.id: node = n self.assertFalse(n.has('is_partial_inferred')) if not len(in_nodes): break
def infer(node: Node): assert len(node.in_nodes()) == len(__class__.inputs) + len( __class__.extra_inputs) for axis in ['concat_axis', 'split_axis']: axis_node = __class__.extra_inputs.index(axis) + len( __class__.inputs) assert node.in_node(axis_node).has_valid('value') assert node.in_node(axis_node).value == 1 shift_const = node.in_node( __class__.extra_inputs.index('shift_const') + len(__class__.inputs)) assert shift_const.has_valid('value') shift_const = shift_const.value assert shift_const.ndim == 0 # expect scalar value node['shift_const'] = shift_const.copy() weights_node = node.in_node(__class__.inputs.index('weights')) biases_node = node.in_node(__class__.inputs.index('biases')) assert weights_node.has_valid('value') assert biases_node.has_valid('value') # Restore original infer function (to avoid calling previous code twice) and call it node.infer = node.old_infer node.infer(node)
def infer(node: Node): node_name = node.soft_get('name', node.id) assert node.with_right_bound is not None, \ "Attribute \"with_right_bound\" is not defined" assert len(node.in_nodes()) == 2, \ "Incorrect number of inputs for {} node".format(node.id) if node.get_opset() != "extension": assert node.has_valid('output_type'), \ '`output_type` attribute is not set for Bucketize node `{}`'.format(node_name) assert node.output_type in [np.int64, np.int32], \ 'Bucketize `output_type` attribute must be int32 or int64, `{}` found'.format(np.dtype(node.output_type).name) output_shape = node.in_port(0).data.get_shape() node.out_port(0).data.set_shape(output_shape) input_value = node.in_port(0).data.get_value() buckets_value = node.in_port(1).data.get_value() # compute if all input is constant if input_value is not None and buckets_value is not None: node.out_port(0).data.set_value( mo_array(np.digitize(input_value, buckets_value, right=node.with_right_bound), dtype=node.output_type))
def infer(node: Node): assert len(node.in_nodes()) == 4 # check that shape value is defined that is needed for shape inference shape = node.in_node(2) assert shape.value is not None and shape.value.size == 2, \ "SparseFillEmptyRows is supported only with constant shape value" shape_value = int64_array(shape.value) # check that default value is scalar default_value = node.in_node(3) assert default_value.shape is not None and len(default_value.shape) == 0, \ "Default value for SparseFillEmptyRows must be scalar" if node.is_out_port_connected(0): # set a shape for output indices if is_fully_defined(shape_value): node.out_port(0).data.set_shape([np.prod(shape_value), 2]) else: node.out_port(0).data.set_shape([dynamic_dimension_value, 2]) if node.is_out_port_connected(1): # set a shape for output values if is_fully_defined(shape_value): node.out_port(1).data.set_shape([np.prod(shape_value)]) else: node.out_port(1).data.set_shape([dynamic_dimension_value]) if node.is_out_port_connected( 2): # set a shape for empty row indicator node.out_port(2).data.set_shape([shape_value[0]])
def test_remove_softmax_activation_input(self): graph = build_graph( { 'node_1': { 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter' }, 'softmax': { 'type': 'SoftmaxActivation', 'value': None, 'kind': 'op', 'op': 'SoftmaxActivation' }, }, [('node_1', 'softmax')]) pattern = CheckSoftmaxNodeInputs() pattern.find_and_replace_pattern(graph) node_softmax = Node(graph, 'softmax') self.assertEqual(len(node_softmax.in_nodes()), 1) node_input1 = node_softmax.in_node(0) self.assertEqual(node_input1.name, 'node_1')
def infer(node: Node): node_name = node.soft_get('name', node.id) connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] assert len(connected_in_ports) in [4, 5], \ "Incorrect number of inputs for {} node".format(node_name) logits_shape = node.in_port(0).data.get_shape() logit_length_shape = node.in_port(1).data.get_shape() labels_shape = node.in_port(2).data.get_shape() label_length_shape = node.in_port(3).data.get_shape() blank_index_shape = int64_array([]) if len(node.in_nodes()) == 5: blank_index_shape = node.in_port(4).data.get_shape() # check shapes of input tensors assert len(logits_shape) == 3 and len(logit_length_shape) == 1 and len(labels_shape) == 2\ and len(label_length_shape) == 1 and len(blank_index_shape) == 0, \ 'Incorrect rank of some input tensor for {} node'.format(node_name) assert compatible_dims(logits_shape[0], logit_length_shape[0]) and \ compatible_dims(logits_shape[0], labels_shape[0]) and \ compatible_dims(logits_shape[0], label_length_shape[0]), \ 'Batch dimensions of input tensors must be the same for {} node'.format(node_name) assert compatible_dims(logits_shape[1], labels_shape[1]), \ 'Time dimensions of input tensors must be the same for {} node'.format(node_name) batch_size = logits_shape[0] node.out_port(0).data.set_shape([batch_size])
def add_unsqueeze_for_new(graph: Graph, ss_node: Node): log.info( "StridedSlice op with new axis mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = mo_array(range(len(ss_node['new_axis_mask'])))[mo_array( ss_node['new_axis_mask'], dtype=bool)] ss_shape = [] for i in range(0, len(ss_node['new_axis_mask'])): if not ss_node['new_axis_mask'][i]: ss_shape.append(shape_out[i]) else: ss_node['new_axis_mask'][i] = 0 ss_node.out_port(0).data.set_shape(ss_shape) # insert Unsqueeze unsqueeze_node = Unsqueeze(graph, dict(name=ss_node.name + '/Unsqueeze_new')).create_node() ss_node.out_port(0).get_connection().insert_node(unsqueeze_node) unsqueeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': unsqueeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(unsqueeze_node.in_port(1))
def infer(node: Node): # there are limitations coming from ONNX LSTM definition and normalization rules assert len(node.in_nodes()) >= 3 # X, W and R assert len(node.in_nodes()) <= 7 assert len(node.out_nodes()) <= 3 rnn_infer(node, [1, 2])
def infer(node: Node): """ MO input edges: | Description: ------------------------------------------------- 0 | x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs) 1 | w: The weight matrix 2 | b: The bias vector 3 | h_prev: Previous/initial hidden state 4 | cs_prev: Value of the initial cell state """ assert len(node.in_nodes()) == 5 """ MO output edges: | Description: 0 | cs: Output data / output hidden states concatenated over the whole time sequence 1 | h: Output cell states concatenated over the whole time sequence """ assert len(node.out_nodes()) in [1, 2] mark_input_bins(node) input_shape = node.in_node(0).shape assert len(input_shape) == 3 out_shape = input_shape.copy() node.out_port(0).data.set_shape(out_shape) if node.is_out_port_connected(1): node.out_port(1).data.set_shape(out_shape)
def infer(node: Node): node_name = node.soft_get('name', node.id) connected_in_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] assert len(connected_in_ports) in [2, 3], \ "Incorrect number of inputs for {} node".format(node_name) logits_shape = node.in_port(0).data.get_shape() sequence_len_shape = node.in_port(1).data.get_shape() if len(node.in_nodes()) == 3: blank_index_shape = node.in_port(2).data.get_shape() assert len(blank_index_shape) == 1, \ 'Incorrect rank of blank_index for {} node'.format(node_name) # check shapes of input tensors assert len(logits_shape) == 3, \ 'Incorrect rank of logits for {} node'.format(node_name) assert len(sequence_len_shape) == 1, \ 'Incorrect rank of sequence length tensor for {} node'.format(node_name) assert compatible_dims(logits_shape[0], sequence_len_shape[0]), \ 'Batch dimensions of input tensors must be the same for {} node'.format(node_name) batch_size = logits_shape[0] time_size = logits_shape[1] if node.is_out_port_connected(0): node.out_port(0).data.set_shape([batch_size, time_size]) if node.is_out_port_connected(1): node.out_port(1).data.set_shape([batch_size])
def restore_tensor_names(op: Node): for out_port in op.ports: # op.ports is our internal attribute, dictionary, where keys are numbers of output ports # and values are tuples with shape and tensor name: # {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name, out_port_idx_1_rt_info), # out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name, out_port_idx_2_rt_info)} out_tensor_names = op.ports[out_port][1] # handle Constant operations with old style output port numbering if op.soft_get('type') == 'Const': assert len(op.ports) == 1, 'Something wrong with Constant node: {}, wrong number ' \ 'of output ports: {}!'.format(op.soft_get('name'), len(op.ports)) out_port = 0 out_port = out_port - len(op.in_nodes()) if out_tensor_names is not None: # handle tensor names with commas and add them to dictionary as separate items if out_tensor_names.find(',') >= 0: str_to_replace = '<comma_in_tensor_name>' out_tensor_names = (out_tensor_names.replace( '\\,', str_to_replace)).split(',') op.out_node(out_port)['fw_tensor_debug_info'] = [] for out_tensor_name in out_tensor_names: out_tensor_name = out_tensor_name.replace( str_to_replace, ',') op.out_node(out_port)['fw_tensor_debug_info'].append( (out_tensor_name, out_tensor_name)) else: op.out_node(out_port)['fw_tensor_debug_info'] = [ (out_tensor_names, out_tensor_names) ]
def replace_op(self, graph: Graph, node: Node): matmul = MatMul(graph, dict(name=node.name, transpose_b=True)).create_node([node.in_node(0), node.in_node(1)]) # Bias if len(node.in_nodes()) > 2: matmul = Add(graph, dict(name=node.name + '/bias')).create_node([matmul, node.in_node(2)]) return [matmul.id]
def infer(node: Node): # check a number of input/output edges assert len(node.in_nodes()) == 3 assert len(node.out_nodes()) == 1 data_shape = node.in_port(0).data.get_shape() indices_shape = node.in_port(1).data.get_shape() segment_ids_shape = node.in_port(2).data.get_shape() data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() segment_ids_value = node.in_port(2).data.get_value() # check input shapes assert data_shape is not None, \ "Shape for input data tensor to SparseSegmentSqrtN must be defined" assert indices_shape is not None and indices_shape.size == 1, \ "SparseSegmentSqrtN supports only 1D indices tensor" assert segment_ids_shape is not None and segment_ids_shape.size == 1, \ "SparseSegmentSqrtN supports only 1D segment IDs tensor" assert compatible_shapes(segment_ids_shape, indices_shape), \ "Indices and segment IDs tensors must have compatible shapes" # computes output shape output_shape = data_shape output_shape[0] = segment_ids_shape[0] node.out_port(0).data.set_shape(output_shape) # infer if all input is constant if data_value is None or indices_value is None or segment_ids_value is None: return # check that values in segment_ids are sorted for i in range(1, len(segment_ids_value)): assert segment_ids_value[i-1] <= segment_ids_value[i], \ "Values in segment IDs are not sorted" num_segments = int(segment_ids_value[-1]) + 1 # check that indices are in a range [0, data_shape[0]) assert np.all(indices_value >= 0) and np.all(indices_value < data_shape[0]), \ "Some value in indices tensor is out of range" # infer num_adds = np.zeros(num_segments, dtype=np.int) output_value = np.zeros([num_segments] + data_shape[1:].tolist(), dtype=np.float32) output_shape = output_value.shape for i in range(len(segment_ids_value)): segment_id = int(segment_ids_value[i]) indice = int(indices_value[i]) output_value[segment_id, :] += data_value[indice, :] num_adds[segment_id] += 1 num_adds = np.sqrt(num_adds) for segment_id in range(num_segments): if num_adds[segment_id] != 0: output_value[segment_id, :] /= num_adds[segment_id] node.out_port(0).data.set_shape(output_shape) node.out_port(0).data.set_value(output_value)
def get_value_id(node: Node): assert node.has_valid('op') value_id = None for port, in_node in node.in_nodes().items(): if in_node.has_valid('value'): if value_id: return None value_id = port return value_id
def get_tensor_id(node: Node): assert node.has_valid('op') tensor_id = None for port, in_node in node.in_nodes().items(): if not in_node.has_valid('value'): if tensor_id: return None tensor_id = port return tensor_id
def control_flow_infer(node: Node, is_executable: bool, mark_executability: callable): in_data_nodes = node.in_nodes(control_flow=True) out_data_nodes = node.out_nodes(control_flow=True) is_executable = any( [d.has_and_set('executable') for i, d in in_data_nodes.items( )] if len(in_data_nodes) else [False]) for i, d in out_data_nodes.items(): mark_executability(d.id, is_executable)
def get_fw_tensor_debug_info(node: Node): while not node.has_valid('fw_tensor_debug_info') and not node.has_valid('output_sort_order') \ and len(node.in_nodes()): try: node = node.in_node() except Exception as e: log.warning('Was not able to determine tensor debug info for node {}'.format(node.name)) return "dummy_node_name" if node.has_valid('output_sort_order'): return node.soft_get('output_sort_order') return node.soft_get('fw_tensor_debug_info')
def add_squeeze_for_shrink(graph: Graph, ss_node: Node): # add Squeeze for shrink_axis_mask log.info( "StridedSlice op with shrink mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = mo_array(range(len(ss_node['shrink_axis_mask'])))[mo_array( ss_node['shrink_axis_mask'], dtype=bool)] ss_shape = [] i = 0 k = 0 # Don't permute reshape if channels were squeezed dont_permute = graph.graph['layout'] == 'NCHW' if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][ -1] == 1: dont_permute = True while k < len(shape_out): if i >= len(ss_node['shrink_axis_mask'] ) or not ss_node['shrink_axis_mask'][i]: ss_shape.append(shape_out[k]) k = k + 1 else: ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 while i < len(ss_node['shrink_axis_mask']): ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 ss_node.out_port(0).data.set_shape(ss_shape) # insert Squeeze squeeze_node = Squeeze( graph, dict(name=ss_node.name + '/Squeeze_shrink', nchw_layout=dont_permute, correct_data_layout=dont_permute)).create_node() ss_node.out_port(0).get_connection().insert_node(squeeze_node) squeeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': squeeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(squeeze_node.in_port(1))
def infer(node: Node): assert len(node.in_nodes()) == 1 assert node.fill_value is not None assert node.input_as_shape shape = node.in_port(0).data.get_value() assert shape is not None if is_fully_defined(shape): node.out_port(0).data.set_value(np.full(shape, node.fill_value, np.float32)) else: node.out_port(0).data.set_shape(shape)
def array_infer(node: Node): assert len(node.in_nodes()) == 3 handle = node.in_node(0) ta_node = Node(node.graph, str(handle.value)) assert ta_node.has_valid('element_shape') for _, out_node in node.graph.out_edges(node.id): node.graph.node[out_node]['shape'] = shape_array( ta_node['element_shape']) node.graph.node[out_node]['value'] = None
def infer(node: Node): real_squeeze_dims = int64_array([]) input_shape = node.in_port(0).data.get_shape() node_name = node.soft_get('name', node.id) if input_shape is None: raise Error( 'Input shape is not defined for node {}'.format(node_name)) output_shape = input_shape.copy() assert len(node.in_nodes( )) == 2, 'The Squeeze node {} must have 2 inputs'.format(node_name) # TODO remove the following 'if' statement when IE start support 0D tensors squeeze_dims = node.in_port(1).data.get_value() if squeeze_dims.ndim == 0: squeeze_dims = squeeze_dims.reshape([1]) for dim in squeeze_dims: if output_shape[dim] == 1 or output_shape[dim] is dynamic_dimension: real_squeeze_dims = np.ma.append( real_squeeze_dims, get_canonical_axis_index(output_shape, dim)) else: raise Error( 'Trying to squeeze dimension not equal to 1 for node "{}"'. format(node_name)) # if squeeze_dims empty then all 1s should be removed (tf specification of Squeeze op) if squeeze_dims.size == 0: for i in range(output_shape.size): if output_shape[i] == 1: real_squeeze_dims = np.ma.append( real_squeeze_dims, get_canonical_axis_index(output_shape, i)) assert is_fully_defined( real_squeeze_dims ), 'Squeeze dimension(s) is not defined for op "{}"'.format(node_name) output_shape = shape_delete(output_shape, real_squeeze_dims) node.out_port(0).data.set_shape(output_shape) # make dimensions positive to correctly translate from NHWC to NCHW layout if node.in_port(1).get_source().node.op == 'Const': node.in_port(1).data.set_value(real_squeeze_dims) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( node.in_port(0).data.get_value().reshape(output_shape)) # the squeeze_dim attribute will be converted to the second input in the end of the Middle phase PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
def find_input_port(node: Node, input_desc: list, search_node_name: str, search_node_port: int): if input_desc is None: return len(node.in_nodes()) for in_port, tensor_desc in enumerate(input_desc): for node_pattern, node_port in tensor_desc: if findall(node_pattern, search_node_name) and node_port == search_node_port: return in_port raise Exception( 'Did not find input port of the node "{}" with port "{}"'.format( search_node_name, search_node_port))
def array_infer(node: Node): assert len(node.in_nodes()) == 2 handle = node.in_node(0) ta_node = Node(node.graph, str(handle.value)) assert ta_node.has_valid('size') output_value = mo_array(ta_node['size']) for _, out_node in node.graph.out_edges(node.id): node.graph.node[out_node]['shape'] = shape_array(output_value.shape) node.graph.node[out_node]['value'] = output_value.copy()
def convert_const_node_value_type(const_node: Node, np_data_type): assert const_node.type == 'Const' log.warning('Converting type of Const node "{}" to "{}"'.format(const_node.name, np_data_type)) const_node.value = const_node.value.astype(np_data_type) const_node.data_type = np_data_type const_node.infer(const_node) const_node.type_infer(const_node) # if the Const node has an input data node then need to update it also if len(const_node.in_nodes()) == 1: input_data = const_node.in_node(0) assert input_data.kind == 'data' input_data.value = input_data.value.astype(const_node.data_type) input_data.data_type = const_node.data_type
def infer(node: Node): assert (len(node.in_nodes()) == 3), 'MaxPoolV2 node {} from must have only 3 inputs: input, window size, and ' \ 'strides but instead got {} inputs'.format(node.soft_get('name', node.id), len(node.in_nodes())) node['window'] = node.in_port(1).data.get_value() node['stride'] = node.in_port(2).data.get_value() if node['window'] is None: raise Error( 'The non-constant window size for MaxPoolV2 node {} is not supported' ''.format(node.soft_get('name', node.id))) if node['stride'] is None: raise Error( 'The non-constant strides for MaxPoolV2 node {} is not supported' ''.format(node.soft_get('name', node.id))) Pooling.pool_infer(node)