def add_unsqueeze_for_new(graph: Graph, ss_node: Node): log.info( "StridedSlice op with new axis mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = mo_array(range(len(ss_node['new_axis_mask'])))[mo_array( ss_node['new_axis_mask'], dtype=bool)] ss_shape = [] for i in range(0, len(ss_node['new_axis_mask'])): if not ss_node['new_axis_mask'][i]: ss_shape.append(shape_out[i]) else: ss_node['new_axis_mask'][i] = 0 ss_node.out_port(0).data.set_shape(ss_shape) # insert Unsqueeze unsqueeze_node = Unsqueeze(graph, dict(name=ss_node.name + '/Unsqueeze_new')).create_node() ss_node.out_port(0).get_connection().insert_node(unsqueeze_node) unsqueeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': unsqueeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(unsqueeze_node.in_port(1))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] if not node.has_valid('start') or not node.has_valid( 'stop') or not node.has_valid('step'): return start_value = Const( graph, dict(value=node.start, symbol_dict={'name': node.id + '/const_start'})).create_node() limit_value = Const( graph, dict(value=node.stop, symbol_dict={'name': node.id + '/const_limit'})).create_node() delta_value = Const( graph, dict(value=node.step, symbol_dict={'name': node.id + '/const_delta'})).create_node() node.in_port(0).get_connection().set_source(start_value.out_port(0)) node.in_port(1).get_connection().set_source(limit_value.out_port(0)) node.in_port(2).get_connection().set_source(delta_value.out_port(0)) if node.has_valid('repeat') and node.repeat > 1: rep = MXRepeat( graph, dict(name=node.id + '/mxrepeat', axis=0, repeats=node.repeat)).create_node() node.out_port(0).get_destination().get_connection().set_source( rep.out_port(0)) rep.in_port(0).connect(node.out_port(0))
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 node_name = node.soft_get('name', node.id) const_attrs = { 'name': node_name + '/indexes', 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node() perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'}) gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]): """ Duplicates shared constants that are consumed by more than one node. If constant is consumed by several ports of one node - no duplication gets done """ graph = node.graph for port_id in port_ids_to_duplicate: dst_port_map = defaultdict(list) for dst in node.in_port( port_id).get_source().get_connection().get_destinations(): dst_port_map[dst.node].append(dst.idx) del dst_port_map[node] value = node.in_port(port_id).data.get_value() if value is None: log.debug( 'Can not duplicate due no data for in_port {} of node {}'. format(port_id, node.name)) for node, idxs in dst_port_map.items(): const = Const( graph, { 'value': mo_array(value), 'name': node.soft_get('name', node.id) + '/duplicated_' }).create_node() for idx in idxs: node.in_port(idx).disconnect() const.out_port(0).connect(node.in_port(idx)) const.infer(const)
def find_and_replace_pattern(self, graph: Graph): for roll_node in graph.get_op_nodes(op='Roll'): if not roll_node.in_port(2).disconnected(): return node_name = roll_node.soft_get('name', roll_node.id) # reshape to 1d tensor reshape_to_1d = create_op_node_with_second_input( graph, Reshape, int64_array([-1]), {'name': node_name + '/reshape'}) roll_node.in_port(0).get_connection().insert_node(reshape_to_1d) # add zero const as axes input to roll const_zero = Const(graph, { 'value': int64_array([0]), 'name': node_name + '/axes' }).create_node() const_zero.out_port(0).connect(roll_node.in_port(2)) # reshape to original shape shape_of = Shape(graph, { 'name': node_name + '/shape_of' }).create_node() reshape_to_1d.in_port(0).get_connection().add_destination( shape_of.in_port(0)) reshape_to_orig_shape = Reshape(graph, {}).create_node() rename_nodes([(roll_node, node_name + '/roll'), (reshape_to_orig_shape, node_name)]) shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1)) roll_node.out_port(0).get_connection().insert_node( reshape_to_orig_shape)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['reduce'] connected_in_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] if len(connected_in_ports) == 1: node_name = node.soft_get('name', node.id) # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal # to None. The infer function handles this case because the input shape is known at this stage only if node.has_valid('axis'): const = Const(graph, { 'name': node_name + '/axis', 'value': node.axis }).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) del graph.node[node.id]['axis'] else: # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor. axes = create_op_with_const_inputs( graph, Range, { 0: int64_array(0), 2: int64_array(1) }, dict(name=node_name + '/axes')) end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node() node.in_port(0).get_connection().get_source().connect( end_of_range.in_port(0)) end_of_range.out_port(0).connect(axes.in_port(1)) node.add_input_port(1, skip_if_exist=True) axes.out_port(0).connect(node.in_port(1))
def replace_pattern(self, graph: Graph, match: [str, Node]): node = match['transpose'] assert len(node.in_nodes()) == 1 order = np.arange(len(node.in_port(0).data.get_shape()))[::-1] const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) node['reverse_order'] = False
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name, 'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t]), dtype=np.float32), 'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node() memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': mo_array([in_shape[1]*(node_t-1)]), 'offset': mo_array([in_shape[1]]), 'axis': mo_array([1])}).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': mo_array([in_shape[1]]), 'offset': mo_array([0]), 'axis': mo_array([1])}).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str): """ This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends" :param graph: graph to operate on. :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice. :param shape: input shape of the Slice :param axes: axes that "starts" and "ends" apply to :param node_name: Slice node name :return: Concat node that forms "begin"/"end" values for the StridedSlice """ # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is # supported by the StridedSlice layer clamp = create_op_with_const_inputs(graph, Clamp, port_value_dict={ 1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max }, op_attrs=dict(name=node_name + '/Clamp')) clamp.in_port(0).connect(slice_border_port) # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created # here to prevent type errors in Concat node cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() cast.in_port(0).connect(clamp.out_port(0)) concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node() for value_idx, port_idx in enumerate(axes): concat.add_input_port(port_idx) # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct # Concat input port value = create_op_with_const_inputs( graph, Gather, port_value_dict={ 1: int64_array([value_idx]), 2: int64_array(0) }, op_attrs={'name': node_name + '/Gather'}) cast.out_port(0).connect(value.in_port(0)) value.out_port(0).connect(concat.in_port(port_idx)) for port_idx in range(len(shape)): if not concat.is_in_port_connected(port_idx): concat.add_input_port(port_idx) # This border value would be ignored in StridedSlice because of the begin_mask\end_mask const = Const( graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node() const.out_port(0).connect(concat.in_port(port_idx)) return concat
def input_as_const(node: Node, attrs: dict, port: int, bin: str, value: np.ndarray): """ Inserts constant node on input `port` of `node` with `values` and `attrs`. Marks input edge with bin `attribute` """ graph = node.graph const = Const(graph, {'value': value, **attrs}).create_node() node.add_input_port(port, skip_if_exist=True) const.out_port(0).connect(node.in_port(port)) node.in_port(port).bin = bin node.in_port(port).in_attrs.append('bin')
def replace_sub_graph(graph: Graph, match: dict): node = match['op'] for port_index, value_attr, attrs in node['embedded_inputs']: const = Const(graph, dict(value=node[value_attr])).create_node() node.add_input_port(port_index, skip_if_exist=True) const.out_port(0).connect(node.in_port(port_index)) node.in_port(port_index).bin = attrs['bin'] node.in_port(port_index).in_attrs.append('bin') del node[value_attr] del node['embedded_inputs']
def create_op_node_with_second_input(graph: Graph, op: callable, second_input_value: np.array, op_attrs=None, input_node=None): operation = op(graph, op_attrs) node = operation.create_node() if input_node is not None: input_node.out_port(0).connect(node.in_port(0)) second_input_node = Const(graph, {'name': node.name + '/value', 'value': second_input_value}).create_node() second_input_node.out_port(0).connect(node.in_port(1)) if graph.stage != 'front': second_input_node.infer(second_input_node) return node
def __insert_mul_node_with_coeff(node: Node, port: int, coeff: float): if coeff != 1: mul_node = Mul(node.graph, { 'name': node.id + '/coeff_mul' }).create_node() const_node = Const(node.graph, { 'name': node.id + '/coeff', 'value': mo_array([coeff]) }).create_node() node.in_port(port).get_connection().insert_node(mul_node) const_node.out_port(0).connect(mul_node.in_port(1))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) assert node.has_valid('axis') axis = Const(graph, {'name': name + '/axis', 'value': int64_array(node.axis)}).create_node() gather = Gather(graph, {'name': name}).create_node() node.in_port(0).get_connection().set_destination(gather.in_port(0)) node.in_port(1).get_connection().set_destination(gather.in_port(1)) axis.out_port(0).connect(gather.in_port(2)) node.out_port(0).get_connection().set_source(gather.out_port(0))
def replace_op(self, graph: Graph, node: Node): const = Const( graph, dict(value=mo_array(-1.), name=node.name + '/reciprocal_pow_const_')).create_node() reciprocal = Pow(graph, { 'name': node.name + '/reciprocal_pow_' }).create_node() node.in_port(0).get_connection().set_destination(reciprocal.in_port(0)) const.out_port(0).connect(reciprocal.in_port(1)) return [reciprocal.id]
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['reshape'] connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] if len(connected_in_ports) == 1: if node.has('dim'): const = Const(graph, {'value': node.dim}).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) del node['dim'] else: raise Error('The `dim` attribute for node {} is not set'.format(node.op))
def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int): node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros( num_insertions) if input_name != 'strides' else np.ones( num_insertions) blank_values_node = Const( graph, { 'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name), 'value': int64_array(blank_values_arr) }).create_node() concat_in_ports_count = 3 if ellipsis_start != 0 else 2 concat = Concat( graph, { 'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': concat_in_ports_count }).create_node() if ellipsis_start != 0: split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([ellipsis_start, -1]) }, { 'name': node_name + '/split_for_{}_ellipsis'.format(input_name), 'out_ports_count': 2 }) node.in_port(i).get_connection().set_destination( split.in_port(0)) concat.in_port(0).connect(split.out_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.in_port(2).connect(split.out_port(1)) else: concat.in_port(0).connect(blank_values_node.out_port(0)) node.in_port(i).get_connection().set_destination( concat.in_port(1)) concat.out_port(0).get_connection().set_destination( node.in_port(i))
def add_squeeze_for_shrink(graph: Graph, ss_node: Node): # add Squeeze for shrink_axis_mask log.info( "StridedSlice op with shrink mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = mo_array(range(len(ss_node['shrink_axis_mask'])))[mo_array( ss_node['shrink_axis_mask'], dtype=bool)] ss_shape = [] i = 0 k = 0 # Don't permute reshape if channels were squeezed dont_permute = graph.graph['layout'] == 'NCHW' if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][ -1] == 1: dont_permute = True while k < len(shape_out): if i >= len(ss_node['shrink_axis_mask'] ) or not ss_node['shrink_axis_mask'][i]: ss_shape.append(shape_out[k]) k = k + 1 else: ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 while i < len(ss_node['shrink_axis_mask']): ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 ss_node.out_port(0).data.set_shape(ss_shape) # insert Squeeze squeeze_node = Squeeze( graph, dict(name=ss_node.name + '/Squeeze_shrink', nchw_layout=dont_permute, correct_data_layout=dont_permute)).create_node() ss_node.out_port(0).get_connection().insert_node(squeeze_node) squeeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': squeeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(squeeze_node.in_port(1))
def replace_pattern(self, graph: Graph, match: [str, Node]): swapaxis = match['op'] assert len(swapaxis.in_ports()) == 1 assert swapaxis.has_and_set('order') order = swapaxis.order swapaxis.add_input_port(1) const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node() const.out_port(0).connect(swapaxis.in_port(1)) Transpose.update_node_stat(swapaxis, {'need_shape_inference': True}) del swapaxis['order']
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['transpose'] connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] if len(connected_in_ports) == 1: if node.has_valid('order'): const = Const(graph, {'value': node.order}).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) del graph.node[node.id]['order'] elif node.has('order') and node.order is None: assert node.has_and_set('reverse_order') else: raise Error('Can not deduce transpose `order` for {}: only one in_port and no `order` parameter.' ''.format(node.soft_get('name', node.id)))
def create_fake_quantize_node(graph: Graph, name, data_type=np.float32): fq = FakeQuantize(graph, { 'name': name, 'levels': 0, 'stop_value_propagation': True }).create_node() input_low = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() input_height = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() output_low = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() output_height = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() input_low.out_port(0).connect(fq.in_port(1)) input_height.out_port(0).connect(fq.in_port(2)) output_low.out_port(0).connect(fq.in_port(3)) output_height.out_port(0).connect(fq.in_port(4)) input_low.infer(input_low) input_height.infer(input_height) output_low.infer(output_low) output_height.infer(output_height) return fq
def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dict[int, np.array], op_attrs=None, input_node=None): operation = op(graph, op_attrs) node = operation.create_node() if input_node is not None: input_node.out_port(0).connect(node.in_port(0)) for idx, value in port_value_dict.items(): node.add_input_port(idx, skip_if_exist=True) value_input_node = Const(graph, {'name': node.name + '_input_port_' + str(idx) + '/value', 'value': value}).create_node() value_input_node.out_port(0).connect(node.in_port(idx)) if graph.stage != 'front': value_input_node.infer(value_input_node) return node
def swap_pad_and_unsqueeze(self, pad: Node, unsqueeze: Node): # insert additional items to the pads in the position specified by the Unsqueeze axis unsqueeze_axis = unsqueeze.in_port(1).data.get_value() for port_id in [1, 2]: current_value = pad.in_port( port_id).get_connection().data.get_value() new_value_node = Const( pad.graph, { 'name': pad.soft_get('name', pad.id) + '/value_{}'.format(port_id), 'value': shape_insert(current_value, unsqueeze_axis.item(), 0), 'override_output_shape': True }).create_node() pad.in_port(port_id).disconnect() pad.in_port(port_id).connect(new_value_node.out_port(0)) # swap Pad and Unsqueeze layers unsqueeze.in_port(0).disconnect() pad.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) unsqueeze.out_port(0).get_connection().set_source(pad.out_port(0)) unsqueeze.out_port(0).connect(pad.in_port(0)) # output shapes of Pad and Unsqueeze changed so need to recalculate them pad['override_output_shape'] = True unsqueeze['override_output_shape'] = True
def replace_pattern(self, graph: Graph, match: dict): bias_add = match['BiasAdd'] # Replace BiasAdd by Add operation new_add = Add(graph, {'name': bias_add.id + '/Add'}).create_node() bias_add.in_port(0).get_connection().set_destination(new_add.in_port(0)) bias_add.in_port(1).get_connection().set_destination(new_add.in_port(1)) bias_add.out_port(0).get_connection().set_source(new_add.out_port(0)) if bias_add.data_format != 'NCHW': return input_shape = new_add.in_port(0).data.get_shape() bias_shape = new_add.in_port(1).data.get_shape() assert len(bias_shape) == 1 unsqueeze_dims = np.arange(len(input_shape)) channel_dim = get_features_dim('NCHW', len(input_shape)) unsqueeze_dims = np.delete(unsqueeze_dims, channel_dim, 0) unsqueeze_node = Unsqueeze(graph, {'name': new_add.id + '/BiasUnsqueeze'}).create_node() unsqueeze_dims_node = Const(graph, {'name': new_add.id + '/Dims', 'value': unsqueeze_dims}).create_node() # Reconnecting nodes unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0)) unsqueeze_node['override_output_shape'] = True new_add.in_port(1).get_connection().insert_node(unsqueeze_node)
def replace_pattern(self, graph: Graph, match: dict): """ Adds Normalize layer weights, which are required by Inference Engine, but do not always exist in MXNet model. L2Normalization is mapped to Normalize layer so we need to generate Normalize weights filled with ones. Parameters ---------- graph : Graph Graph with loaded model. match : dict Patterns which were found in graph structure. """ l2_normalization_node = match['l2_normalization'] if len(l2_normalization_node.in_nodes()) < 2: value = np.full([l2_normalization_node.in_node(0).shape[1]], 1.0, dtype=np.float32) weights_node = Const( graph, dict(name=l2_normalization_node['name'] + '_weights', value=value)).create_node() l2_normalization_node.add_input_port(1) l2_normalization_node.in_port(1).connect(weights_node.out_port(0)) l2_normalization_node.in_port(1).bin = 'weights'
def create_bias_node(graph: Graph, src_node): logger.debug('Creating new bias for {}'.format(src_node.name)) destination_ports = [] for dest_port in src_node.out_port(0).get_destinations(): destination_ports.append(dest_port) # Create Add and constant with zero bias bias_shape = src_node.out_port(0).data.get_shape() add_bias_shape = [1] * len(bias_shape) add_bias_shape[1] = bias_shape[1] weights = get_weights_for_node(src_node) bias_dtype = np.float32 if weights and weights.out_port(0).is_data_type_defined(): bias_dtype = weights.out_port(0).get_data_type() add_bias = Const( graph, { 'value': np.zeros(add_bias_shape, dtype=bias_dtype), 'shape': add_bias_shape, 'need_shape_inference': True }).create_node() add_op = Add(graph, { 'name': src_node.name + '/add_', 'need_shape_inference': True }).create_node() # Connect Const to Add node add_op.in_port(1).connect(add_bias.out_port(0)) # Reconnect src_node -> output to src_node -> Add -> output src_node.out_port(0).disconnect() src_node.out_port(0).get_connection().set_destination(add_op.in_port(0)) for destination_port in destination_ports: add_op.out_port(0).connect(destination_port) add_bias.out_node(0)['Insert_Convert_operation_after'] = True
def replace_sub_graph(graph: Graph, match: dict): strided_slice_node = match['strided_slice'] const_node = match['const'] reshape_node = match['reshape'] pack_node = match['pack'] if not const_node.has_valid('value') or not is_value_is_constant(const_node.value, -1): log.debug('The pattern does not correspond to flatten. The second reshape dimension is not -1. It is {}'. format(const_node.soft_get('value'))) return if len(pack_node.in_nodes()) != 2: log.debug('The pattern does not correspond to flatten. The "Pack" operation produces tensor with 3 items ' 'but should produce just 2.') return expected_values = [0, 1, 1] # expected values to a StridedSlice to get the batch size for ind in range(3): if not strided_slice_node.in_node(ind + 1).has_valid('value') or \ not is_value_is_constant(strided_slice_node.in_node(ind + 1).value, expected_values[ind]): log.debug('The pattern does not correspond to flatten because of the input with index {}. The value is ' '"{}".'.format(ind, strided_slice_node.soft_get('value'))) return reshape_node.in_port(1).disconnect() reshape_const_node = Const(graph, {'value': int64_array([0, -1]), 'name': reshape_node.soft_get('name', reshape_node.id) + '/shape'}).create_node() reshape_node.in_port(1).connect(reshape_const_node.out_port(0)) reshape_node['special_zero'] = True log.debug('The node "{}" is actually a Flatten node'.format(reshape_node.soft_get('name')))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) const_dtype = np.float32 if node.has_valid('data_type'): const_dtype = node.data_type const = Const(graph, {'value': mo_array([1], dtype=const_dtype)}).create_node() add = Add(graph, {'name': node.name + '/Add_'}).create_node() log = Log(graph, {'name': node.name + '/Log_'}).create_node() # Connect nodes: input -> Add -> Log const.out_port(0).connect(add.in_port(0)) node.in_port(0).get_connection().set_destination(add.in_port(1)) add.out_port(0).connect(log.in_port(0)) rename_nodes([(node, node_name + '/delete'), (log, node_name)]) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [log.id]
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node :param external_match: a match used for handling a part of the main graph responsible for output concatenation :param internal_match: a match used for handling a part of the body graph responsible for output concatenation """ loop_node = external_match['while'] stack_node = external_match['stack'] list_reserve_node = external_match['reserve'] body_graph = loop_node['body'] tensor_list_set_item_node = internal_match['concatenation'] tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get( 'name', tensor_list_set_item_node.id) list_result_node = internal_match['concatenation_result'] # replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node # to concatenate results from different iterations unsqueeze_list_element = create_op_with_const_inputs( body_graph, Unsqueeze, {1: int64_array(0)}, {'name': 'TensorListSetItemUnsqueeze'}) tensor_list_set_item_node.in_port(2).get_connection().set_destination( unsqueeze_list_element.in_port(0)) tensor_list_set_item_node.out_port(0).get_connection().set_source( unsqueeze_list_element.out_port(0)) rename_nodes([(tensor_list_set_item_node, tensor_list_set_item_node_name + '/AbandonedName'), (unsqueeze_list_element, tensor_list_set_item_node_name) ]) list_result_node_layer_id = list_result_node.internal_layer_id Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id, 'axis', 0) # remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated stack_node.out_port(0).get_connection().set_source( stack_node.in_port(0).get_connection().get_source()) # disconnect ListReserve node because it is no longer needed for Loop list_reserve_node.out_port(0).disconnect() # connect a number of iterations with trip count that can be received from the second input of ListReserve # create a constant network with True value for execution_condition so that IE can ignore execution condition # and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism. loop_node.in_port(1).disconnect() list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1)) for record in loop_node.output_port_map: if 'purpose' in record and record[ 'purpose'] == 'execution_condition': exec_cond_layer_id = record['internal_layer_id'] exec_cond_node = Loop.get_body_node_by_internal_id( loop_node, exec_cond_layer_id) const_true = Const(body_graph, { 'value': np.array(True, dtype=np.bool) }).create_node() exec_cond_node.in_port(0).get_connection().set_source( const_true.out_port(0))
def find_and_replace_pattern(self, graph: Graph): global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True) if len(global_poolings) == 0: return layout = graph.graph['layout'] assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)' for pooling in global_poolings: name = pooling.soft_get('name', pooling.id) assert pooling.has_valid( 'pool_method' ), 'Global Pooling {} has no `pool_method` attribute'.format(name) method = pooling['pool_method'] assert method in self.pool_method_to_reduce_type, \ 'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name) reduce_op_class = self.pool_method_to_reduce_type[method] reduce = reduce_op_class(graph, { 'name': name + '/reduce', 'keep_dims': True }).create_node() pooling.out_port(0).get_connection().set_source(reduce.out_port(0)) src = pooling.in_port(0).get_connection().get_source() reduce.in_port(0).get_connection().set_source(src) start = Const(graph, {'value': int64_array(2)}).create_node() end = Rank(graph, {'name': name + '/input_rank'}).create_node() delta = Const(graph, {'value': int64_array(1)}).create_node() axis = Range(graph, { 'name': name + '/global_pooling_reduce_axis' }).create_node() axis.in_port(0).connect(start.out_port(0)) src.connect(end.in_port(0)) axis.in_port(1).connect(end.out_port(0)) axis.in_port(2).connect(delta.out_port(0)) axis.out_port(0).connect(reduce.in_port(1)) log.debug('Global {} pooling was converted to reduce: `{}`'.format( method, name))