def transpose_infer(node): if node.order is None and (not node.has_valid('reverse_order') or (node.has_valid('reverse_order') and node.reverse_order == False)): log.error('Cannot infer {} because order is None'.format( node.soft_get('name'))) return if node.has_valid('reverse_order' ) and node.reverse_order and node.has_valid('order'): log.error( 'Cannot infer {} due to both order and reverse_order was set'. format(node.soft_get('name'))) return input_shape = node.in_node(0).shape if node.has_valid('reverse_order') and node.reverse_order: node.order = np.arange(len(input_shape))[::-1] # Reverse order output_shape = np.array([input_shape[i] for i in node.order], dtype=np.int64) node.out_node(0).shape = output_shape if node.in_node().has_valid('value'): node.out_node().value = np.transpose(node.in_node().value, axes=node.order) PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
def infer(node): in_ports = node.in_ports() connected_ports = [port for port in in_ports.values() if not port.disconnected()] assert len(connected_ports) == 2, 'The number of inputs to the TopK layer name "{}" must be equal to 2.' \ ''.format(node.soft_get('name')) k = node.in_port(1).data.get_value() if k is None: raise Error('The value defining number of output elements for layer "{}" is not defined' ''.format(node.soft_get('name'))) assert node.has_valid('axis'), 'The "axis" attribute is not defined for node {}'.format(node.name) input_shape = node.in_port(0).data.get_shape() node.axis = len(input_shape) + node.axis if node.axis < 0 else node.axis output_shape = input_shape.copy() output_shape[node.axis] = k PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) # setting shape and value if applicable if not node.out_port(0).disconnected(): node.out_port(0).data.set_shape(output_shape) if not node.out_port(1).disconnected(): node.out_port(1).data.set_shape(output_shape) if node.in_port(0).data.get_value() is not None: # TODO implement value propagation pass
def argmax_infer(node: Node): shape = node.in_node(0).shape if shape is None: return # there are two inputs in TensorFlow. The second input is the axis for ArgMax if len(node.in_nodes()) == 2: if node.in_node(1).value is None: log.debug('The second argument to ArgMax is None') return node.axis = node.in_node(1).value.item() # remove the unnecessary input node.graph.remove_edge(node.in_node(1).id, node.id) num_top_axes = shape.size if num_top_axes < 3: num_top_axes = 3 out_shape = np.ones(num_top_axes, dtype=int) if node.has_valid('axis'): axis = get_canonical_axis_index(shape, node.axis) node.axis = axis out_shape = np.array(shape) out_shape[axis] = node.top_k PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) else: out_shape[0] = shape[0] out_shape[2] = node.top_k if node.out_max_val: out_shape[1] = 2 node.out_node().shape = out_shape
def tf_split_infer(node): """ Partial infer of split node similar to Split op of TF. """ # Two inputs: [split_dim, input] assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name')) split_dim = node.in_node(0).value if split_dim is None: log.error('split_dim value for node {} is None. Cannot do shape inference.') return assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name')) split_dim = split_dim.item() input = node.in_node(1) if input.shape is None: log.error('Input shape for node {} is not defined'.format(node.soft_get('name'))) return log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim)) split_dim_size = input.shape[split_dim] log.debug('split_dim_size type = {}'.format(type(split_dim_size))) if split_dim_size % node.num_split != 0: log.error("split_dim cannot be evenly divided by a given number of parts") return # split_dim is a numpy array, axis is split_dim[0] log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format( split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split))) split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split) node.graph.remove_edge(node.in_node(0).id, node.id) node['input_port'] = 1 PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(type='StridedSlice'): StridedSliceNormalizer.normalize_strided_slice(graph, node) PermuteAttrs.create_permute_attrs( node, attrs=[ ('begin_mask', 'input:0'), # but indeed depends from slice_rank ('end_mask', 'input:0'), ('new_axis_mask', 'input:0'), ('shrink_axis_mask', 'input:0'), ('ellipsis_mask', 'input:0') ]) # StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes # Until now it was not possible to set correct permutations PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size') if node.is_in_port_connected(3): PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
def tf_squeeze_infer(node): if node.squeeze_dims is None: # TODO: implement; there is no implementation now because no test return real_squeeze_dims = [] input_shape = node.in_node().shape if input_shape is None: return # UGLY output_shape = input_shape.copy() for n in node.squeeze_dims: if output_shape[n] == 1: real_squeeze_dims.append(get_canonical_axis_index(output_shape, n)) else: raise Error('Trying to squeeze dimension not equal to 1 for node "{}"'.format(node.soft_get('name'))) output_shape = np.delete(output_shape, real_squeeze_dims) node.out_node().shape = output_shape if is_spatial_squeeze(node.graph.graph['layout'], input_shape, output_shape): output_shape = int64_array([0, -1]) node['dim'] = output_shape if node.in_node().value is not None: node.out_node().value = np.array(np.reshape(node.in_node().value, output_shape)) PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
def infer(node): name = node.soft_get('name', node.id) connected_in_ports = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_in_ports) == 1 and 0 in connected_in_ports, \ "AttributedTile should have 1 connected input port, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_in_ports) shape = node.in_port(0).data.get_shape() assert shape is not None, "Undefined input shape for AttributedTile node '{}'.".format( name) axis = node.soft_get('axis', None) assert axis is not None tiles = node.soft_get('tiles', None) assert tiles is not None, "Undefined `tiles` attribute of Tile node '{}'".format( name) tile_array = int64_array(np.ones(shape.size)) tile_array[node.axis] = node.tiles node.out_port(0).data.set_shape(shape * tile_array) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value( np.tile(node.in_port(0).data.get_value(), tile_array)) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def tf_expand_dims_infer(node): input_node = node.in_nodes()[0] output_node = node.out_node() if input_node.shape is None: return # TensorFlow style with dynamic input if len(node.in_nodes()) > 1: axis_node = node.in_nodes()[1] if isinstance(axis_node.value, np.ndarray) and axis_node.value.size > 1: log.error("ExpandDims operation : axis should be scalar") return expand_axis = axis_node.value.item() node.graph.remove_edge(axis_node.id, node.id) else: if not node.has_valid('expand_axis'): log.error("ExpandDims axis is not defined") return expand_axis = node.expand_axis if expand_axis is None: return output_node.shape = np.insert(input_node.shape, expand_axis, [1]) # convert data type of the shape to int64 explicitly output_node.shape = output_node.shape.astype(np.int64) if input_node.value is not None: output_node.value = np.array(np.reshape(input_node.value, output_node.shape)) node['dim'] = output_node.shape PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')])
def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \ "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_in_ports) axis = node.soft_get('axis', None) assert axis is not None data_shape = node.in_port(0).data.get_shape() assert data_shape is not None indices_shape = node.in_port(1).data.get_shape() assert indices_shape is not None # Convert negative axis axis = get_canonical_axis_index(data_shape, axis) node.axis = axis PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None: node.out_port(0).data.set_value(np.array(np.take(data_value, indices_value, axis), dtype=data_value.dtype)) return shape = np.concatenate((data_shape[:axis], indices_shape)) if axis < len(data_shape) - 1: shape = np.concatenate((shape, data_shape[axis + 1:])) node.out_port(0).data.set_shape(int64_array(shape))
def infer(node: Node): tf_strided_slice_infer(node) if node.graph.graph['layout'] == 'NHWC' and node.out_port( 0).data.get_value() is None: PermuteAttrs.create_permute_attrs( node, attrs=[ ('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) for i in range(1, len(node.in_nodes())): if node.in_node( i).value is not None and node.in_node(i).shape[0] > 3: perm = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.in_node(0).shape)) node.in_node(i).value = permute_array_with_ellipsis( node, perm, node.in_node(i).value, 0) # due to permutation from nhwc to nchw we will extend all masks and inputs idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0
def _one_input_infer(node: Node): input_shape = np.array(node.in_node().shape) if input_shape is None: log.error('input_shape is none for {} node'.format(node.name)) return if not node.has_valid('axis'): log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name)) return output_shape = input_shape if node.has_valid('dim'): if len(node.dim) != len(node.axis): log.error('number of axis should match number of dim') return output_shape[node.axis] = node.dim elif node.has_valid('crop_begin') and node.has_valid('crop_end'): if len(node.crop_begin) != len(node.axis) or len(node.crop_end) != len(node.axis): log.error('number of crop_begin/crop_end should match number of axis') return if type(node.axis) in [list, tuple]: for i in range(len(node.axis)): output_shape[node.axis[i]] = output_shape[node.axis[i]] - node.crop_begin[i] - node.crop_end[i] else: output_shape[node.axis] = output_shape[node.axis] - node.crop_begin - node.crop_end else: log.error('Crop node {} should have either dim or crop_begin and crop_end attributes'.format(node.name)) return node.out_node().shape = np.array(output_shape) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): layout = node.graph.graph['layout'] assert len(layout) == 4 assert len( [p for p in node.in_ports().values() if not p.disconnected()]) assert node.has_valid('mode') assert node.has_valid('axes') src_shape = node.in_port(0).data.get_shape() assert src_shape is not None dst_shape = node.in_port(1).data.get_value() assert dst_shape is not None out_height = dst_shape[0] out_width = dst_shape[1] node.out_node().shape = shape_for_layout( layout, batch=src_shape[get_batch_dim(layout, 4)], features=src_shape[get_features_dim(layout, 4)], height=out_height, width=out_width) PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
def infer(node): name = node.soft_get('name', node.id) op = node.soft_get('op', None) assert op is not None and op in ['Split', 'AttributedSplit'], \ 'Unexpected `op`={} attribute for Split-like node {}'.format(op, name) num_in_ports = 1 if op == 'AttributedSplit' else 2 if op == 'Split' else None assert num_in_ports in [1, 2], \ 'SplitBase supports AttributedSplit with 1 input and Split with 2 inputs, but it is {} for {} node {}' \ ''.format(num_in_ports, op, name) connected_inputs = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_inputs) == num_in_ports and all([i in connected_inputs for i in range(num_in_ports)]), \ "{} should have {} connected input ports, but it doesn't for node: `{}`. Ports: {}" \ "".format(op, num_in_ports, name, connected_inputs) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None, 'Input shape is unknown for node {}'.format( name) assert node.has_valid( 'num_splits' ), 'Parameter `num_splits` is unknown for node {}'.format(name) num_splits = node.num_splits axis = node.in_port(1).data.get_value( ) if op == 'Split' else node.soft_get('axis', None) assert axis is not None, '{} `axis` is unknown for node {}'.format( op, name) assert axis.ndim == 0, '{} `axis` should be scalar, but it`s not for node {}'.format( op, name) assert input_shape[axis] % num_splits == 0, \ 'Input shape is not evenly divided by `num_splits` of {} node {}. `input_shape`={}, `axis`={}, ' \ '`num_splits`={}'.format(op, name, input_shape, axis, num_splits) out_shape = input_shape.copy() out_shape[axis] = np.int64(input_shape[axis] / num_splits) input_value = node.in_port(0).data.get_value() output_value = np.split(input_value.copy(), axis=axis, indices_or_sections=num_splits) \ if input_value is not None else None for idx, port in node.out_ports().items(): if idx in node.out_nodes(): port.data.set_shape(out_shape) if output_value is not None: port.data.set_value(output_value[idx]) if op == 'Split': PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') elif op == 'AttributedSplit': PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node): name = node.soft_get('name', node.id) assert node.has_valid('shape'), \ 'Parameter node {} should have `shape` attribute. Please use cli options to set model input shape' \ ''.format(name) node.out_port(0).data.set_shape(node.shape) PermuteAttrs.create_permute_attrs(node, attrs=[('shape', 'output:0')])
def infer(node: Node): input_node = node.in_node(0) outputs = node.out_nodes() out_shape = copy.copy(input_node.shape) out_shape[node.axis] = np.int64(input_node.shape[node.axis] / node.num_split) for idx, output in outputs.items(): output.shape = out_shape PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): shape = node.in_node().shape if shape is None: log.error( "Undefined shape for the input tiles for the Tile operation '{}'." .format(node.node)) return shape = np.copy(shape) if len(node.in_nodes()) == 2: tile_array = node.in_node(1).value if tile_array is None: log.error('A tile values are None for a node "{}".'.format( node.name)) return if len(shape) != len(tile_array): log.error('Shape mismatch for a node "{}": {} vs {}.'.format( node.name, shape.shape, tile_array.shape)) return non_one_tile = np.argwhere(tile_array != 1) if len(non_one_tile) == 0: log.info( 'Redundant "Tile" operation "{}" with tile values for all dimensions equal to 1.' .format(node.name)) node['axis'] = 0 node['tiles'] = 1 elif len(non_one_tile) == 1: node['axis'] = non_one_tile[0][0] node['tiles'] = tile_array[node['axis']] else: node['type'] = None node['tile_array'] = tile_array log.warning( "Tile operation with more than one dimension not equal to 1 is not supported." ) # do not return here to allow infer shape and values for the constant propagation case node.graph.remove_edge(node.in_node(1).id, node.id) elif len( node.in_nodes() ) == 1: # case when tiled dimension and count are specified in node attributes if not node.has_valid('axis') or not node.has_valid('tiles'): log.error( 'Mandatory attributes "axis" or "tiles" are not specified for a Tile node "{}"' .format(node.name)) return tile_array = np.ones([len(shape)], dtype=np.int64) tile_array[node.axis] = node.tiles else: log.error( 'Unsupported number of input parameters to Tile node "{}"'. format(node.name)) return PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) node.out_node().shape = shape * tile_array if node.in_node(0).value is not None: node.out_node().value = np.tile(node.in_node(0).value, tile_array)
def infer(node: Node): assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\ 'LogSoftmax node with id {} have more than one port connected'.format(node.id) if node.axis < 0: node.axis = len(node.in_port(0).data.get_shape()) + node.axis assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\ 'LogSoftmax node with id {} has wrong axis attribute'.format(node.id) copy_shape_infer(node) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node: Node): node['order'] = list(range(node.in_node().shape.size)) node.order[node.dim2], node.order[node.dim1] = node.order[node.dim1], node.order[node.dim2] input_shape = node.in_port(0).data.get_shape().copy() node.out_port(0).data.set_shape(input_shape[node.order]) if node.in_port(0).data.get_value() is not None: node.out_port(0).data.set_value(np.transpose(node.in_port(0).data.get_value(), axes=node.order)) PermuteAttrs.create_permute_attrs(node, attrs=[('order', 'input:0')])
def tf_reshape_shape_infer(node): # TODO Make sure that all -1 are handled correctly # We cannot simply copy shape argument to the output, # because if -1 appears, it should be substituted by a real # value from input shape if input shape is completely defined. if node.in_node(0).shape is None: return None input_shape = node.in_node(0).shape reshape_output = node.in_node(1).value if len( node.in_nodes()) > 1 else node.dim if node.in_node(0).shape is None: return None total = 1 for index, i in enumerate(input_shape): total *= i res = 1 for index, x in enumerate(reshape_output): if x == 0: res *= input_shape[index] elif x != -1: res *= x new_dim = total // res output_shape = [] for index, x in enumerate(reshape_output): if x == 0: output_shape.append(input_shape[index]) elif x == -1: output_shape.append(new_dim) else: output_shape.append(x) out_shape_total = 1 for index, i in enumerate(output_shape): assert i != -1 out_shape_total *= i if total != out_shape_total: raise Error( "Number of elements in input {} and output {} of reshape node {} mismatch" .format(input_shape, output_shape, node.name)) PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')]) output_shape = int64_array(output_shape) # In case if Reshape operation was created with two inputs and dim attr wasn't set, we set in automatically if not node.has_valid('dim'): node['dim'] = output_shape return output_shape
def _two_inputs_infer(node: Node): N = len(node.in_nodes()) shapes = [node.in_node(i).shape for i in range(N)] if any(s is None for s in shapes): log.error('Not all input shapes were defined for {} node'.format(node.name)) return if not node.has_valid('axis'): log.error('axis attribute is missing for {} node. should be set in crop extractor'.format(node.name)) return if not node.has_valid('offset'): log.error('offset attribute is missing for {} node. should be set in crop extractor'.format(node.name)) return input_shape = np.array(shapes[0]) start_axis = get_canonical_axis_index(input_shape, node.axis) node.axis = start_axis reference_shape = np.array(shapes[1]) input_dim = input_shape.size # set new shape to current shape new_shape = input_shape.copy() ir_axis = [] ir_offset = [] dim = [] for i in range(0, input_dim): if i < start_axis: new_shape[i] = input_shape[i] continue crop_offset = 0 if len(node.offset) == 1: crop_offset = node.offset[0] elif len(node.offset) > 1: crop_offset = node.offset[i - start_axis] if input_shape[i] - crop_offset < reference_shape[i]: log.error('The crop for dimension is out of bounds in ' + node.node) return dim.append(reference_shape[i]) ir_axis.append(i) ir_offset.append(crop_offset) new_shape[i] = reference_shape[i] node.axis = ir_axis node.offset = ir_offset node['dim'] = dim node.out_node().shape = new_shape PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def infer(node): PermuteAttrs.create_permute_attrs(node, attrs=[('pads', 'input:0')]) num_of_inputs = len(node.in_nodes()) if node.has_valid('pads'): assert num_of_inputs == 1, "Pad operation has pads attribute and unexpected additional input " \ "argument for node {}.".format(node.name) else: assert num_of_inputs >= 2, "Missing required second input argument for node {} and pads attribute " \ "is missing.".format(node.name) node['pads'] = node.in_node(1).value if num_of_inputs in [3, 4]: pads_begin = node.in_node(1).value pads_end = node.in_node(2).value node['pads'] = np.concatenate( (pads_begin.reshape(-1, 1), pads_end.reshape(-1, 1)), 1) node['fill_value'] = node.in_node( 3).value if num_of_inputs == 4 else 0.0 padding = node.pads input_shape = node.in_node(0).shape if padding is None or input_shape is None: log.error('The paddings are not defined for node "{}"'.format( node.soft_get('name'))) return # paddings can be defined, partially defined or undefined # TODO for now we only handle fully defined paddings # That means that intermediate tensor that delivers padding # should have defined value and size Nx2 # TODO possible broadcasts are not supported assert (padding.ndim == 2 and padding.shape[1] == 2) # make sure that input has the same number of dimensions as the number of padding dimensions assert (padding.shape[0] == len(input_shape)), \ "Input tensor shape {} and pads values {} do not match for Pad node {}".format( input_shape, padding.shape, node.name ) # sum low and high padding values to calculate the shape modification vector shape_change = np.add.reduce(padding, 1) assert (shape_change.shape == input_shape.shape) # preserve non-positive values in the input shape, because it has a special meaning shape = np.array([ shape_change[i] + input_shape[i] if input_shape[i] > 0 else input_shape[i] for i in range(len(input_shape)) ]) assert len(node.out_nodes()) == 1 node.out_node().shape = shape
def infer(node: Node): tf_strided_slice_infer(node) out_shape = node.out_port(0).data.get_shape() assert out_shape is not None, \ 'Output shape was not calculated for node {}'.format(node.name) # extend inputs according to ellipsis mask and/or input_shape for i_port in node.in_ports().values(): if i_port.idx == 0 or i_port.disconnected(): continue old_value = i_port.data.get_value() # additional check for non-const input # error will be return in shape inference if non-const will be added # it is paranoid check for case if shape inference will be changed assert old_value is not None, \ '{} input of {} node is not constant: \'value\' attribute for edge ' + \ 'contains None'.format(i_port.idx, node.name) # insert 0 for begin and end and 1 for stride new_value = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask, len(out_shape), list(old_value), int(i_port.idx == 3))) # set_value additionally set_shape and propagate value to Const node if not np.array_equal(new_value, old_value): i_port.data.set_value(new_value) # extend masks before removing ellipsis for attr in ["new_axis_mask", "shrink_axis_mask", "begin_mask", "end_mask", "ellipsis_mask"]: node[attr] = int64_array(extend_mask_according_ellipsis(node.ellipsis_mask, node.shrink_axis_mask, len(out_shape), list(node[attr]), 0)) # we will extend all masks and inputs to simplify future transformations idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0 if node.graph.graph['layout'] == 'NHWC' and node.out_port(0).data.get_value() is None: PermuteAttrs.create_permute_attrs(node, attrs=[('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) # permute inputs in_shape = node.in_port(0).get_source().data.get_shape() assert in_shape is not None, \ 'Input shape is unknown for 0 input of node {}'.format(node.name) input_rank = len(in_shape) if input_rank > 3: for i_port in node.in_ports().values(): if i_port.idx == 0 or i_port.disconnected(): continue new_value = permute_array(node, i_port.data.get_value()) # set_value additionally set_shape and propagate value to Const node i_port.data.set_value(new_value)
def infer(node): input_data_shape = node.in_port(0).data.get_shape() assert input_data_shape is not None assert node.has_valid('seq_axis') assert node.has_valid('batch_axis') assert len(node.out_nodes()) == 1 node.out_port(0).data.set_shape(input_data_shape) PermuteAttrs.create_permute_attrs(node, attrs=[('seq_axis', 'input:0')]) PermuteAttrs.create_permute_attrs(node, attrs=[('batch_axis', 'input:0')])
def concat_infer(node): if not node.has('axis'): N = node.N axis_input = node.in_node(N) if axis_input.has_valid('value') and axis_input.value.size == 1: node['axis'] = axis_input.value.item() node.graph.remove_edge( axis_input.node, node.node) # TODO add skip attribute instead of deleting else: return else: N = len(node.in_nodes()) shapes = [node.in_node(i).shape for i in range(N)] if any(s is None for s in shapes): return shape = np.array(shapes[0]) axis = get_canonical_axis_index(shape, node.axis) node.axis = axis mask = np.zeros_like(shape, dtype=np.bool) mask[axis] = True # pylint: disable=unsupported-assignment-operation not_mask = np.logical_not(mask) # pylint: disable=assignment-from-no-return for s in shapes[1:]: s = int64_array(s) if np.all(shape[not_mask] == s[not_mask]): # TODO handle -1 in a special way shape[mask] += s[mask] else: log.error('Concat input shapes do not match') return node.out_node(0).shape = shape if len(shape) != 4: # exclude it from NHWC to NCHW conversion if 'axis' in node.dim_attrs: node.dim_attrs.remove('axis') PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) values = [node.in_node(i).value for i in range(N)] if any(v is None for v in values): return node.out_node(0).value = np.concatenate(values, axis=node.axis).astype( values[0].dtype, copy=False) node.out_node(0).shape = np.array(node.out_node(0).value.shape, dtype=np.int64)
def tf_transpose_infer(node): if len(node.in_nodes()) != 2: log.error("Transpose should take 2 inputs") return node_inp, node_order = (node.in_node(0), node.in_node(1)) order = node_order.value in_shape = np.array(node_inp.shape) node.graph.remove_edge(node_order.node, node.node) node.order = np.array(order) node.out_node().shape = in_shape[order] if node_inp.has_valid('value'): node.out_node().value = np.transpose(node_inp.value, axes=order) PermuteAttrs.create_permute_attrs(node, attrs=[('order','input:0')])
def reorgyolo_infer(node: Node): input_shape = node.in_node(0).shape if input_shape is None: return stride = node.stride output_shape = np.full_like(input_shape, -1, dtype=np.int64) output_shape[node.batch_dims] = input_shape[node.batch_dims] # pylint: disable=unsupported-assignment-operation output_shape[node.channel_dims] = input_shape[node.channel_dims] * stride ** 2 # pylint: disable=unsupported-assignment-operation # Round as in caffe output_shape[node.spatial_dims] = np.round(input_shape[node.spatial_dims] / stride) # pylint: disable=unsupported-assignment-operation node.out_node().shape = output_shape PermuteAttrs.create_permute_attrs(node, attrs=[('channel_dims', 'input:0'), ('spatial_dims', 'input:0')])
def infer(node): unsqueeze_dims = np.array(node.unsqueeze_dims) value = node.in_node(0).value shape = node.in_node(0).shape for dim in unsqueeze_dims: shape = np.insert(shape, dim, 1) node.out_node().shape = np.array(shape) node['dim'] = shape PermuteAttrs.create_permute_attrs(node, attrs=[('dim', 'output:0')]) if value is not None: value = np.reshape(value, shape) node.out_node().value = np.array(value)
def tf_split_v_infer(node: Node): """ Partial infer of split node similar to SplitV op of TF. """ if len(node.in_nodes()) == 1 and not (node.has_valid('axis') and node.has_valid('size_splits')): return if len(node.in_nodes()) == 3 and (node.has_valid('axis') or node.has_valid('size_splits')): return # Three inputs: [input, size_splits, split_dim) if len(node.in_nodes()) == 3: split_dim = node.in_node(2).value assert split_dim.ndim == 0 split_dim = split_dim.item() size_splits = node.in_node(1).value node.graph.remove_edge(node.in_node(1).id, node.id) node.graph.remove_edge(node.in_node(2).id, node.id) else: split_dim = node.axis size_splits = node.size_splits if split_dim is None: log.error( 'split_dim value for node {} is None. Cannot do shape inference.') return input = node.in_node(0) if input.shape is None or size_splits is None: log.error( 'input shape or size of splits are not defined for node {}'.format( node.soft_get('name'))) return log.debug( 'split_dim = {}, input.shape = {}, size_splits.value = {}'.format( split_dim, input.shape, size_splits)) # split_dim is a numpy array, axis is split_dim split(input, node, split_dim, size_splits) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def test_from4D_to3D(self): input_shape = np.array([1, 2, 3, 4]) new_shape = np.array([3, 4, 2]) nhwc_shape = np.array([1, 3, 4, 2]) graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('input_data', { 'shape': input_shape }), ('reshape', { 'dim': new_shape }), ('reshape_data', { 'shape': new_shape }) ]) graph.graph['layout'] = 'NHWC' # add permute attrs to reshape reshape = Node(graph, 'reshape') PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')]) tested_pattern = PermuteForReshape() tested_pattern.find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.permute_nodes, edges_with_attrs=self.edges[1:] + self.permute_edges, update_nodes_attributes=[('input_data', { 'shape': input_shape }), ('reshape', { 'dim': new_shape }), ('reshape_data', { 'shape': new_shape }), ('permute_data', { 'shape': nhwc_shape })]) # check graphs equality (flag, resp) = compare_graphs(graph, graph_ref, last_node='reshape_data') self.assertTrue(flag, resp) # check righ order in new permutation node permute_order = graph.node['reshape/Permute_']['order'] self.assertTrue(np.all( permute_order == np.array([0, 2, 3, 1]))) # from NCHW to NHWC
def infer_for_opset1(node: Node): assert len([p for p in node.in_ports().values() if not p.disconnected()]) == 2 assert node.has_valid('mode') assert node.has_valid('axes') src_shape = node.in_port(0).data.get_shape() assert src_shape is not None dst_shape = node.in_port(1).data.get_value() assert dst_shape is not None output_shape = src_shape.copy() for ind, axis in enumerate(node.axes): output_shape[axis] = dst_shape[ind] node.out_port(0).data.set_shape(output_shape) PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])