def test_broadcast(self, data, target_shape, axes_mapping=None, mode='numpy', ref_out=None, test_raising=False): if ref_out is not None: input = valued_const_with_data('data', int64_array(data)) else: input = shaped_data('data', int64_array(data)) nodes = { **input, **valued_const_with_data('target_shape', int64_array(target_shape)), **regular_op_with_empty_data('broadcast', {'op': 'Broadcast', 'mode': mode}), } edges = [('data', 'broadcast'), ('target_shape', 'broadcast'), ('broadcast', 'broadcast_d')] if axes_mapping is not None: nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping))) edges.append(('axes_mapping', 'broadcast')) graph = build_graph(nodes, edges) broadcast_node = Node(graph, 'broadcast') if test_raising: self.assertRaises(AssertionError, Broadcast.infer, broadcast_node) return Broadcast.infer(broadcast_node) if ref_out is not None: self.assertTrue(np.array_equal(broadcast_node.out_node().value, np.array(ref_out))) else: self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
def find_and_replace_pattern(self, graph: Graph): for fill_node in graph.get_op_nodes(op='Fill'): name = fill_node.soft_get('name', fill_node.id) broadcast_node = Broadcast(graph, { 'name': name + '/Broadcast' }).create_node() fill_node.in_port(0).get_connection().set_destination( broadcast_node.in_port(1)) fill_node.in_port(1).get_connection().set_destination( broadcast_node.in_port(0)) fill_node.out_port(0).get_connection().set_source( broadcast_node.out_port(0)) for fill_node in graph.get_op_nodes(op='ConstantFill'): name = fill_node.soft_get('name', fill_node.id) assert fill_node.has_valid('fill_value') assert fill_node.has_and_set('input_as_shape') const = Const(graph, { 'value': np.array(fill_node.fill_value), 'name': name + '/value' }).create_node() broadcast_node = Broadcast(graph, { 'name': name + '/Broadcast' }).create_node() fill_node.in_port(0).get_connection().set_destination( broadcast_node.in_port(1)) const.out_port(0).connect(broadcast_node.in_port(0)) fill_node.out_port(0).get_connection().set_source( broadcast_node.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for fill_node in graph.get_op_nodes(op='Fill'): broadcast_node = Broadcast(graph, {}).create_node() fill_node.in_port(0).get_connection().set_destination( broadcast_node.in_port(1)) fill_node.in_port(1).get_connection().set_destination( broadcast_node.in_port(0)) fill_node.out_port(0).get_connection().set_source( broadcast_node.out_port(0))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node() node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): # TODO: remove casting once we start to support I64 model input # cast default value to I32 due limitation about I64 input support # so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) broadcast_node.in_port(0).connect(cast_default_value.out_port(0)) else: broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0)} ).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def find_and_replace_pattern(self, graph: Graph): for const_of_shape_node in graph.get_op_nodes(op='ConstantOfShape'): broadcast_node = Broadcast(graph, {'name': const_of_shape_node.name + '/Broadcast'}).create_node() const_of_shape_node.in_port(0).get_connection().set_destination(broadcast_node.in_port(1)) broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue', 'value': const_of_shape_node.fill_value} ).create_node().out_port(0)) const_of_shape_node.out_port(0).get_connection().set_source(broadcast_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] shapes = [in_node.shape for _, in_node in node.in_nodes().items()] out_shape = node.out_node().shape broadcast_name = node.name + '/Broadcast/' for i, shape in enumerate(shapes): if not np.array_equal(shape, out_shape): # Add Broadcast op for this input # Need to create additional Const op for shape new_shape = Const(graph, {'name': broadcast_name + 'Shape', 'value': out_shape.copy()}).create_node() broadcast_axis = Const(graph, { 'name': broadcast_name + 'Axis', 'value': np.array(range(len(out_shape)), dtype=np.int64)} ).create_node() broadcast = Broadcast(graph, {'name': broadcast_name}).create_node() node.in_port(i).get_connection().set_destination(broadcast.in_port(0)) broadcast.in_port(1).connect(new_shape.out_port(0)) broadcast.in_port(2).connect(broadcast_axis.out_port(0)) broadcast.out_port(0).connect(node.in_port(i))
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float): # create init_graph connected to ReadValue graph = input_out_port.node.graph input_name = input_out_port.node.name shape_of_input = Shape(graph, { 'name': 'shape/' + input_name }).create_node() shape_of_input.in_port(0).connect(input_out_port) dim_for_get_batch = Const( graph, { 'name': 'dim/crop_batch/' + shape_of_input.name, 'value': int64_array([1]), 'shape': int64_array([1]) }).create_node() get_batch = Crop( graph, { 'name': 'crop_batch/' + shape_of_input.name, 'axis': int64_array([0]), 'offset': int64_array([0]) }).create_node() get_batch.in_port(0).connect(shape_of_input.out_port(0)) get_batch.in_port(1).connect(dim_for_get_batch.out_port(0)) mem_shape_2nd_dim = Const( graph, { 'name': 'gifo_r_weights_shape/' + input_name, 'value': int64_array([second_dim]), 'shape': int64_array([1]) }).create_node() mem_shape = Concat( graph, { 'name': 'gather_memory_shape/' + input_name, 'axis': 0, 'in_ports_count': 2 }).create_node() mem_shape.in_port(0).connect(get_batch.out_port(0)) mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0)) fill_value = Const( graph, { 'name': 'fill_value/' + input_name, 'value': np.array([0.0], precision), 'shape': int64_array([1]) }).create_node() init_value_prev_lstm_output = Broadcast(graph, { 'name': 'init_value/' + input_name, }).create_node() init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0)) init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0)) return init_value_prev_lstm_output
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': np.array([-2])}).create_node() end = Const(graph, {'value': np.array([-1])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) concat_value = Const(graph, {'value': np.array([4])}).create_node() shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': np.array(0)}).create_node() shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0)) concat_value.out_port(0).connect(shape_concat.in_port(1)) variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node() node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0)) else: broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0)} ).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): reshape_classes_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name='do_reshape_classes'), match.single_input_node(1)[0]) initial_priors_node = match.single_input_node(2)[0] priors_name = initial_priors_node.soft_get('name', initial_priors_node.id) # model calculates identical prior boxes for each batch, so we take first slice of them begin = Const(graph, { 'value': np.array([0, 0, 0], dtype=np.int32) }).create_node() end = Const(graph, { 'value': np.array([1, 0, 0], dtype=np.int32) }).create_node() stride = Const(graph, { 'value': np.array([1, 1, 1], dtype=np.int32) }).create_node() priors_node = StridedSlice( graph, { 'name': priors_name + '/0_batch_slice', 'begin_mask': np.array([1, 1, 1], dtype=np.int32), 'end_mask': np.array([1, 0, 0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32), 'shrink_axis_mask': np.array([0], dtype=np.int32), 'ellipsis_mask': np.array([0], dtype=np.int32) }).create_node() initial_priors_node.out_port(0).connect(priors_node.in_port(0)) begin.out_port(0).connect(priors_node.in_port(1)) end.out_port(0).connect(priors_node.in_port(2)) stride.out_port(0).connect(priors_node.in_port(3)) placeholders = graph.get_op_nodes(type='Parameter') assert len(placeholders) == 1, "{} replacer requires model to have one Placeholder, but current model has " \ "{} placeholders".format(self.replacement_id, len(placeholders)) placeholder = placeholders[0] # scale prior boxes to the [0, 1] interval node_with_scales_for_prior_boxes = self.placeholder_scales(placeholder) priors_scale_node = Mul(graph, {'name': 'scale_priors'}).create_node() broadcast = Broadcast(graph, { 'name': 'scales_broadcast' }).create_node() shape_of_priors = Shape(graph, {'name': 'priors_shape'}).create_node() priors_node.out_port(0).connect(shape_of_priors.in_port(0)) broadcast.in_port(1).connect(shape_of_priors.out_port(0)) broadcast.in_port(0).connect( node_with_scales_for_prior_boxes.out_port(0)) priors_scale_node.in_port(0).connect(priors_node.out_port(0)) priors_scale_node.in_port(1).connect(broadcast.out_port(0)) try: variance = match.custom_replacement_desc.custom_attributes[ 'variance'] except: raise Error( 'There is no variance attribute in {} replacement config file `custom_attributes`' ''.format(self.replacement_id)) priors = self.append_variances(priors_scale_node, variance) # calculate prior boxes widths and heights split_node = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(2), 2: int64_array([1, 1, 1, 1]) }, {'out_ports_count': 4}, priors_scale_node) priors_width_node = Sub( graph, dict(name=split_node.name + '/sub_2-0_')).create_node([ (split_node, 2), (split_node, 0) ]) priors_height_node = Sub(graph, dict(name=split_node.name + '/sub_3-1_')).create_node([ (split_node, 3), (split_node, 1) ]) # concat weights and heights into a single tensor and multiple with the box coordinates regression values # WA with 3 Concats instead of 1 for keeping model reshapable # concat_width_height_node = Concat(graph, {'name': 'concat_priors_width_height', 'axis': -1, # 'in_ports_count': 4}).create_node( # [priors_width_node, priors_height_node, priors_width_node, priors_height_node]) concat_1 = Concat(graph, { 'name': 'concat_width_height', 'axis': -1, 'in_ports_count': 2 }).create_node([priors_width_node, priors_height_node]) concat_2 = Concat(graph, { 'name': 'concat_width_height_width', 'axis': -1, 'in_ports_count': 2 }).create_node([concat_1, priors_width_node]) concat_width_height_node = Concat(graph, { 'name': 'concat_priors_width_height', 'axis': -1, 'in_ports_count': 2 }).create_node([concat_2, priors_height_node]) applied_width_height_regressions_node = Mul(graph, { 'name': 'final_regressions' }).create_node( [concat_width_height_node, match.single_input_node(0)[0]]) # reshape to 2D tensor as Inference Engine Detection Output layer expects reshape_regression_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(name='reshape_regression'), applied_width_height_regressions_node) detection_output_op = DetectionOutput( graph, match.custom_replacement_desc.custom_attributes) # get nms from the original network iou_threshold = None nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') if len(nms_nodes) > 0: # it is highly unlikely that for different classes NMS has different # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) iou_threshold = nms_nodes[0].in_node(3).value if iou_threshold is None: raise Error( 'During {} `iou_threshold` was not retrieved from RetinaNet graph' .format(self.replacement_id)) detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, variance_encoded_in_target=0, background_label_id=1000)) return {'detection_output_node': detection_output_node}
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))
def extract(cls, node): Broadcast.update_node_stat(node, {'mode': 'bidirectional'}) return cls.enabled
def extract(node): Broadcast.update_node_stat(node) return __class__.enabled
def extract(cls, node: Node): Broadcast.update_node_stat(node, attrs={'mode': 'numpy'}) return cls.enabled
def find_and_replace_pattern(self, graph: Graph): for embedding_segments_mean in graph.get_op_nodes( op='EmbeddingSegmentsMean'): embedding_segments_mean_name = embedding_segments_mean.soft_get( 'name', embedding_segments_mean.id) embedding_table_input = embedding_segments_mean.in_port(0) segment_ids_input = embedding_segments_mean.in_port(2) num_segments_input = embedding_segments_mean.in_port(3) # TODO: support EmbeddingSegmentsMean with specified weights vector. # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion # transformations do not handle it either if embedding_segments_mean.is_in_port_connected(5): return # 1. compute indices membership matrix, i.e. which indices belong to some object # the shape of this matrix is [num_segments, num_indices] non_norm_range_1_to_num_segments = create_op_with_const_inputs( graph, Range, { 0: int64_array(0), 2: int64_array(1) }, { 'name': embedding_segments_mean_name + '/Range1ToNumSegments', 'output_type': np.int64 }) num_segments_input.get_connection().add_destination( non_norm_range_1_to_num_segments.in_port(1)) range_1_to_num_segments = ConvertLike(graph, { 'name': embedding_segments_mean_name + '/Range1ToNumSegmentsNorm' }).create_node() range_1_to_num_segments.in_port(0).connect( non_norm_range_1_to_num_segments.out_port(0)) num_segments_input.get_connection().add_destination( range_1_to_num_segments.in_port(1)) unsqueeze_range_1_to_num_segments = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(1)}, { 'name': embedding_segments_mean_name + '/Range1ToNumSegmentsUnsqueeze' }) unsqueeze_range_1_to_num_segments.in_port(0).connect( range_1_to_num_segments.out_port(0)) unsqueeze_segment_ids = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, { 'name': embedding_segments_mean_name + '/SegmentIdsUnsqueeze' }) segment_ids_input.get_connection().add_destination( unsqueeze_segment_ids.in_port(0)) boolean_membership_matrix = Equal(graph, { 'name': embedding_segments_mean_name + '/BooleanMembershipMatrix' }).create_node() boolean_membership_matrix.in_port(0).connect( unsqueeze_range_1_to_num_segments.out_port(0)) boolean_membership_matrix.in_port(1).connect( unsqueeze_segment_ids.out_port(0)) shape_of_membership_matrix = Shape(graph, { 'name': embedding_segments_mean_name + '/ShapeOfMembershipMatrix' }).create_node([boolean_membership_matrix]) one_scalar_constant = Const( graph, { 'name': embedding_segments_mean_name + '/OneScalar', 'value': int64_array([1]) }).create_node() one_constant = Broadcast(graph, { 'name': embedding_segments_mean_name + '/One' }).create_node([one_scalar_constant, shape_of_membership_matrix]) zero_constant = Const( graph, { 'name': embedding_segments_mean_name + '/Zero', 'value': int64_array(0) }).create_node() membership_matrix = Select( graph, { 'name': embedding_segments_mean_name + '/MembershipMatrix', 'auto_broadcast': 'numpy' }).create_node( [boolean_membership_matrix, one_constant, zero_constant]) # 2. compute a number of indices belong to each object from the batch # it computes the normalization coefficients num_indices_per_object = create_op_with_const_inputs( graph, ReduceSum, {1: int64_array(1)}, { 'name': embedding_segments_mean_name + '/NumIndicesPerObject' }) num_indices_per_object.in_port(0).connect( membership_matrix.out_port(0)) # 3. replace zero coefficient (zero number of indices belong to an object) with one # because for such object the single default embedding vector is used where_zero_number = Equal(graph, { 'name': embedding_segments_mean_name + '/WhereZeroIndicesNumber' }).create_node([num_indices_per_object, zero_constant]) normalized_num_indices_per_object = Select( graph, { 'name': embedding_segments_mean_name + '/NormNumIndicesPerObject', 'auto_broadcast': 'numpy' }).create_node([ where_zero_number, one_scalar_constant, num_indices_per_object ]) # 4. cast normalized_num_indices_per_object to the same type as embedding vector table norm_coefficients = ConvertLike( graph, { 'name': embedding_segments_mean_name + '/NormCoefficients' }).create_node() norm_coefficients.in_port(0).connect( normalized_num_indices_per_object.out_port(0)) embedding_table_input.get_connection().add_destination( norm_coefficients.in_port(1)) # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum embedding_segments_sum = EmbeddingSegmentsSum( graph, { 'name': embedding_segments_mean_name + '/EmbeddingSegmentsSum' }).create_node() for in_port in embedding_segments_mean.in_ports(): if embedding_segments_mean.is_in_port_connected(in_port): embedding_segments_mean.in_port( in_port).get_connection().set_destination( embedding_segments_sum.in_port(in_port)) # 6. normalize EmbeddingSegmentSum results by computed coefficients result_node = Div(graph, { 'name': embedding_segments_mean_name + '/Div' }).create_node([embedding_segments_sum, norm_coefficients]) embedding_segments_mean.out_port(0).get_connection().set_source( result_node.out_port(0)) rename_nodes([(embedding_segments_mean, embedding_segments_mean_name + '/AbandonedName'), (result_node, embedding_segments_mean_name)]) graph.remove_nodes_from([embedding_segments_mean.id])
def find_and_replace_pattern(self, graph: Graph): reverse_nodes = graph.get_op_nodes(op='Reverse') for reverse in reverse_nodes: reverse_name = reverse.soft_get('name', reverse.id) assert reverse.in_port(1).disconnected() assert reverse.has_valid('axis') in_shape_rank = len(reverse.in_port(0).data.get_shape()) # 1. Add new dimension as batch for rank = 1 to have batch != seq_axis if in_shape_rank == 1: unsq_node = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': reverse_name + "/Unsqueeze"}) reverse.in_port(0).get_source().connect(unsq_node.in_port(0)) new_in = unsq_node.out_port(0) batch_axis = 0 seq_axis = 1 else: new_in = reverse.in_port(0).get_source() seq_axis = reverse['axis'] batch_axis = 0 if seq_axis != 0 else 1 # 2. For ReverseSequence 1-port input is seq_lengths => create this input node as # shape[seq_axis] broadcasted to shape[batch_axis] # in ---> ShapeOf ----> Gather(seq_axis) ----> Broadcast-----> # | | # | -------> Gather(batch_axis)----------| shape_node = Shape(graph, { 'name': reverse_name + "/Shape" }).create_node() new_in.connect(shape_node.in_port(0)) seq_axis_node = node_to_get_shape_value_of_indices( shape_node, [seq_axis]) batch_node = node_to_get_shape_value_of_indices( shape_node, [batch_axis]) broadcast_node = Broadcast(graph, { 'name': reverse_name + "/Broadcast" }).create_node() broadcast_node.in_port(0).connect(seq_axis_node.out_port(0)) broadcast_node.in_port(1).connect(batch_node.out_port(0)) # 3. Create new ReverseSequence node and reconnect all inputs/outputs to it rename_node(reverse, reverse_name + '/to_delete') reverse_sequence = ReverseSequence( graph, { 'name': reverse_name, 'seq_axis': seq_axis, 'batch_axis': batch_axis }).create_node() reverse_sequence.in_port(0).connect(new_in) reverse_sequence.in_port(1).connect(broadcast_node.out_port(0)) # 4. remove added dimension for rank = 1 if in_shape_rank == 1: rename_node(reverse_sequence, reverse_name + '/ReverseSequence') squeeze_node = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': reverse_name}) squeeze_node.in_port(0).connect(reverse_sequence.out_port(0)) reverse.out_port(0).get_connection().set_source( squeeze_node.out_port(0)) else: reverse.out_port(0).get_connection().set_source( reverse_sequence.out_port(0)) # 5. Delete old Reverse node graph.remove_nodes_from([reverse.id for reverse in reverse_nodes])
def extract(cls, node): Broadcast.update_node_stat(node) return cls.enabled