def transpose_infer(node): if node.order is None and (not node.has_valid('reverse_order') or (node.has_valid('reverse_order') and node.reverse_order == False)): log.error('Cannot infer {} because order is None'.format( node.soft_get('name'))) return if node.has_valid('reverse_order' ) and node.reverse_order and node.has_valid('order'): log.error( 'Cannot infer {} due to both order and reverse_order was set'. format(node.soft_get('name'))) return input_shape = node.in_node(0).shape if node.has_valid('reverse_order') and node.reverse_order: node.order = np.arange(len(input_shape))[::-1] # Reverse order output_shape = np.array([input_shape[i] for i in node.order], dtype=np.int64) node.out_node(0).shape = output_shape if node.in_node().has_valid('value'): node.out_node().value = np.transpose(node.in_node().value, axes=node.order) PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
def infer(node): name = node.soft_get('name', node.id) connected_in_ports = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_in_ports) == 1 and 0 in connected_in_ports, \ "AttributedTile should have 1 connected input port, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_in_ports) shape = node.in_port(0).data.get_shape() assert shape is not None, "Undefined input shape for AttributedTile node '{}'.".format( name) axis = node.soft_get('axis', None) assert axis is not None tiles = node.soft_get('tiles', None) assert tiles is not None, "Undefined `tiles` attribute of Tile node '{}'".format( name) tile_array = int64_array(np.ones(shape.size)) tile_array[node.axis] = node.tiles node.out_port(0).data.set_shape(shape * tile_array) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( np.tile(node.in_port(0).data.get_value(), tile_array)) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def test_permute_begin_end_ellipsis(self): # Testing constant path case graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice'), ('begin', 'begin_data'), ('begin_data', 'strided_slice'), ('end', 'end_data'), ('end_data', 'strided_slice'), ('stride', 'stride_data'), ('stride_data', 'strided_slice'), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None}, 'begin': {'value': [0, 1], 'shape': [2]}, 'end': {'value': [1, 0], 'shape': [2]}, 'stride': {'value': [1, 2], 'shape': [2]}, 'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0], 'ellipsis_mask': np.array([1, 0])}, 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None}, }) slice_node = Node(graph, 'strided_slice') slice_node['begin_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 4, list(slice_node['begin_mask']), 0)) permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask') self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0]))) slice_node['end_mask'] = int64_array(extend_mask_according_ellipsis(slice_node['ellipsis_mask'], slice_node['shrink_axis_mask'], 4, list(slice_node['end_mask']), 0)) permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask') self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_data_nodes(): if node.has_and_set('nchw_layout'): continue # Get NHWC to NCHW permutation for N dims, where N = len(node.shape) permutation = PermuteAttrs().get_nhwc_to_nchw_permutation( len(node.shape)) # Check that data node already has permutation skip_permutation = False for in_node in node.in_nodes(): edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True for out_node in node.out_nodes(): edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True if skip_permutation: continue # Set permutation to all in/out edges for in_node in node.in_nodes(): PermuteAttrs.set_permutation(in_node, node, permutation) for out_node in node.out_nodes(): PermuteAttrs.set_permutation(node, out_node, permutation)
def _one_input_infer(node: Node): input_shape = np.array(node.in_node().shape) if input_shape is None: log.error('input_shape is none for {} node'.format(node.name)) return if not node.has_valid('axis'): log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name)) return output_shape = input_shape if node.has_valid('dim'): if len(node.dim) != len(node.axis): log.error('number of axis should match number of dim') return output_shape[node.axis] = node.dim elif node.has_valid('crop_begin') and node.has_valid('crop_end'): if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis): log.error('number of crop_begin/crop_end should match number of axis') return if type(node.axis) in [list, tuple]: for i in range(len(node.axis)): output_shape[node.axis[i]] = output_shape[node.axis[i]] - node.crop_begin[i] - node.crop_end[i] else: output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end else: log.error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node.name)) return node.out_node().shape = np.array(output_shape) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def caffe_inner_product(node): input_shape = node.in_node(0).shape if input_shape is None: return batches = input_shape[0] input_channels = np.prod(input_shape[1:]) if not node.has_valid('out-size'): node['out-size'] = (np.prod(node.in_node(1).shape) / input_channels).astype(np.int64) output_channels = node['out-size'] weights_shape = np.array([output_channels, input_channels], dtype=np.int64) # In case if original weight layout is IO we transpose them if np.array_equal(node.in_node(1).shape, weights_shape[::-1] ) and node.soft_get('transpose_weights') is True: node.in_node(1).value = np.transpose(node.in_node(1).value) node.out_node().shape = np.array([batches, output_channels], dtype=np.int64) # Back propagation of shape to weights node.in_node(1).shape = np.array(weights_shape) node.in_node(1).value.shape = node.in_node(1).shape mark_input_bins(node) assign_dims_to_weights(node.in_node(1), None, 1, 0, 2) PermuteAttrs.set_permutation(node.in_node(1), node, None)
def infer(node): in_ports = node.in_ports() connected_ports = [port for port in in_ports.values() if not port.disconnected()] assert len(connected_ports) == 2, 'The number of inputs to the TopK layer name "{}" must be equal to 2.' \ ''.format(node.soft_get('name')) k = node.in_port(1).data.get_value() if k is None: raise Error('The value defining number of output elements for layer "{}" is not defined' ''.format(node.soft_get('name'))) assert node.has_valid('axis'), 'The "axis" attribute is not defined for node {}'.format(node.name) input_shape = node.in_port(0).data.get_shape() node.axis = len(input_shape) + node.axis if node.axis < 0 else node.axis output_shape = input_shape.copy() output_shape[node.axis] = k PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) # setting shape and value if applicable if not node.out_port(0).disconnected(): node.out_port(0).data.set_shape(output_shape) if not node.out_port(1).disconnected(): node.out_port(1).data.set_shape(output_shape) if node.in_port(0).data.get_value() is not None: # TODO implement value propagation pass
def argmax_infer(node: Node): shape = node.in_node(0).shape if shape is None: return # there are two inputs in TensorFlow. The second input is the axis for ArgMax if len(node.in_nodes()) == 2: if node.in_node(1).value is None: log.debug('The second argument to ArgMax is None') return node.axis = node.in_node(1).value.item() # remove the unnecessary input node.graph.remove_edge(node.in_node(1).id, node.id) num_top_axes = shape.size if num_top_axes < 3: num_top_axes = 3 out_shape = np.ones(num_top_axes, dtype=int) if node.has_valid('axis'): axis = get_canonical_axis_index(shape, node.axis) node.axis = axis out_shape = np.array(shape) out_shape[axis] = node.top_k PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) else: out_shape[0] = shape[0] out_shape[2] = node.top_k if node.out_max_val: out_shape[1] = 2 node.out_node().shape = out_shape
def tf_split_infer(node): """ Partial infer of split node similar to Split op of TF. """ # Two inputs: [split_dim, input] assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name')) split_dim = node.in_node(0).value if split_dim is None: log.error('split_dim value for node {} is None. Cannot do shape inference.') return assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name')) split_dim = split_dim.item() input = node.in_node(1) if input.shape is None: log.error('Input shape for node {} is not defined'.format(node.soft_get('name'))) return log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim)) split_dim_size = input.shape[split_dim] log.debug('split_dim_size type = {}'.format(type(split_dim_size))) if split_dim_size % node.num_split != 0: log.error("split_dim cannot be evenly divided by a given number of parts") return # split_dim is a numpy array, axis is split_dim[0] log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format( split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split))) split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split) node.graph.remove_edge(node.in_node(0).id, node.id) node['input_port'] = 1 PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])
def conv_flatten_concat_action(graph: Graph, match: dict): assert graph.graph['layout'] == 'NHWC' reshape_node = match['reshape'] reshape_data_node = match['reshape_data'] conv_name = match['conv'].name conv_data_node = match['conv_data'] # the pattern should be applied only in case when the reshape operation changes number of dimensions if len(reshape_data_node.shape) == len( conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'): return if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \ reshape_data_node.out_node().type == 'FullyConnected' and \ can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()): log.info( 'There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no ' 'need to insert Permute'.format(reshape_node.soft_get('name'))) return graph.remove_edge(conv_data_node.id, reshape_node.id) permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation( len(conv_data_node.shape)).perm new_permute_op = Permute(graph, {'order': permutation_order}) permute_data_node = new_permute_op.create_node_with_data( [conv_data_node], dict(name=conv_name + '/Permute_')) graph.create_edge(permute_data_node, reshape_node) # Disable permutation for Reshape and Concat layers attributes PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None) reshape_node['nchw_layout'] = True
def tf_squeeze_infer(node): if node.squeeze_dims is None: # TODO: implement; there is no implementation now because no test return real_squeeze_dims = [] input_shape = node.in_node().shape if input_shape is None: return # UGLY output_shape = input_shape.copy() for n in node.squeeze_dims: if output_shape[n] == 1: real_squeeze_dims.append(get_canonical_axis_index(output_shape, n)) else: raise Error('Trying to squeeze dimension not equal to 1 for node "{}"'.format(node.soft_get('name'))) output_shape = np.delete(output_shape, real_squeeze_dims) node.out_node().shape = output_shape if is_spatial_squeeze(node.graph.graph['layout'], input_shape, output_shape): output_shape = int64_array([0, -1]) node['dim'] = output_shape if node.in_node().value is not None: node.out_node().value = np.array(np.reshape(node.in_node().value, output_shape)) PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
def tf_expand_dims_infer(node): input_node = node.in_nodes()[0] output_node = node.out_node() if input_node.shape is None: return # TensorFlow style with dynamic input if len(node.in_nodes()) > 1: axis_node = node.in_nodes()[1] if isinstance(axis_node.value, np.ndarray) and axis_node.value.size > 1: log.error("ExpandDims operation : axis should be scalar") return expand_axis = axis_node.value.item() node.graph.remove_edge(axis_node.id, node.id) else: if not node.has_valid('expand_axis'): log.error("ExpandDims axis is not defined") return expand_axis = node.expand_axis if expand_axis is None: return output_node.shape = np.insert(input_node.shape, expand_axis, [1]) # convert data type of the shape to int64 explicitly output_node.shape = output_node.shape.astype(np.int64) if input_node.value is not None: output_node.value = np.array(np.reshape(input_node.value, output_node.shape)) node['dim'] = output_node.shape PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
def apply_nhwc_to_nchw_permutation(graph: Graph): # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation) if graph.graph['layout'] == 'NCHW': return for node in graph.get_data_nodes(): if node.has_and_set('nchw_layout'): continue # Get NHWC to NCHW permutation for N dims, where N = len(node.shape) permutation = PermuteAttrs().get_nhwc_to_nchw_permutation( len(node.shape)) # Check that data node already has permutation skip_permutation = False for in_node in node.in_nodes(): edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True for out_node in node.out_nodes(): edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True if skip_permutation: continue # Set permutation to all in/out edges for in_node in node.in_nodes(): PermuteAttrs.set_permutation(in_node, node, permutation) for out_node in node.out_nodes(): PermuteAttrs.set_permutation(node, out_node, permutation)
def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \ "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_in_ports) axis = node.soft_get('axis', None) assert axis is not None data_shape = node.in_port(0).data.get_shape() assert data_shape is not None indices_shape = node.in_port(1).data.get_shape() assert indices_shape is not None # Convert negative axis axis = get_canonical_axis_index(data_shape, axis) node.axis = axis PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None: node.out_port(0).data.set_value(np.array(np.take(data_value, indices_value, axis), dtype=data_value.dtype)) return shape = np.concatenate((data_shape[:axis], indices_shape)) if axis < len(data_shape) - 1: shape = np.concatenate((shape, data_shape[axis + 1:])) node.out_port(0).data.set_shape(int64_array(shape))
def infer(node: Node): tf_strided_slice_infer(node) if node.graph.graph['layout'] == 'NHWC' and node.out_port( 0).data.get_value() is None: PermuteAttrs.create_permute_attrs( node, attrs=[ ('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) for i in range(1, len(node.in_nodes())): if node.in_node( i).value is not None and node.in_node(i).shape[0] > 3: perm = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.in_node(0).shape)) node.in_node(i).value = permute_array_with_ellipsis( node, perm, node.in_node(i).value, 0) # due to permutation from nhwc to nchw we will extend all masks and inputs idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0
def infer(node: Node): layout = node.graph.graph['layout'] assert len(layout) == 4 assert len( [p for p in node.in_ports().values() if not p.disconnected()]) assert node.has_valid('mode') assert node.has_valid('axes') src_shape = node.in_port(0).data.get_shape() assert src_shape is not None dst_shape = node.in_port(1).data.get_value() assert dst_shape is not None out_height = dst_shape[0] out_width = dst_shape[1] node.out_node().shape = shape_for_layout( layout, batch=src_shape[get_batch_dim(layout, 4)], features=src_shape[get_features_dim(layout, 4)], height=out_height, width=out_width) PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(type='StridedSlice'): StridedSliceNormalizer.normalize_strided_slice(graph, node) PermuteAttrs.create_permute_attrs( node, attrs=[ ('begin_mask', 'input:0'), # but indeed depends from slice_rank ('end_mask', 'input:0'), ('new_axis_mask', 'input:0'), ('shrink_axis_mask', 'input:0'), ('ellipsis_mask', 'input:0') ]) # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes # Until now it was not possible to set correct permutations PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size') if node.is_in_port_connected(3): PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
def infer(node): name = node.soft_get('name', node.id) op = node.soft_get('op', None) assert op is not None and op in ['Split', 'AttributedSplit'], \ 'Unexpected `op`={} attribute for Split-like node {}'.format(op, name) num_in_ports = 1 if op == 'AttributedSplit' else 2 if op == 'Split' else None assert num_in_ports in [1, 2], \ 'SplitBase supports AttributedSplit with 1 input and Split with 2 inputs, but it is {} for {} node {}' \ ''.format(num_in_ports, op, name) connected_inputs = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \ "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \ "".format(op, num_in_ports, name, connected_inputs) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None, 'Input shape is unknown for node {}'.format( name) assert node.has_valid( 'num_splits' ), 'Parameter `num_splits` is unknown for node {}'.format(name) num_splits = node.num_splits axis = node.in_port(1).data.get_value( ) if op == 'Split' else node.soft_get('axis', None) assert axis is not None, '{} `axis` is unknown for node {}'.format( op, name) assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format( op, name) assert input_shape[axis] % num_splits == 0, \ 'Input shape is not evenly divided by `num_splits` of {} node {}. `input_shape`={}, `axis`={}, ' \ '`num_splits`={}'.format(op, name, input_shape, axis, num_splits) out_shape = input_shape.copy() out_shape[axis] = np.int64(input_shape[axis] / num_splits) input_value = node.in_port(0).data.get_value() output_value = np.split(input_value.copy(), axis=axis, indices_or_sections=num_splits) \ if input_value is not None else None for idx, port in node.out_ports().items(): if idx in node.out_nodes(): port.data.set_shape(out_shape) if output_value is not None: port.data.set_value(output_value[idx]) if op == 'Split': PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') elif op == 'AttributedSplit': PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def find_and_replace_pattern(self, graph: Graph): if graph.graph['layout'] != 'NHWC': # we check it here because this transformation is called explicitly from the pipeline return # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) input_shape = reinterp_shape_node.in_node(0).shape if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4: order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm }).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose' }).create_node() reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) order_const.infer(order_const) # do not infer the Transpose node because it should have input data node in NCHW layout (but currently # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout # (which is true at this moment) permute_node['need_shape_inference'] = False # mark the Transpose output data node having correct layout so it's shape will not be permuted mark_output_as_in_correct_layout(permute_node, 0) # keep the reinterp_shape_node in NHWC layout mark_input_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) output_shape = reinterp_shape_node.out_node(0).shape if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4: order_const = Const(graph, { 'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node() reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose # will convert it to the NCHW mark_input_as_in_correct_layout(permute_node, 0) mark_input_as_in_correct_layout(permute_node, 1) # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be # permuted # keep the reinterp_shape_node in NHWC layout mark_output_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the # rest of the graph consistent permute_node['need_shape_inference'] = False
def infer(node): name = node.soft_get('name', node.id) assert node.has_valid('shape'), \ 'Parameter node {} should have `shape` attribute. Please use cli options to set model input shape' \ ''.format(name) node.out_port(0).data.set_shape(node.shape) PermuteAttrs.create_permute_attrs(node, attrs=[('shape', 'output:0')])
def infer(node: Node): shape = node.in_node().shape if shape is None: log.error( "Undefined shape for the input tiles for the Tile operation '{}'." .format(node.node)) return shape = np.copy(shape) if len(node.in_nodes()) == 2: tile_array = node.in_node(1).value if tile_array is None: log.error('A tile values are None for a node "{}".'.format( node.name)) return if len(shape) != len(tile_array): log.error('Shape mismatch for a node "{}": {} vs {}.'.format( node.name, shape.shape, tile_array.shape)) return non_one_tile = np.argwhere(tile_array != 1) if len(non_one_tile) == 0: log.info( 'Redundant "Tile" operation "{}" with tile values for all dimensions equal to 1.' .format(node.name)) node['axis'] = 0 node['tiles'] = 1 elif len(non_one_tile) == 1: node['axis'] = non_one_tile[0][0] node['tiles'] = tile_array[node['axis']] else: node['type'] = None node['tile_array'] = tile_array log.warning( "Tile operation with more than one dimension not equal to 1 is not supported." ) # do not return here to allow infer shape and values for the constant propagation case node.graph.remove_edge(node.in_node(1).id, node.id) elif len( node.in_nodes() ) == 1: # case when tiled dimension and count are specified in node attributes if not node.has_valid('axis') or not node.has_valid('tiles'): log.error( 'Mandatory attributes "axis" or "tiles" are not specified for a Tile node "{}"' .format(node.name)) return tile_array = np.ones([len(shape)], dtype=np.int64) tile_array[node.axis] = node.tiles else: log.error( 'Unsupported number of input parameters to Tile node "{}"'. format(node.name)) return PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) node.out_node().shape = shape * tile_array if node.in_node(0).value is not None: node.out_node().value = np.tile(node.in_node(0).value, tile_array)
def infer(node: Node): input_node = node.in_node(0) outputs = node.out_nodes() out_shape = copy.copy(input_node.shape) out_shape[node.axis] = np.int64(input_node.shape[node.axis] / node.num_split) for idx, output in outputs.items(): output.shape = out_shape PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\ 'LogSoftmax node with id {} have more than one port connected'.format(node.id) if node.axis < 0: node.axis = len(node.in_port(0).data.get_shape()) + node.axis assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\ 'LogSoftmax node with id {} has wrong axis attribute'.format(node.id) copy_shape_infer(node) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): node['order'] = list(range(node.in_node().shape.size)) node.order[node.dim2], node.order[node.dim1] = node.order[node.dim1], node.order[node.dim2] input_shape = node.in_port(0).data.get_shape().copy() node.out_port(0).data.set_shape(input_shape[node.order]) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value(np.transpose(node.in_port(0).data.get_value(), axes=node.order)) PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
def tf_reshape_shape_infer(node): # TODO Make sure that all -1 are handled correctly # We cannot simply copy shape argument to the output, # because if -1 appears, it should be substituted by a real # value from input shape if input shape is completely defined. if node.in_node(0).shape is None: return None input_shape = node.in_node(0).shape reshape_output = node.in_node(1).value if len( node.in_nodes()) > 1 else node.dim if node.in_node(0).shape is None: return None total = 1 for index, i in enumerate(input_shape): total *= i res = 1 for index, x in enumerate(reshape_output): if x == 0: res *= input_shape[index] elif x != -1: res *= x new_dim = total // res output_shape = [] for index, x in enumerate(reshape_output): if x == 0: output_shape.append(input_shape[index]) elif x == -1: output_shape.append(new_dim) else: output_shape.append(x) out_shape_total = 1 for index, i in enumerate(output_shape): assert i != -1 out_shape_total *= i if total != out_shape_total: raise Error( "Number of elements in input {} and output {} of reshape node {} mismatch" .format(input_shape, output_shape, node.name)) PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')]) output_shape = int64_array(output_shape) # In case if Reshape operation was created with two inputs and dim attr wasn't set, we set in automatically if not node.has_valid('dim'): node['dim'] = output_shape return output_shape
def _two_inputs_infer(node: Node): N = len(node.in_nodes()) shapes = [node.in_node(i).shape for i in range(N)] if any(s is None for s in shapes): log.error('Not all input shapes were defined for {} node'.format(node.name)) return if not node.has_valid('axis'): log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name)) return if not node.has_valid('offset'): log.error('offset attribute is missing for {} node. should be set in crop extractor'.format(node.name)) return input_shape = np.array(shapes[0]) start_axis = get_canonical_axis_index(input_shape, node.axis) node.axis = start_axis reference_shape = np.array(shapes[1]) input_dim = input_shape.size # set new shape to current shape new_shape = input_shape.copy() ir_axis = [] ir_offset = [] dim = [] for i in range(0, input_dim): if i < start_axis: new_shape[i] = input_shape[i] continue crop_offset = 0 if len(node.offset) == 1: crop_offset = node.offset[0] elif len(node.offset) > 1: crop_offset = node.offset[i - start_axis] if input_shape[i] - crop_offset < reference_shape[i]: log.error('The crop for dimension is out of bounds in ' + node.node) return dim.append(reference_shape[i]) ir_axis.append(i) ir_offset.append(crop_offset) new_shape[i] = reference_shape[i] node.axis = ir_axis node.offset = ir_offset node['dim'] = dim node.out_node().shape = new_shape PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node): PermuteAttrs.create_permute_attrs(node, attrs=[('pads', 'input:0')]) num_of_inputs = len(node.in_nodes()) if node.has_valid('pads'): assert num_of_inputs == 1, "Pad operation has pads attribute and unexpected additional input " \ "argument for node {}.".format(node.name) else: assert num_of_inputs >= 2, "Missing required second input argument for node {} and pads attribute " \ "is missing.".format(node.name) node['pads'] = node.in_node(1).value if num_of_inputs in [3, 4]: pads_begin = node.in_node(1).value pads_end = node.in_node(2).value node['pads'] = np.concatenate( (pads_begin.reshape(-1, 1), pads_end.reshape(-1, 1)), 1) node['fill_value'] = node.in_node( 3).value if num_of_inputs == 4 else 0.0 padding = node.pads input_shape = node.in_node(0).shape if padding is None or input_shape is None: log.error('The paddings are not defined for node "{}"'.format( node.soft_get('name'))) return # paddings can be defined, partially defined or undefined # TODO for now we only handle fully defined paddings # That means that intermediate tensor that delivers padding # should have defined value and size Nx2 # TODO possible broadcasts are not supported assert (padding.ndim == 2 and padding.shape[1] == 2) # make sure that input has the same number of dimensions as the number of padding dimensions assert (padding.shape[0] == len(input_shape)), \ "Input tensor shape {} and pads values {} do not match for Pad node {}".format( input_shape, padding.shape, node.name ) # sum low and high padding values to calculate the shape modification vector shape_change = np.add.reduce(padding, 1) assert (shape_change.shape == input_shape.shape) # preserve non-positive values in the input shape, because it has a special meaning shape = np.array([ shape_change[i] + input_shape[i] if input_shape[i] > 0 else input_shape[i] for i in range(len(input_shape)) ]) assert len(node.out_nodes()) == 1 node.out_node().shape = shape
def infer(node: Node): tf_strided_slice_infer(node) out_shape = node.out_port(0).data.get_shape() assert out_shape is not None, \ 'Output shape was not calculated for node {}'.format(node.name) # extend inputs according to ellipsis mask and/or input_shape for i_port in node.in_ports().values(): if i_port.idx == 0 or i_port.disconnected(): continue old_value = i_port.data.get_value() # additional check for non-const input # error will be return in shape inference if non-const will be added # it is paranoid check for case if shape inference will be changed assert old_value is not None, \ '{} input of {} node is not constant: \'value\' attribute for edge ' + \ 'contains None'.format(i_port.idx, node.name) # insert 0 for begin and end and 1 for stride new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask, len(out_shape), list(old_value), int(i_port.idx == 3))) # set_value additionally set_shape and propagate value to Const node if not np.array_equal(new_value, old_value): i_port.data.set_value(new_value) # extend masks before removing ellipsis for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]: node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask, len(out_shape), list(node[attr]), 0)) # we will extend all masks and inputs to simplify future transformations idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0 if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None: PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) # permute inputs in_shape = node.in_port(0).get_source().data.get_shape() assert in_shape is not None, \ 'Input shape is unknown for 0 input of node {}'.format(node.name) input_rank = len(in_shape) if input_rank > 3: for i_port in node.in_ports().values(): if i_port.idx == 0 or i_port.disconnected(): continue new_value = permute_array(node, i_port.data.get_value()) # set_value additionally set_shape and propagate value to Const node i_port.data.set_value(new_value)
def replace_pattern(self, graph: Graph, match: dict): if match['axis'].value is None or match['input'].shape is None: return dims = len(match['input'].shape) ones = np.ones(dims, dtype=np.int64) axis = np.array(match['axis'].value) axis = axis if axis.ndim != 0 else np.array([axis], dtype=np.int64) mean = graph.node[match['mean'].node] mean['stride'] = np.array(ones) # TODO: need to check axis with real layout spatial_dims = np.array(axis) mean['spatial_dims'] = spatial_dims mean['pad'] = np.zeros((dims, 2), np.int64) mean['pad_spatial_shape'] = np.array(mean['pad'][spatial_dims]) window = np.array(ones) window[spatial_dims] = match['input'].shape[spatial_dims] mean['window'] = window mean['TF_op'] = mean['op'] mean['op'] = 'AvgPool' mean['pool_method'] = 'avg' mean['rounding_type'] = 'ceil' mean['exclude_pad'] = 'true' mean['kernel_spatial'] = window[spatial_dims] graph.remove_edge(match['axis'].node, match['mean'].node) mean['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[( 'pad', 'input:0'), ('stride', 'input:0'), ('window', 'input:0'), ('spatial_dims', 'input:0')]) if match['mean'].keep_dims == False: output = match['mean'].out_node() pool_node = match['mean'] # Keep dims for AvgPool shape = np.array(output.shape) for idx in spatial_dims: shape = np.insert(shape, idx, 1) graph.remove_edge(pool_node.id, output.id) # Create new data for pool with all dims pool_data = Op.create_data_node(graph, pool_node, {'shape': np.array(shape)}) # Create and connect reshape node reshape_op = Reshape(graph, {'dim': np.array(output.shape)}) reshape_node = reshape_op.create_node( [pool_data], dict(name='Reshape_', permute_attrs=PermuteAttrs().update_attrs( attrs=[('dim', 'output:0')]))) graph.create_edge(reshape_node, output)
def infer(node): input_data_shape = node.in_port(0).data.get_shape() assert input_data_shape is not None assert node.has_valid('seq_axis') assert node.has_valid('batch_axis') assert len(node.out_nodes()) == 1 node.out_port(0).data.set_shape(input_data_shape) PermuteAttrs.create_permute_attrs(node, attrs=[('seq_axis', 'input:0')]) PermuteAttrs.create_permute_attrs(node, attrs=[('batch_axis', 'input:0')])