def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int): node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions) blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name), 'value': int64_array(blank_values_arr)}).create_node() concat_in_ports_count = 3 if ellipsis_start != 0 else 2 concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': concat_in_ports_count}).create_node() if ellipsis_start != 0: split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([ellipsis_start, -1])}, {'name': node_name + '/split_for_{}_ellipsis'.format(input_name), 'out_ports_count': 2}) node.in_port(i).get_connection().set_destination(split.in_port(0)) concat.in_port(0).connect(split.out_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.in_port(2).connect(split.out_port(1)) else: concat.in_port(0).connect(blank_values_node.out_port(0)) node.in_port(i).get_connection().set_destination(concat.in_port(1)) concat.out_port(0).get_connection().set_destination(node.in_port(i))
def extend_inputs(node: Node, num_insertions: int): graph = node.graph node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions) blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name), 'value': int64_array(blank_values_arr)}).create_node() if node.in_port(i).get_source().node.soft_get('type') == 'Concat': # concat already exists concat = node.in_port(i).get_source().node last_in_port = max(concat.in_ports().keys()) assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {}' \ 'should be connected'. \ format(concat.soft_get('name', node.id)) concat.add_input_port(last_in_port + 1) concat.in_port(last_in_port + 1).connect(blank_values_node.out_port(0)) else: # have to create concat concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': 2}).create_node() node.in_port(i).get_connection().set_destination(concat.in_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.out_port(0).get_connection().set_destination(node.in_port(i))
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t) memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1])}).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_tdnn(self, graph: Graph, tdnn_node: Node): tdnn_name = tdnn_node.soft_get('name', tdnn_node.id) concat_node = Concat(graph, {'axis': 1}).create_node() rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'), (concat_node, tdnn_name)]) for offset_ind, t in enumerate(tdnn_node['time_offsets']): concat_node.add_input_port(offset_ind) if t != 0: memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t)) memoryoffset_node = MemoryOffset( graph, { 'name': memory_name, 't': t, 'pair_name': memory_name + '_out', 'has_default': False, 'splitted': False }).create_node() tdnn_node.in_port(0).get_source().connect( memoryoffset_node.in_port(0)) memoryoffset_node.out_port(0).connect( concat_node.in_port(offset_ind)) else: # 0 time delay is not allowed in IE, it's meaningless # if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset tdnn_node.in_port(0).get_source().connect( concat_node.in_port(offset_ind)) weights = tdnn_node['weights'] fc_inputs = {1: weights} bias_term = False if tdnn_node.has_valid('biases'): assert len(tdnn_node['biases']) == weights.shape[0] fc_inputs.update({2: tdnn_node['biases']}) bias_term = True fc_node = create_op_with_const_inputs( graph, FullyConnected, fc_inputs, { 'name': tdnn_name + '/FC', 'out-size': weights.shape[0], 'transpose_weights': True, 'bias_term': bias_term }) concat_node.out_port(0).connect(fc_node.in_port(0)) tdnn_node.in_port(0).disconnect() tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['tile'] name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() assert input_shape is not None tiles = node.in_port(1).data.get_value() assert tiles is not None, "Undefined `repeats` (1st port input value) of Tile node '{}'".format(name) if input_shape.size == tiles.size: return if input_shape.size < tiles.size: unsqueeze = create_op_node_with_second_input(graph, Unsqueeze, int64_array(list(range(tiles.size - input_shape.size))), {'name': name + '/input_alignment', 'override_output_shape': True}) node.in_port(0).get_source().connect(unsqueeze.in_port(0)) node.in_port(0).get_connection().set_source(unsqueeze.out_port(0)) else: const = Const(graph, {'name': name + '/tile_alignment_const', 'value': np.ones([input_shape.size - tiles.size], dtype=np.int64)}).create_node() concat = Concat(graph, {'axis': 0, 'override_output_shape': True}).create_node() concat.add_input_port(0) concat.add_input_port(1) node.in_port(1).get_source().connect(concat.in_port(1)) node.in_port(1).disconnect() concat.in_port(0).connect(const.out_port(0)) node.in_port(1).connect(concat.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): concat_node = match['concat'] concat_node['axis'] = 1 concat_name = concat_node.soft_get('name', concat_node.id) concat_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([1, 2, -1]), op_attrs=dict( name=concat_name + '/Reshape')) split_node = create_op_node_with_second_input(graph, Split, int64_array(1), op_attrs=dict( name=concat_name + '/Split', num_splits=2), input_node=concat_reshape) split_node_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 4]), op_attrs=dict( name=split_node.name + '/Reshape')) split_node.out_port(0).connect(split_node_reshape.in_port(0)) value = create_op_node_with_second_input(graph, Split, int64_array(1), op_attrs=dict( name=split_node_reshape.name + '/Split', num_splits=4), input_node=split_node_reshape) xmin, xmax = calculate_prior_box_value(value, value_to_div=value.out_port(2), value_to_add=value.out_port(0)) ymin, ymax = calculate_prior_box_value(value, value_to_div=value.out_port(3), value_to_add=value.out_port(1)) concat_slice_value = Concat(graph, dict(name=value.name + '/Concat', in_ports_count=4, axis=1)).create_node() for ind, node in enumerate([xmin, ymin, xmax, ymax]): concat_slice_value.in_port(ind).connect(node.out_port(0)) reshape_concat_values = create_op_node_with_second_input(graph, Reshape, int64_array([1, 1, -1]), op_attrs=dict(name=concat_slice_value.name + '/Reshape'), input_node=concat_slice_value) concat = Concat(graph, dict(name=reshape_concat_values.name + '/Concat', in_ports_count=2, axis=1)).create_node() concat.in_port(0).connect(reshape_concat_values.out_port(0)) concat.in_port(1).connect(split_node.out_port(1)) match['detection_output'].in_port(2).get_connection().set_source(concat.out_port(0)) concat_node.out_port(0).get_connection().set_destination(concat_reshape.in_port(0))
def replace_with_split_concat(node): graph = node.graph name = node.soft_get('name', node.id) axis = node.axis order = node.order split = create_op_with_const_inputs(graph, Split, {1: int64_array(axis)}, { 'name': name + '/Split', 'num_splits': order.size }) concat = Concat(graph, { 'name': name + '/Concat', 'axis': axis, 'in_ports_count': order.size }).create_node() for out_port_idx, in_port_idx in enumerate(order): split.out_port(out_port_idx).connect(concat.in_port(in_port_idx)) node.out_port(0).get_connection().set_source(concat.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) graph.remove_node(node.id)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['mxreshape'] input_index = 0 reshape_index = 0 shape_node = Shape(graph, dict(name=node.id + '/ShapeMXReshape')).create_node() shape_node.in_port(0).connect(node.in_port(0).get_source()) output_dims_nodes = [] for d in node.dim: if reshape_index < len(node.dim): input_index, reshape_index, output_dims_nodes = self.resolve( input_index, reshape_index, node.dim, shape_node, output_dims_nodes) concat_node = Concat( shape_node.graph, dict(name=shape_node.id + '/ConcatMXReshape_', axis=0, in_ports_count=len(output_dims_nodes))).create_node() for in_port_index, dim_node in enumerate(output_dims_nodes): concat_node.in_port(in_port_index).connect(dim_node.out_port(0)) reshape_node = Reshape(graph, dict(name=node.id + '/Reshape_')).create_node() reshape_node.in_port(1).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': np.array([-2])}).create_node() end = Const(graph, {'value': np.array([-1])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) concat_value = Const(graph, {'value': np.array([4])}).create_node() shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': np.array(0)}).create_node() shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0)) concat_value.out_port(0).connect(shape_concat.in_port(1)) variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def replace_pattern(self, graph: Graph, match: dict): node = match['node'] node_name = node.soft_get('name', node.id) connected_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] if len(connected_ports) == 2: axis = node.in_port(1).data.get_value() else: axis = node.axis assert axis is not None, 'The "axis" should be defined for node "{}"'.format( node_name) assert node.has_and_set( 'output_type'), 'The data type is not set for node "{}"'.format( node_name) topk_mode = 'max' if node.op == 'ArgMax' else 'min' topk_node = TopK( graph, { 'axis': axis, 'mode': topk_mode, 'sort': 'index', 'remove_values_output': node.has_and_set('remove_values_output'), 'index_element_type': node.output_type }).create_node() node.in_port(0).get_connection().set_destination(topk_node.in_port(0)) if node.has_and_set( 'out_max_val' ): # in this mode the ArgMax produces tuples (max_ind, max_value) concat_node = Concat(graph, { 'axis': 1, 'name': node.name + '/Concat' }).create_node() concat_node.add_input_port(0, skip_if_exist=True) concat_node.add_input_port(1, skip_if_exist=True) topk_node.out_port(0).connect(concat_node.in_port(1)) # indices topk_node.out_port(1).connect(concat_node.in_port(0)) # values if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source( concat_node.out_port(0)) else: if not node.out_port(0).disconnected(): node.out_port(0).get_connection().set_source( topk_node.out_port(1)) topk_node.in_port(1).connect( Const(graph, { 'name': node.soft_get('name') + '/TopK', 'value': node.top_k }).create_node().out_port(0)) graph.remove_nodes_from([node.id, node.out_node(0).id])
def fuse_reduces(first_reduce, second_reduce): first_reduce_name = first_reduce.soft_get('name', first_reduce.id) second_reduce_name = second_reduce.soft_get('name', second_reduce.id) reduce_type = first_reduce.type assert first_reduce.type == second_reduce.type if len(first_reduce.out_port(0).get_destinations()) != 1: # data dependency return if first_reduce.keep_dims != second_reduce.keep_dims: return first_axes = first_reduce.in_port(1).data.get_value() second_axes = second_reduce.in_port(1).data.get_value() if first_axes is None or second_axes is None: # dynamic axes merging is not supported return if not first_reduce.keep_dims: if not np.all(first_axes > second_axes): # indexing of upper reduce input dimensions changed return graph = second_reduce.graph new_axes = Concat( graph, { 'name': second_reduce_name + '/Axes', 'axis': int64_array(0), 'in_ports_count': 2, 'override_output_shape': True }).create_node() new_axes.in_port(0).connect(first_reduce.in_port(1).get_source()) new_axes.in_port(1).connect(second_reduce.in_port(1).get_source()) first_reduce.in_port( 0).get_source().node['need_shape_inference'] = True first_reduce.in_port( 0).get_source().node['override_output_shape'] = True second_reduce.in_port(1).get_connection().set_source( new_axes.out_port(0)) first_reduce.out_port(0).get_connection().set_source( first_reduce.in_port(0).get_connection().get_source()) first_reduce.in_port(1).disconnect() graph.remove_node(first_reduce.id) log.debug( '{0} nodes {1} and {2} were fused to a single {2} node with updated axes input' ''.format(reduce_type, first_reduce_name, second_reduce_name))
def replace_pattern(graph: Graph, match: dict): node = match['pool'] if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] shape = Shape(graph, {}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([1, -1]) }, {'out_ports_count': 2}, shape) node_pool_stride = Const(graph, { 'value': int64_array([node.pool_stride]) }).create_node() pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1])) pow_node.in_port(0).connect(node_pool_stride.out_port(0)) mul = Mul(graph, {}).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) const_1 = Const(graph, {'value': int64_array([1])}).create_node() concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node() concat.in_port(0).connect(split.out_port(0)) concat.in_port(3).connect(mul.out_port(0)) concat.in_port(2).connect(const_1.out_port(0)) concat.in_port(1).connect(node_pool_stride.out_port(0)) reshape_in = Reshape(graph, { 'name': '/Reshape/' + node.name }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node.name + '/Reshape/'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float): # create init_graph connected to ReadValue graph = input_out_port.node.graph input_name = input_out_port.node.name shape_of_input = Shape(graph, { 'name': 'shape/' + input_name }).create_node() shape_of_input.in_port(0).connect(input_out_port) dim_for_get_batch = Const( graph, { 'name': 'dim/crop_batch/' + shape_of_input.name, 'value': int64_array([1]), 'shape': int64_array([1]) }).create_node() get_batch = Crop( graph, { 'name': 'crop_batch/' + shape_of_input.name, 'axis': int64_array([0]), 'offset': int64_array([0]) }).create_node() get_batch.in_port(0).connect(shape_of_input.out_port(0)) get_batch.in_port(1).connect(dim_for_get_batch.out_port(0)) mem_shape_2nd_dim = Const( graph, { 'name': 'gifo_r_weights_shape/' + input_name, 'value': int64_array([second_dim]), 'shape': int64_array([1]) }).create_node() mem_shape = Concat( graph, { 'name': 'gather_memory_shape/' + input_name, 'axis': 0, 'in_ports_count': 2 }).create_node() mem_shape.in_port(0).connect(get_batch.out_port(0)) mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0)) fill_value = Const( graph, { 'name': 'fill_value/' + input_name, 'value': np.array([0.0], precision), 'shape': int64_array([1]) }).create_node() init_value_prev_lstm_output = Broadcast(graph, { 'name': 'init_value/' + input_name, }).create_node() init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0)) init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0)) return init_value_prev_lstm_output
def replace_pattern(self, graph: Graph, match: dict): concat_node = match['concat'] sources_of_ports = [ concat_node.in_port(i).get_connection().get_source() for i in concat_node.in_ports() ] # If 'concat' is ConcatV2 layer from TF, then this layer initially had input 'axis' as the last input. # But then this input was deleted and the attribute 'axis' was added. Hence, the last port source can # be None in such case. sources_of_ports = [s for s in sources_of_ports if s is not None] input_nodes = [s.node for s in sources_of_ports] if not all(n.has_valid('type') for n in input_nodes): return saved_ports = [] disconnected_ports = [] for port_num, node in enumerate(input_nodes): if node.soft_get('type') == 'Const' and len( node.shape) > 1 and any(i == 0 for i in node.shape): disconnected_ports.append(port_num) else: saved_ports.append(port_num) if not saved_ports or not disconnected_ports: return if len(saved_ports) == 1: before_concat = concat_node.in_port( saved_ports[0]).get_connection().get_source() concat_node.out_port(0).get_connection().set_source(before_concat) return new_concat_attrs = concat_node.attrs().copy() new_concat_attrs['name'] = concat_node.name + '/Concat_' new_concat_attrs['in_ports_count'] = len(saved_ports) new_concat_node = Concat(graph, attrs=new_concat_attrs).create_node() for new_port_num, old_port_num in enumerate(saved_ports): concat_node.in_port(old_port_num).get_connection().set_destination( new_concat_node.in_port(new_port_num)) for p in disconnected_ports: concat_node.in_port(p).disconnect() concat_node.out_port(0).get_connection().set_source( new_concat_node.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): merge = match['merge'] power = Pow(graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM' }).create_node() const1 = Const(graph, { 'value': -1.0, 'name': merge.name + '/negate_const' }).create_node() merge.in_port(0).get_connection().set_destination(power.in_port(0)) const1.out_port(0).connect(power.in_port(1)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(merge.significant, merge.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(10.0, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) power.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(10.0, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = Cast( graph, { 'name': merge.name + '/Cast_' + idx.__str__(), 'dst_type': np.float32 }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': False, 'in_ports_count': 2, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array([1, 5]), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] in_shape = node.in_port(0).data.get_shape().copy() memory_element = in_shape[1] - node.const_dim memory_size = memory_element * len(node.context) memory_pair_id = unique_id('id') # Memory(in) input_memory = Memory( graph, { 'name': 'prev_splice_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': int64_array([memory_size]) }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element]), 'dim': int64_array([memory_size - memory_element]) }).create_node() crop.in_port(0).connect(input_memory.out_port(0)) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node.in_port(0).connect(crop.out_port(0)) # Concat -> Memory(out) mem_out = Memory( graph, { 'name': 'out_splice_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': int64_array([memory_size]) }).create_node() mem_out.in_port(0).connect(concat_node.out_port(0)) Result(graph).create_node().in_port(0).connect(mem_out.out_port(0)) if node.const_dim != 0: memory_element_constdim = node.const_dim memory_size_constdim = memory_element_constdim * len(node.context) split = create_op_with_const_inputs( graph, VariadicSplit, { 1: int64_array(1), 2: int64_array([memory_element, memory_element_constdim]) }, { 'name': node.id + '_split_const', 'out_ports_count': 2 }) split.out_port(0).connect(concat_node.in_port(1)) # create separate splice construction for const_dim memory_pair_id = unique_id('memory_for_const_dim') input_memory_const_dim = Memory( graph, { 'name': 'const_dim_in_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': int64_array([memory_size_constdim]) }).create_node() crop_const_dim = Crop( graph, { 'name': 'const_dim_crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element_constdim]), 'dim': int64_array( [memory_size_constdim - memory_element_constdim]) }).create_node() crop_const_dim.in_port(0).connect( input_memory_const_dim.out_port(0)) concat_node_const_dim = Concat(graph, { 'name': 'const_dim_concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node_const_dim.in_port(0).connect( crop_const_dim.out_port(0)) mem_out_const_dim = Memory( graph, { 'name': 'const_dim_out_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': int64_array([memory_size_constdim]) }).create_node() mem_out_const_dim.in_port(0).connect( concat_node_const_dim.out_port(0)) Result(graph).create_node().in_port(0).connect( mem_out_const_dim.out_port(0)) # connect splice to Split as begin and Concat as the end split.out_port(1).connect(concat_node_const_dim.in_port(1)) crop_first = Crop( graph, { 'name': 'const_dim_crop_first', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([memory_element_constdim]) }).create_node() crop_first.in_port(0).connect(concat_node_const_dim.out_port(0)) concat_const = Concat(graph, { 'name': node.id + '_concat_const', 'axis': 1, 'in_ports_count': 2 }).create_node() concat_const.in_port(1).connect(crop_first.out_port(0)) concat_const.in_port(0).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) node.out_port(0).get_connection().set_source( concat_const.out_port(0)) else: node.in_port(0).get_connection().set_destination( concat_node.in_port(1)) node.out_port(0).get_connection().set_source( concat_node.out_port(0)) # to avoid re-inference of shape and touching in next replacements graph.remove_node(node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_zero_value_with_batch_from_input( in_node_port, context_len, np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_sub_graph(self, graph: Graph, match: dict): # obtain references to necessary nodes and their names fill = match['fill'] dims = match['dims'] strided_slice = match['strided_slice'] strided_slice_1 = match['strided_slice_1'] ctc_greedy_decoder = match['decoder'] cast = match['cast'] sparse_to_dense = match['sparse_to_dense'] strided_slice_name = strided_slice.soft_get('name', strided_slice.id) strided_slice_1_name = strided_slice_1.soft_get( 'name', strided_slice_1.id) ctc_greedy_decoder_name = ctc_greedy_decoder.soft_get( 'name', ctc_greedy_decoder.id) sparse_to_dense_name = sparse_to_dense.soft_get( 'name', sparse_to_dense.id) # unsqueeze scalar values with batch size and time dimension unsqueeze_batch_size = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, {'name': strided_slice_name + '/Unsqueeze'}) dims.in_port(0).get_connection().set_destination( unsqueeze_batch_size.in_port(0)) unsqueeze_time_size = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, {'name': strided_slice_1_name + '/Unsqueeze'}) fill.in_port(1).get_connection().set_destination( unsqueeze_time_size.in_port(0)) # compute a sequence mask shape [T, N] required for CTCGreedyDecoder seq_mask_shape = Concat( graph, { 'axis': 0, 'in_ports_count': 2, 'name': ctc_greedy_decoder_name + '/SequenceMaskShape' }).create_node() seq_mask_shape.in_port(0).connect(unsqueeze_time_size.out_port(0)) seq_mask_shape.in_port(1).connect(unsqueeze_batch_size.out_port(0)) # compute a sequence mask sequence_mask = create_op_with_const_inputs( graph, Broadcast, {0: np.array([1.0], dtype=np.float)}, { 'mode': 'numpy', 'name': ctc_greedy_decoder_name + '/SequenceMask' }) sequence_mask.in_port(1).connect(seq_mask_shape.out_port(0)) # create CTCGreedyDecoder with the sequence mask instead of sequence length ctc_greedy_decoder.in_port(1).disconnect() ctc_greedy_decoder.in_port(1).connect(sequence_mask.out_port(0)) # remove fill and pack nodes since they are now in unconnected component graph.remove_nodes_from([fill.id, dims.id]) # transform opset CTCGreedyDecoder output to TensorFlow's one that has a shape [N, T] # opset CTCGreedyDecoder has an output with a shape [N, T, 1, 1] squeeze_dec_seq = create_op_with_const_inputs( graph, Squeeze, {1: int64_array([2, 3])}, {'name': sparse_to_dense_name}) squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0)) cast_to_int = Cast(graph, { 'name': sparse_to_dense_name + '/CastToInt', 'dst_type': np.int32 }).create_node() cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0)) # preserve output name from original graph rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (cast_to_int, sparse_to_dense_name)]) # set output of the new sub-graph as a source for SparseToDense consumer sparse_to_dense.out_port(0).get_connection().set_source( cast_to_int.out_port(0)) # cleanup a graph graph.remove_nodes_from([cast.id, sparse_to_dense.id])
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, {'name': 'select_' + node.name}).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select')) ], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', {'in': 0}), ('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32) mem_out = ReadValue(graph, {'name': 'iteration_number', 'variable_id': 'iteration_' + node.name}).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign(graph, {'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name}).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1])}).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='Memory', index=1, shape=int64_array([context_len]))), ('mem_in_data', dict()), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Memory', index=0, shape=int64_array([context_len]))), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: mem_out = Memory( graph, { 'name': 'iteration_number', 'size': 2, 'index': 1, 'id': 'iteration_' + node.name, 'shape': int64_array([context_len]), 'dst_type': np.int32 }).create_node() cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Memory( graph, { 'name': 'iteration_number_out', 'size': 2, 'index': 0, 'id': 'iteration_' + node.name, 'shape': int64_array([context_len]) }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) select_node.in_port(0).connect(input_port) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input( input_port, in_shape[1] * node_t) memory_out = ReadValue(graph, { 'name': pair_name, 'variable_id': node_name + pair_name }).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop( graph, { 'name': 'Memory_crop', 'dim': np.array([in_shape[1] * (node_t - 1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop( graph, { 'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) if not graph.graph['cmd_params'].static_shape: log.error( "Model can not be translated in a reshape-able way.\n" "Model Optimizer key static_shape was turned on to prevent related errors.\n" "There will be no success changing input shapes of the model with the help of " "InferenceEngine reshape method", extra={'is_warning': True}) graph.graph['cmd_params'].static_shape = True graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_pattern(self, graph: Graph, match: dict): const = 0.99 merge = match['merge'] digits = significant_digits() pnorm = Power( graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM', 'significant': digits[0], 'to_significant': digits[1], 'scale': 1, 'shift': 0, 'power': get_power_attr() }).create_node() merge.in_port(0).get_connection().set_destination(pnorm.in_port(0)) in_shape = pnorm.in_port(0).data.get_shape() in_shape = list(in_shape) in_shape.insert(0, 1) reshape1 = Reshape(graph, { 'name': merge.name + '/Reshape_Node1' }).create_node() reshape_dim1 = Const(graph, { 'value': np.array(in_shape), 'name': merge.id + '/Reshape_Dim1' }).create_node() pnorm.out_port(0).connect(reshape1.in_port(0)) reshape1.in_port(1).connect(reshape_dim1.out_port(0)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(pnorm.significant, pnorm.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(const, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) reshape1.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(const, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = ExpOp(graph, { 'name': merge.name + '/Exp_' + idx.__str__() }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) in_shape = pnorm.in_port(0).data.get_shape() in_shape = list(in_shape) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': True, 'in_ports_count': 2, 'shape': in_shape, 'axis': 0, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array(in_shape), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ 'mode is supported for node {}.'.format(node.id) node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') is_packed = False if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): is_packed = True embedding_bag = EmbeddingBagPackedSum(graph, { 'name': node_name }).create_node() else: embedding_bag = EmbeddingBagOffsetsSum(graph, { 'name': node_name }).create_node() node.in_port(2).get_connection().set_destination( embedding_bag.in_port(2)) rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination( embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination( embedding_bag.in_port(1)) node.out_port(0).get_connection().set_source( embedding_bag.out_port(0)) if len(node.in_ports() ) == 4 and not node.in_port(3).disconnected(): if is_packed: node.in_port(3).get_connection().set_destination( embedding_bag.in_port(2)) else: # connect per_sample_weights node.in_port(3).get_connection().set_destination( embedding_bag.in_port(4)) weights_shape_node = Shape( graph, { 'name': node_name + '/WeightsShape' }).create_node() weights_rank_node = Rank(graph, { 'name': node_name + '/WeightsRank' }).create_node() last_dim_node = get_canonical_axis_index_node( weights_rank_node, -1) weights_last_dim = get_shape_values_by_indices_node( weights_shape_node, last_dim_node) weights_first_dim = node_to_get_shape_value_of_indices( weights_shape_node, [0]) zero_col_node = create_op_with_const_inputs( graph, Broadcast, {0: int64_array([0])}, {'name': node_name + '/Broadcast'}) zero_col_node.in_port(1).connect( weights_last_dim.out_port(0)) default_embeddings_node = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, {'name': node_name + '/Unsqueeze'}) default_embeddings_node.in_port(0).connect( zero_col_node.out_port(0)) # expand embedding table with zeros weights_concat = Concat( graph, { 'axis': 0, 'in_ports_count': 2, 'name': node_name + '/Concat' }).create_node() embedding_bag.in_port(0).get_connection().set_destination( weights_concat.in_port(0)) weights_concat.in_port(0).get_connection().add_destination( weights_shape_node.in_port(0)) weights_concat.in_port(0).get_connection().add_destination( weights_rank_node.in_port(0)) weights_concat.in_port(1).connect( default_embeddings_node.out_port(0)) weights_concat.out_port(0).connect( embedding_bag.in_port(0)) # point default index to expanded part of embedding table weights_first_dim.out_port(0).connect( embedding_bag.in_port(3))
def replace_timeheightconv(self, graph: Graph, node: Node): req_time_offsets = node.soft_get('time_offsets') offsets = node.soft_get("offsets", [[]]) all_time_offsets = list(set(offsets[:, 0])) all_time_offsets.sort() in_name = node.soft_get('name', node.id) rename_node(node, in_name + '/to_delete') # create memoryoffsets for context gathering # we need concat if time offsets more than 1 concat = Concat(graph, attrs={ 'name': in_name + '/Concat', 'in_ports_count': len(all_time_offsets) }).create_node() i = 0 for t in all_time_offsets: # if time offset included in required_time_offsets we don't need default value has_default = t not in req_time_offsets memoff = MemoryOffset(graph, attrs={ 'name': in_name + '/MemoryOffset_' + str(i), 't': t, 'has_default': has_default, 'splitted': False, 'pair_name': in_name + '/MemoryOffset_pair_' + str(i) }).create_node() concat.in_port(i).connect(memoff.out_port(0)) memoff.in_port(0).connect(node.in_port(0).get_source()) i = i + 1 stride = node.soft_get("height_subsample", 1) kernel = int64_array([0, 0]) kernel[0] = len(set(offsets[:, 0])) kernel[1] = len(set(offsets[:, 1])) pad_h = int64_array([0, 0]) pad_h[0] = -min(offsets[:, 1]) if min(offsets[:, 1]) < 0 else 0 pad_h[1] = stride * node.height_out - (node.height_in - max([max(offsets[:, 1]), 0])) dilation_t = (max(offsets[:, 0]) - min(offsets[:, 0])) / ( kernel[0] - 1) if kernel[0] > 1 else 1 dilation_h = (max(offsets[:, 1]) - min(offsets[:, 1])) / ( kernel[1] - 1) if kernel[0] > 1 else 1 conv_attrs = { 'name': in_name, 'output': node['out_channels'], 'height_in': node.height_in, 'bias_term': None, 'pad': int64_array([[0, 0], [0, 0], [0, 0], pad_h]), 'pad_spatial_shape': int64_array([[0, 0], pad_h]), 'dilation': int64_array([1, 1, dilation_t, dilation_h]), 'kernel': int64_array( [node.out_channels, node.in_channels, kernel[0], kernel[1]]), 'stride': int64_array([1, 1, 1, stride]), 'kernel_spatial': kernel, 'input_feature_channel': 1, 'output_feature_channel': 0, 'channel_dims': int64_array([1]), 'spatial_dims': int64_array([2, 3]), 'batch_dims': int64_array([0]), 'kernel_spatial_idx': int64_array([2, 3]), 'group': 1, 'reshape_kernel': True, 'bias_addable': True, } conv = Convolution(graph, attrs=conv_attrs).create_node() conv.in_port(0).connect(concat.out_port(0)) conv.in_port(1).connect(node.in_port(1).get_source()) # change layout for weights from OHWI to OIHW # in future should be replaced by common Permute mechanics weights = conv.in_port(1).get_source().node.value weights = weights.reshape( int64_array([node.out_channels, -1, node.in_channels])) weights = weights.transpose(int64_array([0, 2, 1])) weights = weights.flatten() conv.in_port(1).get_source().node.value = weights conv.in_port(2).connect(node.in_port(2).get_source()) node.out_port(0).get_connection().set_source(conv.out_port(0)) graph.remove_node(node.id)