def replace_op(self, graph: Graph, node: Node): # Add new nodes mvn = MVN(graph, { 'eps': node.epsilon, 'name': node.name + '/Ins_Norm/MVN_', }).create_node() mul = Mul(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/mul_' }).create_node() add = Add(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/add_' }).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(mvn.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.out_port(0).connect(mul.in_port(0)) mul.out_port(0).connect(add.in_port(0)) return [add.id]
def _fused_batch_norm_decomposition(graph: Graph, tinput: Port, toutput: Port, gamma: Port, beta: Port, mean: np.ndarray, variance: np.ndarray, can_be_fused=True): """ This is common function for TF, Caffe and MXNet It creates Mul->Add->Mul->Add sub graph """ batch_norm_name = tinput.get_connection().get_destination().node.name # Create first Mul & Add operations mul1_node = Mul( graph, dict(name=batch_norm_name + "/mean", can_be_fused=can_be_fused)).create_node() add1_node = Add( graph, dict(name=batch_norm_name + "/variance", can_be_fused=can_be_fused)).create_node() const_mul1_node = Const(graph, dict(name="data_mul_", value=np.array(mean))).create_node() const_add1_node = Const(graph, dict(name="data_add_", value=np.array(variance))).create_node() # Broadcast const from scalar # We can broadcast only when const.value is scalar if gamma.data.get_shape()[0] != gamma.data.get_value().shape[0]: value = gamma.data.get_value() value.resize(gamma.data.get_shape()).fill(value[0]) gamma.data.set_value(value) # Create second Mul & Add mul2_node = Mul( graph, dict(name=batch_norm_name + "/gamma", can_be_fused=can_be_fused)).create_node() add2_node = Add( graph, dict(name=batch_norm_name + "/beta", can_be_fused=can_be_fused)).create_node() # Connect edges Mul1->Add1->Mul2->Add2 tinput.get_connection().set_destination(mul1_node.in_port(0)) mul1_node.in_port(1).get_connection().set_source( const_mul1_node.out_port(0)) add1_node.in_port(0).get_connection().set_source(mul1_node.out_port(0)) add1_node.in_port(1).get_connection().set_source( const_add1_node.out_port(0)) mul2_node.in_port(0).get_connection().set_source(add1_node.out_port(0)) gamma.get_connection().set_destination(mul2_node.in_port(1)) add2_node.in_port(0).get_connection().set_source(mul2_node.out_port(0)) beta.get_connection().set_destination(add2_node.in_port(1)) toutput.get_connection().set_source(add2_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): resize_node = match['resize'] if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value or \ match['mul_1'].in_node(1).value != match['mul_3'].in_node(1).value: log.info( 'Pattern matched around resize op {} has different scale values.' .format(resize_node.name)) return interpolate_node = Interpolate( graph, { 'name': resize_node.name + '/Interpolate', 'mode': resize_node.mode, 'axes': int64_array([2, 3, 4]) }).create_node() scale = match['mul_1'].in_node(1).value scale_value = int64_array([scale, scale, scale]) scale_const = Const(graph, { 'value': scale_value, 'name': resize_node.name + '/Scale' }).create_node() interpolated_shape = Mul(graph, { 'name': resize_node.name + '/OutputShape' }).create_node() match['slice'].out_port(0).connect(interpolated_shape.in_port(0)) scale_const.out_port(0).connect(interpolated_shape.in_port(1)) resize_node.in_port(0).get_connection().set_destination( interpolate_node.in_port(0)) interpolated_shape.out_port(0).connect(interpolate_node.in_port(1)) resize_node.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = np.array([get_split_scale(split)], dtype=np.float32) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def div_to_mul_replacement(div: Node): # we execute this transformation for V10 IR later on middle phase despite graph_condition # so we prevent Div replacement on shape-calculating sub-graphs if div.in_port(0).data.get_value() is not None and div.in_port( 1).data.get_value() is not None: return graph = div.graph name = div.soft_get('name', div.id) # keep Mul name the same as Div -- because of mathematical equality of output tensors rename_node(node=div, name=name + '/to_be_removed') # reconnect Div in(out)puts to Mul mul = Mul(graph, {'name': name}).create_node() rename_node(mul, name) div.in_port(0).get_connection().set_destination(mul.in_port(0)) div.in_port(1).get_connection().set_destination(mul.in_port(1)) div.out_port(0).get_connection().set_source(mul.out_port(0)) # restore mathematical equivalence to Div operation: Div(A, B) = Mul(A, Pow(B, -1)) reciprocal = create_op_with_const_inputs( graph, Pow, {1: np.float64(-1)}, {'name': name + '/reciprocal_'}) mul.in_port(1).get_connection().insert_node(reciprocal)
def replace_op(self, graph: Graph, node: Node): # Add new nodes const = Const(graph, dict(value=np.array(-1, dtype=np.int32))).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(negate.in_port(0)) const.out_port(0).connect(negate.in_port(1)) node.in_port(0).get_connection().set_destination(add.in_port(1)) negate.out_port(0).connect(add.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [add.id]
def replace_pattern(graph: Graph, match: dict): node = match['normalize'] assert node.in_port(0).data.get_shape().size in [2, 3, 4] assert node.has_valid('across_spatial') assert node.has_valid('channel_shared') assert node.has_valid('eps') if 'bin' in node.in_edge(1): del node.in_edge(1)['bin'] weights = node.in_port(1).data.get_value() if node.channel_shared or all(weights == weights[0]): node.in_port(1).data.set_value(np.array([weights[0]])) assert weights is not None mul = Mul(graph, { 'name': node.name + '/Normalize_weights_multiplication' }).create_node() node.out_port(0).get_connection().set_source(mul.out_port(0)) node.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().get_source().connect(mul.in_port(1)) node.in_port(1).disconnect() node['type'] = 'NormalizeL2' node['eps_mode'] = 'add' node['force_precision_in_ports'] = {1: 'int64'} axes_val = np.array([1]) if not node.across_spatial else \ np.arange(start=1, stop=node.in_port(0).data.get_shape().size) axes = Const(graph, {'value': axes_val}).create_node() node.in_port(1).connect(axes.out_port(0)) del node['across_spatial'] del node['channel_shared']
def replace_sub_graph(self, graph: Graph, match: dict): op = match['op'] out_port = op.in_port(0).get_source() if op.soft_get('scale', 1) != 1: const = Const(graph, {'value': np.array(op.scale)}).create_node() mul = Mul(graph, {'name': op.name + '/mul_'}).create_node() const.out_port(0).connect(mul.in_port(1)) out_port.connect(mul.in_port(0)) out_port = mul.out_port(0) if op.soft_get('shift', 0) != 0: const = Const(graph, {'value': np.array(op.shift)}).create_node() add = Add(graph, {'name': op.name + '/add_'}).create_node() const.out_port(0).connect(add.in_port(1)) out_port.connect(add.in_port(0)) out_port = add.out_port(0) if op.soft_get('power', 1) != 1: const = Const(graph, {'value': np.array(op.power)}).create_node() pow = Pow(graph, {'name': op.name + '/pow_'}).create_node() const.out_port(0).connect(pow.in_port(1)) out_port.connect(pow.in_port(0)) out_port = pow.out_port(0) op.out_port(0).get_connection().set_source(out_port)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] if not node.has_valid('bias') or (node.has_valid('bias') and node.bias == 1): return # Calculate scale value & create Const op scale_value = np.array(1. / (pow(node.bias, node.beta))) node.alpha /= node.bias const_node = Const( graph, { 'value': scale_value, 'shape': scale_value.shape, 'name': node.name + "/Const_Mul_" }).create_node() # Create Mul node mul_node = Mul(graph, {'name': node.name + "/Mul_"}).create_node() # Connect nodes const_node.out_port(0).connect(mul_node.in_port(1)) node.out_port(0).get_connection().set_source(mul_node.out_port(0)) node.out_port(0).connect(mul_node.in_port(0)) # Delete bias, if it is not deleted it will appear in IR v6 del node['bias']
def replace_sub_graph(self, graph: Graph, match: dict): # This replacer replace ImageScalar operation to Mul->Add sequence # Also it check that weights and biases are good op = match['op'] # Check that weights and biases are not useless has_bias, has_weights = True, True if all([x == 1 for x in np.nditer(op.scale)]): has_weights = False if all([x == 0 for x in np.nditer(op.bias)]): has_bias = False assert len(op.in_ports()) == 1 last_port = op.in_port(0).get_source() # Create Mul & Add nodes if has_weights: mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node() mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node() op.in_port(0).get_connection().set_destination(mul_op.in_port(0)) mul_weights.out_port(0).connect(mul_op.in_port(1)) last_port = mul_op.out_port(0) if has_bias: add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node() add_op = Add(graph, dict(name=op.id + '/add_')).create_node() last_port.get_connection().set_destination(add_op.in_port(0)) add_bias.out_port(0).connect(add_op.in_port(1)) last_port = add_op.out_port(0) op.in_port(0).disconnect() op.out_port(0).get_connection().set_source(last_port)
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = int64_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) slice_begin = Const( graph, dict(name=split_node_name + '/slice_begin_', value=int64_array([axis]))).create_node() slice_end = Const( graph, dict(name=split_node_name + '/slice_end_', value=int64_array([axis + 1]))).create_node() strided_slice_node = StridedSlice( graph, { 'name': split_node_name + '/StridedSlice_', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]), }).create_node([shape_node, slice_begin, slice_end]) strided_slice_node.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate( graph, dict(name=split_node_name + '/Interpolate_', axes=int64_array([axis]), mode='nearest')).create_node() mul_node.out_port(0).connect(interp_node.in_port(1)) match['concat'].out_port(0).get_connection().set_source( interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace_pattern(graph: Graph, match: [str, Node]): node = match['sub'] # Add new nodes negate_const = Const( graph, dict(name=node.name + '/negate_const', value=np.array(-1))).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(negate.in_port(0)) negate_const.out_port(0).connect(add.in_port(1)) node.in_port(0).get_connection().set_destination(add.in_port(1)) negate.out_port(0).connect(add.in_port(0)) node.out_port(0).get_connection().set_source(add.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): quantize = match['quantize'] sum_node = Add(graph, dict()).create_node() const = Const(graph, {'value': np.array(0.5)}).create_node() mul_node = Mul(graph, dict()).create_node() mul_node.in_port(0).connect(sum_node.out_port(0)) mul_node.in_port(1).connect(const.out_port(0)) quantize.in_port(1).get_connection().get_source().connect( sum_node.in_port(0)) quantize.in_port(2).get_connection().get_source().connect( sum_node.in_port(1)) quantize.in_port(1).disconnect() quantize.in_port(2).disconnect() mul_node.out_port(0).connect(quantize.in_port(1)) mul_node.out_port(0).connect(quantize.in_port(2))
def find_and_replace_pattern(self, graph: Graph): for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'): node_name = dequantize_node.soft_get('name', dequantize_node.id) axis = dequantize_node.soft_get('axis', None) scale_y_shape = dequantize_node.in_port(1).data.get_shape() model_data_type = data_type_str_to_np( graph.graph['cmd_params'].data_type) cast = Cast(graph, { 'dst_type': model_data_type, 'name': node_name + '/Cast' }).create_node() dequantize_node.in_port(0).get_connection().set_destination( cast.in_port(0)) mul = Mul(graph, {}).create_node() is_second_port_connected = dequantize_node.is_in_port_connected(2) if is_second_port_connected: sub = Sub(graph, {'name': node_name + '/Sub'}).create_node() cast.out_port(0).connect(sub.in_port(0)) dequantize_node.in_port(2).get_connection().set_destination( sub.in_port(1)) sub.out_port(0).connect(mul.in_port(0)) else: cast.out_port(0).connect(mul.in_port(0)) dequantize_node.in_port(1).get_connection().set_destination( mul.in_port(1)) dequantize_node.out_port(0).get_connection().set_source( mul.out_port(0)) rename_nodes([(dequantize_node, node_name + '/TBD'), (mul, node_name)]) assert scale_y_shape is not None if axis is not None and len( scale_y_shape) > 0 and scale_y_shape[0] > 1: input_shape = cast.in_port(0).data.get_shape() target_shape = np.ones(len(input_shape), np.int64) target_shape[axis] = input_shape[axis] mul_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul'}) mul.in_port(1).get_connection().set_destination( mul_reshape.in_port(0)) mul_reshape.out_port(0).connect(mul.in_port(1)) if is_second_port_connected: sub_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Sub'}) sub.in_port(1).get_connection().set_destination( sub_reshape.in_port(0)) sub_reshape.out_port(0).connect(sub.in_port(1))
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast(graph, dict( name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination(dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, {'name': name + '/output_range'}).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, {'name': name + '/descaled_output_low'}).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/zero_point'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(shift.out_port(0)) mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source(mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
def replace_pattern(graph: Graph, match: dict): log.debug( '================== GNMTBeforeConditionFind ==================') input_sequence_lengths = match['Max'].in_port(0).get_source().node encoder_sequence_lengths = looking_for_op_in_list([ port.node for port in input_sequence_lengths.out_port(0).get_destinations() ], 'Identity') # Looking for Sequence_length node in encoder looks like: # Sequence_length -> CheckSeqLen -> Max -> Maximum -> Minimum check_seq_len = looking_for_op_in_list([ port.node for port in encoder_sequence_lengths.out_port( 0).get_destinations() ], 'Identity') max = looking_for_op_in_list([ port.node for port in check_seq_len.out_port(0).get_destinations() ], 'ReduceMax') maximum = max.out_port(0).get_destinations()[0].node assert maximum.op == 'Maximum' minimum = maximum.out_port(0).get_destinations()[0].node assert minimum.op == 'Minimum' tensor_seq_len = looking_for_op_in_list([ minimum.in_port(port).get_source().node for port in minimum.in_ports() ], 'StridedSlice') # Create node for multiplying seq_len by 2 const = Const(graph, { 'name': 'FakeSeqLenMultiplyer', 'value': np.array(2) }).create_node() mul_op = Mul(graph, {'name': 'FakeSeqLen'}).create_node() const.out_port(0).get_connection().set_destination(mul_op.in_port(1)) tensor_seq_len.out_port(0).get_connection().add_destination( mul_op.in_port(0)) # Connect seq_len * 2 to TensorArray from GNMT loop ta_writes = [ port.node for port in match['Identity_1'].out_port(0).get_destinations() if port.node.op == 'TensorArrayWriteV3' ] for ta_write in ta_writes: ta = ta_write.in_port(0).get_source().node.in_port( 0).get_source().node ta.in_port(0).disconnect() ta.in_port(0).get_connection().set_source(mul_op.out_port(0))
def replace_op(self, graph: Graph, node: Node): # Create nodes const_neg = Const( graph, dict(value=np.array(-1), name=node.name + '/negate_const_')).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() const = Const(graph, {'value': np.array(2)}).create_node() squared = Pow(graph, {'name': node.name + '/squared_'}).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(add.in_port(0)) node.in_port(1).get_connection().set_destination(negate.in_port(0)) const_neg.out_port(0).connect(negate.in_port(1)) negate.out_port(0).connect(add.in_port(1)) add.out_port(0).connect(squared.in_port(0)) const.out_port(0).connect(squared.in_port(1)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [squared.id]
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) # create range of axes for MVN based on `start_axis` and rank of input rank = Rank(graph, {'name': name + '/Rank'}).create_node() rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, { 'name': name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'eps': node.epsilon, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'name': name + '/Ins_Norm/MVN_', }).create_node() node.in_port(0).get_connection().set_destination(mvn.in_port(0)) rng.out_port(0).connect(mvn.in_port(1)) mul = Mul(graph, { 'axis': 1, 'name': name + '/Ins_Norm/mul_' }).create_node() mvn.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) add = Add(graph, { 'axis': 1, 'name': name + '/Ins_Norm/add_' }).create_node() mul.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.in_port(0).get_connection().add_destination(rank.in_port(0)) rng.in_port(1).connect(rank.out_port(0)) rename_nodes([(node, name + '/TBD'), (add, name)]) return [add.id]
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') sqr_node = Mul(graph, {}).create_node() reduce_sum_node = ReduceSum( graph, { 'keep_dims': node.soft_get('keep_dims', 0), 'axis': node.soft_get('axis', None) }).create_node() sqrt_node = create_op_with_const_inputs(graph, Pow, {1: float_array(0.5)}) rename_node(sqrt_node, node_name) # Connect nodes node.in_port(0).get_connection().set_destination(sqr_node.in_port(0)) sqr_node.in_port(0).get_connection().add_destination( sqr_node.in_port(1)) sqr_node.out_port(0).connect(reduce_sum_node.in_port(0)) reduce_sum_node.out_port(0).connect(sqrt_node.in_port(0)) return [sqrt_node.id]
def replace_pattern(graph: Graph, match: dict): node = match['pool'] if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] shape = Shape(graph, {}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([1, -1]) }, {'out_ports_count': 2}, shape) node_pool_stride = Const(graph, { 'value': int64_array([node.pool_stride]) }).create_node() pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1])) pow_node.in_port(0).connect(node_pool_stride.out_port(0)) mul = Mul(graph, {}).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) const_1 = Const(graph, {'value': int64_array([1])}).create_node() concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node() concat.in_port(0).connect(split.out_port(0)) concat.in_port(3).connect(mul.out_port(0)) concat.in_port(2).connect(const_1.out_port(0)) concat.in_port(1).connect(node_pool_stride.out_port(0)) reshape_in = Reshape(graph, { 'name': '/Reshape/' + node.name }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node.name + '/Reshape/'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['normalize'] # rename normalize node since it will be no longer output node after the transformation output_name = node.soft_get('name', node.id) normalizel2_name = output_name + '/normalizel2' rename_node(node, normalizel2_name) assert node.in_port(0).data.get_shape().size in [2, 3, 4] assert node.has_valid('across_spatial') assert node.has_valid('channel_shared') assert node.has_valid('eps') if 'bin' in node.in_edge(1): del node.in_edge(1)['bin'] weights = node.in_port(1).data.get_value() assert weights is not None # in the code below we intentionally use get_source() to get the out port. Because updating the out port will # update the Const node 'value' and 'shape' attributes if node.channel_shared or all(weights == weights[0]): node.in_port(1).get_source().data.set_value(np.array([weights[0]])) else: new_shape = np.ones((len(node.in_port(0).data.get_shape())), dtype=np.int64) new_shape[1] = -1 node.in_port(1).get_source().data.set_value( np.array(weights).reshape(new_shape)) mul = Mul(graph, {'name': output_name}).create_node() rename_node(mul, output_name) if not node.across_spatial: axes = int64_array([1]) else: axes = int64_array( np.arange(start=1, stop=node.in_port(0).data.get_shape().size)) normalizel2 = create_op_with_const_inputs(graph, NormalizeL2Op, {1: axes}, { 'eps_mode': 'add', 'eps': node.eps }) node.out_port(0).get_connection().set_source(mul.out_port(0)) node.in_port(1).get_connection().get_source().connect(mul.in_port(1)) normalizel2.out_port(0).connect(mul.in_port(0)) node.in_port(0).get_connection().set_destination( normalizel2.in_port(0))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ThresholdedRelu'): name = node.soft_get('name', node.id) greater = create_op_with_const_inputs(graph, Greater, {1: float_array([node.alpha])}) greater.in_port(0).connect(node.in_port(0).get_source()) float_greater = Cast(graph, {'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() greater.out_port(0).connect(float_greater.in_port(0)) mul = Mul(graph, {}).create_node() node.out_port(0).get_connection().set_source(mul.out_port(0)) mul.in_port(0).connect(node.in_port(0).get_source()) mul.in_port(1).connect(float_greater.out_port(0)) rename_nodes([(node, name + '/TBR'), (mul, name)])
def replace_pattern(graph: Graph, match: [str, Node]): node = match['div'] power_of_exponent = Const(graph, { 'value': np.float64(-1) }).create_node() reciprocal = Pow(graph, { 'name': node.name + '/reciprocal_' }).create_node() mul = Mul(graph, {'name': node.name + '/mul_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(reciprocal.in_port(0)) power_of_exponent.out_port(0).connect(reciprocal.in_port(1)) node.in_port(0).get_connection().set_destination(mul.in_port(1)) reciprocal.out_port(0).connect(mul.in_port(0)) node.out_port(0).get_connection().set_source(mul.out_port(0))
def replace_pattern(graph: Graph, match: dict): log.debug('================== GNMTBeforeConditionFind ==================') input_sequence_lengths = match['Max'].in_port(0).get_source().node encoder_sequence_lengths = looking_for_op_in_list([port.node for port in input_sequence_lengths.out_port(0).get_destinations()], 'Identity') # Looking for Sequence_length node in encoder looks like: # Sequence_length -> CheckSeqLen -> Max -> Maximum -> Minimum check_seq_len = looking_for_op_in_list([port.node for port in encoder_sequence_lengths.out_port(0).get_destinations()], 'Identity') max = looking_for_op_in_list([port.node for port in check_seq_len.out_port(0).get_destinations()], 'ReduceMax') maximum = max.out_port(0).get_destinations()[0].node assert maximum.op == 'Maximum' minimum = maximum.out_port(0).get_destinations()[0].node assert minimum.op == 'Minimum' tensor_seq_len = looking_for_op_in_list([minimum.in_port(port).get_source().node for port in minimum.in_ports()], 'StridedSlice') # Create node for multiplying seq_len by 2 const = Const(graph, {'name': 'FakeSeqLenMultiplyer', 'value': np.array(2)}).create_node() mul_op = Mul(graph, {'name': 'FakeSeqLen'}).create_node() const.out_port(0).get_connection().set_destination(mul_op.in_port(1)) tensor_seq_len.out_port(0).get_connection().add_destination(mul_op.in_port(0)) # Connect seq_len * 2 to TensorArray from GNMT loop ta_writes = [port.node for port in match['Identity_1'].out_port(0).get_destinations() if port.node.op == 'TensorArrayWriteV3'] for ta_write in ta_writes: ta = ta_write.in_port(0).get_source().node.in_port(0).get_source().node ta.in_port(0).disconnect() ta.in_port(0).get_connection().set_source(mul_op.out_port(0)) if not graph.graph['cmd_params'].static_shape: log.error( "Model can not be translated in a reshape-able way.\n" "Model Optimizer key static_shape was turned on to prevent related errors.\n" "There will be no success changing input shapes of the model with the help of " "InferenceEngine reshape method", extra={'is_warning': True}) graph.graph['cmd_params'].static_shape = True
def replace_pattern(graph: Graph, match: dict): node = match['normalize'] assert node.in_port(0).data.get_shape().size in [2, 3, 4] assert node.has_valid('across_spatial') assert node.has_valid('channel_shared') assert node.has_valid('eps') if 'bin' in node.in_edge(1): del node.in_edge(1)['bin'] weights = node.in_port(1).data.get_value() assert weights is not None # in the code below we intentionally use get_source() to get the out port. Because updating the out port will # update the Const node 'value' and 'shape' attributes if node.channel_shared or all(weights == weights[0]): node.in_port(1).get_source().data.set_value(np.array([weights[0]])) else: new_shape = np.ones((len(node.in_port(0).data.get_shape())), dtype=np.int64) new_shape[1] = -1 node.in_port(1).get_source().data.set_value( np.array(weights).reshape(new_shape)) mul = Mul(graph, { 'name': node.name + '/Normalize_weights_multiplication' }).create_node() node.out_port(0).get_connection().set_source(mul.out_port(0)) node.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().get_source().connect(mul.in_port(1)) node.in_port(1).disconnect() node['type'] = 'NormalizeL2' node['eps_mode'] = 'add' node['force_precision_in_ports'] = {1: 'int64'} axes_val = np.array([1]) if not node.across_spatial else \ np.arange(start=1, stop=node.in_port(0).data.get_shape().size) axes = Const(graph, {'value': axes_val}).create_node() node.in_port(1).connect(axes.out_port(0)) del node['across_spatial'] del node['channel_shared']
def replace_pattern(graph: Graph, match: dict): node = match['conv'] node_name = node.soft_get('name', node.id) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([1, -1]) }, { 'name': shape.name + '/split_batch', 'out_ports_count': 2 }, shape) pow_node = create_op_node_with_second_input( graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'}) conv_patch_stride = Const( graph, { 'value': int64_array([node.patch_stride]), 'name': node_name + '/patch_stride/' }).create_node() pow_node.in_port(0).connect(conv_patch_stride.out_port(0)) mul = Mul(graph, { 'name': node_name + '/mul_inverse_stride_h' }).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, {2: int64_array([1])}, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(split.out_port(0)) concat.in_port(1).connect(mul.out_port(0)) concat.in_port(3).connect(conv_patch_stride.out_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) split_node_axis = Const(graph, {'value': np.int64(1)}).create_node() split_node = Split(graph, { 'name': 'Split_lstm_input_', 'num_splits': 5 }).create_node() node.in_port(0).get_connection().set_destination(split_node.in_port(0)) split_node.in_port(1).connect(split_node_axis.out_port(0)) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = {'name': 'i_scaleshift', 'bias_term': False} i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) split_node.out_port(4).connect(i_scale.in_port(0)) sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = {'name': 'f_scaleshift', 'bias_term': False} f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) split_node.out_port(4).connect(f_scale.in_port(0)) sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = {'name': 'o_scaleshift', 'bias_term': False} o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Mul(graph, { 'name': 'prod_o_c_t_tanh_' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, {'name': 'Concat_c_m'}).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert scales.shape == (4, ) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[2] width_scale = scales[3] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() factor = Const(graph, { 'value': np.array([height_scale, width_scale]) }).create_node() shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin = Const(graph, { 'value': int64_array([get_height_dim(layout, input_shape_rank)]) }).create_node() else: begin = Const(graph, { 'value': int64_array([get_depth_dim(layout, input_shape_rank)]) }).create_node() end = Const(graph, { 'value': int64_array([get_width_dim(layout, input_shape_rank) + 1]) }).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() ss = StridedSlice( graph, { 'name': upsample.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }).create_node() mul = Mul(graph, { 'name': upsample.name + '/factor_mul_' }).create_node() source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(mul.in_port(0)) factor.out_port(0).connect(mul.in_port(1)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) resample_op = Interpolate( graph, dict(name='Interpolate/{}'.format(upsample.name), axes=axes, mode=upsample.attrs()['mode'], antialias=0, convert_to_resample=True)).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(resample_op.in_port(1)) upsample.in_port(0).get_connection().set_destination( resample_op.in_port(0)) upsample.out_port(0).get_connection().set_source( resample_op.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): merge = match['merge'] power = Pow(graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM' }).create_node() const1 = Const(graph, { 'value': -1.0, 'name': merge.name + '/negate_const' }).create_node() merge.in_port(0).get_connection().set_destination(power.in_port(0)) const1.out_port(0).connect(power.in_port(1)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(merge.significant, merge.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(10.0, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) power.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(10.0, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = Cast( graph, { 'name': merge.name + '/Cast_' + idx.__str__(), 'dst_type': np.float32 }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': False, 'in_ports_count': 2, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array([1, 5]), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): assert match['operator'].has('multiplication_transparent_ports') quantize = match['quantize'] port = match['operator'].input_ports_with(match['quantized']) assert len(port) >= 1 if len(port) > 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more' ' than once'.format(match['quantized'].name)) return assert len(port) == 1 port = port[0] applicable = [ pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port ] if len(applicable) == 0: return # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through. # Assume that the constant that should be passed through is a scalar. output_low = quantize.in_node(3) output_high = quantize.in_node(4) assert len(output_low.out_nodes()) == 1 assert len(output_high.out_nodes()) == 1 if not output_low.has_valid('value') and not output_high.has_valid( 'value'): return output_low = output_low.value output_high = output_high.value operator = match['operator'] weights = operator.in_node(1).value weights_rounded = np.round(weights) weights_consistent = np.all(np.isclose(weights, weights_rounded)) and \ set(np.unique(weights_rounded)).issubset({-1, 1}) if weights_consistent and np.all(np.isclose(output_low, 0)) and np.all( np.isclose(output_high, 1)): reduction_indices = set(range(len(weights.shape))) - set( [operator.output_feature_channel]) weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices)) weights_reduced = weights_reduced.reshape( [len(weights_reduced), 1, 1]) # FIXME: works for NCHW only add_term = Const(graph, {'value': weights_reduced}).create_node() add = Add(graph, {}).create_node() add.in_port(1).connect(add_term.out_port(0)) mul_term = Const(graph, {'value': np.array(0.5)}).create_node() mul = Mul(graph, {}).create_node() mul.in_port(1).connect(mul_term.out_port(0)) add.out_port(0).connect(mul.in_port(0)) operator.out_port(0).get_connection().set_source(mul.out_port(0)) add.in_port(0).connect(operator.out_port(0)) operator['pad_value'] = float(-1.0) elif weights_consistent and np.all(np.isclose( output_low, -1)) and np.all(np.isclose(output_high, +1)): pass else: log.debug( 'ConvToBinaryConv: cannot apply transformation because input range is neither in [0, +1] nor ' 'in [-1, +1].') return operator['type'] = 'BinaryConvolution' operator['mode'] = 'xnor-popcount' operator['pad_value'] = operator.soft_get('pad_value', float(0)) operator['input'] = operator.in_node(0).shape[1] # Weights are not bit-packed yet; there should be a separate transformation to do that assert output_low.size == 1 assert output_high.size == 1 output_low = quantize.in_node(3) output_high = quantize.in_node(4) # Make sure that low/high values are exactly 0/1 output_low.value = np.zeros(output_low.shape) output_high.value = np.ones(output_high.shape)