def calculate_prior_box_value(value: Node, value_to_div: Port, value_to_add: Port): """ :param value: Node with value. Here is supposed the node with op='Split' :param value_to_div: Output port with values to be divided by 2 :param value_to_add: Output port with values to be added to values from value_to_div port :return: Sub and Add nodes The sub-graph can be described by formulas: min = value[value_to_add] - (value[value_to_div] / 2) max = value[value_to_add] + (value[value_to_div] / 2) """ graph = value.graph dtype = data_type_str_to_np(graph.graph['cmd_params'].data_type) _min = Sub(graph, dict(name=value.name + '/Sub')).create_node() div = create_op_node_with_second_input(graph, Div, np.array([2], dtype=dtype), op_attrs=dict(name=value.name + '/Div')) div.in_port(0).connect(value_to_div) _min.in_port(0).connect(value_to_add) _min.in_port(1).connect(div.out_port(0)) _max = Add(graph, dict(name=value.name + '/Add')).create_node() _max.in_port(0).connect(div.out_port(0)) _max.in_port(1).connect(value_to_add) return _min, _max
def replace_op(self, graph: Graph, node: Node): # Add new nodes mvn = MVN(graph, { 'eps': node.epsilon, 'name': node.name + '/Ins_Norm/MVN_', }).create_node() mul = Mul(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/mul_' }).create_node() add = Add(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/add_' }).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(mvn.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.out_port(0).connect(mul.in_port(0)) mul.out_port(0).connect(add.in_port(0)) return [add.id]
def replace_sub_graph(self, graph: Graph, match: dict): # This replacer replace ImageScalar operation to Mul->Add sequence # Also it check that weights and biases are good op = match['op'] # Check that weights and biases are not useless has_bias, has_weights = True, True if all([x == 1 for x in np.nditer(op.scale)]): has_weights = False if all([x == 0 for x in np.nditer(op.bias)]): has_bias = False assert len(op.in_ports()) == 1 last_port = op.in_port(0).get_source() # Create Mul & Add nodes if has_weights: mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node() mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node() op.in_port(0).get_connection().set_destination(mul_op.in_port(0)) mul_weights.out_port(0).connect(mul_op.in_port(1)) last_port = mul_op.out_port(0) if has_bias: add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node() add_op = Add(graph, dict(name=op.id + '/add_')).create_node() last_port.get_connection().set_destination(add_op.in_port(0)) add_bias.out_port(0).connect(add_op.in_port(1)) last_port = add_op.out_port(0) op.in_port(0).disconnect() op.out_port(0).get_connection().set_source(last_port)
def create_bias_node(graph: Graph, src_node): logger.debug('Creating new bias for {}'.format(src_node.name)) destination_ports = [] for dest_port in src_node.out_port(0).get_destinations(): destination_ports.append(dest_port) # Create Add and constant with zero bias bias_shape = src_node.out_port(0).data.get_shape() add_bias_shape = [1] * len(bias_shape) add_bias_shape[1] = bias_shape[1] weights = get_weights_for_node(src_node) bias_dtype = np.float32 if weights and weights.out_port(0).is_data_type_defined(): bias_dtype = weights.out_port(0).get_data_type() add_bias = Const(graph, {'value': np.zeros(add_bias_shape, dtype=bias_dtype), 'shape': add_bias_shape, 'need_shape_inference': True }).create_node() add_op = Add(graph, {'name': src_node.name + '/add_', 'need_shape_inference': True}).create_node() # Connect Const to Add node add_op.in_port(1).connect(add_bias.out_port(0)) # Reconnect src_node -> output to src_node -> Add -> output src_node.out_port(0).disconnect() src_node.out_port(0).get_connection().set_destination(add_op.in_port(0)) for destination_port in destination_ports: add_op.out_port(0).connect(destination_port) add_bias.out_node(0)['Insert_Convert_operation_after'] = True
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['op'] name = node.soft_get('name', node.id) # biases normalization if 2 in node.in_ports() and not node.in_port(2).disconnected(): bias_node = Add(graph, {'name': name + '/Bias_'}).create_node() if not graph.graph['cmd_params'].generate_deprecated_IR_V7: node_name = node.name + '/WithoutBiases' bias_node_name = node.name rename_nodes([(node, node_name), (bias_node, bias_node_name)]) node.out_port(0).get_connection().set_source(bias_node.out_port(0)) node.in_port(2).get_connection().set_destination(bias_node.in_port(1)) node.out_port(0).connect(bias_node.in_port(0)) # weights normalization assert node.has_valid('out-size') out_size = node['out-size'] reshape_dim = int64_array([-1, out_size]) if node.has_and_set('transpose_weights'): reshape_dim = int64_array([out_size, -1]) node.insert_op_on_input_port(in_port_idx=1, new_op_class=Reshape, new_op_attrs={'name': name + '/weights_reshape'}, value=reshape_dim) if node.has_and_set('transpose_weights'): node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose, new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0])) # input normalization for 4D Caffe and MxNet FullyConnected if graph.graph['fw'] in ['caffe', 'mxnet']: node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape, new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1])) MatMul.update_node_stat(node, {})
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['op'] name = node.soft_get('name', node.id) # biases normalization bias_node = Add(graph, {'name': name + '/Bias_', 'can_be_scaleshift': False}).create_node() if not graph.graph['cmd_params'].generate_deprecated_IR_V7: node_name = node.name + '/WithoutBiases' bias_node_name = node.name rename_nodes([(node, node_name), (bias_node, bias_node_name)]) node.out_port(0).get_connection().set_source(bias_node.out_port(0)) node.in_port(2).get_connection().set_destination(bias_node.in_port(1)) node.out_port(0).connect(bias_node.in_port(0)) if node.has_valid('alpha') and not math.isclose(node.alpha, 1): bias_node.insert_op_on_input_port(in_port_idx=0, new_op_class=Mul, value=np.array(node.alpha), new_op_attrs={'name': name + '/Alpha_', 'can_be_scaleshift': False}) del node['alpha'] if node.has_valid('beta') and not math.isclose(node.beta, 1): bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array(node.beta), new_op_attrs={'name': name + '/Beta_', 'can_be_scaleshift': False}) del node['beta'] MatMul.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def replace_sub_graph(self, graph: Graph, match: dict): op = match['op'] out_port = op.in_port(0).get_source() if op.soft_get('scale', 1) != 1: const = Const(graph, {'value': np.array(op.scale)}).create_node() mul = Mul(graph, {'name': op.name + '/mul_'}).create_node() const.out_port(0).connect(mul.in_port(1)) out_port.connect(mul.in_port(0)) out_port = mul.out_port(0) if op.soft_get('shift', 0) != 0: const = Const(graph, {'value': np.array(op.shift)}).create_node() add = Add(graph, {'name': op.name + '/add_'}).create_node() const.out_port(0).connect(add.in_port(1)) out_port.connect(add.in_port(0)) out_port = add.out_port(0) if op.soft_get('power', 1) != 1: const = Const(graph, {'value': np.array(op.power)}).create_node() pow = Pow(graph, {'name': op.name + '/pow_'}).create_node() const.out_port(0).connect(pow.in_port(1)) out_port.connect(pow.in_port(0)) out_port = pow.out_port(0) op.out_port(0).get_connection().set_source(out_port)
def replace_pattern(self, graph: Graph, match: dict): bias_add = match['BiasAdd'] # Replace BiasAdd by Add operation new_add = Add(graph, {'name': bias_add.id + '/Add'}).create_node() bias_add.in_port(0).get_connection().set_destination(new_add.in_port(0)) bias_add.in_port(1).get_connection().set_destination(new_add.in_port(1)) bias_add.out_port(0).get_connection().set_source(new_add.out_port(0)) if bias_add.data_format != 'NCHW': return input_shape = new_add.in_port(0).data.get_shape() bias_shape = new_add.in_port(1).data.get_shape() assert len(bias_shape) == 1 unsqueeze_dims = np.arange(len(input_shape)) channel_dim = get_features_dim('NCHW', len(input_shape)) unsqueeze_dims = np.delete(unsqueeze_dims, channel_dim, 0) unsqueeze_node = Unsqueeze(graph, {'name': new_add.id + '/BiasUnsqueeze'}).create_node() unsqueeze_dims_node = Const(graph, {'name': new_add.id + '/Dims', 'value': unsqueeze_dims}).create_node() # Reconnecting nodes unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0)) unsqueeze_node['override_output_shape'] = True new_add.in_port(1).get_connection().insert_node(unsqueeze_node)
def add_constant_to_negative_values(node: Node, port_idx: int, added_value: np.array): """ This function adds the given values to negative elements of value from the given input port. :param node: node with corrected values in the input port port_idx :param port_idx: input port index for negative values :param added_value: the value to add :return: None """ negative_values_source = node.in_port(port_idx).get_source() negative_values_node = node.in_port(port_idx).get_source().node negative_values_node_name = negative_values_node.soft_get( 'name', negative_values_node.id) graph = node.graph less_node = create_op_with_const_inputs( graph, Less, {1: np.array(0, dtype=added_value.dtype)}, {'name': negative_values_node_name + '/Less'}) mul_node = create_op_with_const_inputs( graph, Mul, {1: added_value}, {'name': negative_values_node_name + '/Mul'}) node.in_port(port_idx).get_connection().set_destination( less_node.in_port(0)) less_node.out_port(0).connect(mul_node.in_port(0)) add_node = Add(graph, {}).create_node() mul_node.out_port(0).connect(add_node.in_port(1)) negative_values_source.connect(add_node.in_port(0)) add_node.out_port(0).connect(node.in_port(port_idx))
def sub_to_add_replacement(sub: Node): # we execute this transformation for V10 IR later on middle phase despite graph_condition # so we prevent Sub replacement on shape-calculating sub-graphs if sub.in_port(0).data.get_value() is not None and sub.in_port( 1).data.get_value() is not None: return graph = sub.graph name = sub.soft_get('name', sub.id) # keep Add name the same as Sub -- because of mathematical equality of output tensors rename_node(node=sub, name=name + '/to_be_removed') # reconnect Sub in(out)puts to Add add = Add(graph, {'name': name}).create_node() rename_node(add, name) sub.in_port(0).get_connection().set_destination(add.in_port(0)) sub.in_port(1).get_connection().set_destination(add.in_port(1)) sub.out_port(0).get_connection().set_source(add.out_port(0)) # restore mathematical equivalence to Sub operation: Sub(A, B) = Add(A, Mul(B, -1)) const_dtype = sub.soft_get('data_type', np.float32) negate = create_op_with_const_inputs( graph, Mul, {1: np.array(-1, dtype=const_dtype)}, {'name': name + '/neg_'}) add.in_port(1).get_connection().insert_node(negate)
def get_range_node_of_idxs(rank: Node, begin: int, end: int, include_begin: bool = True, include_end: bool = False) -> Node: """ Returns node that produces 1D output of values of range from begin to end (ex)/(in)cluding begin or end point :param rank: the node of 0D output shape to get rank of tensor from :param begin: integer value from [-rank; rank - 1] :param end: integer value from [-rank; +rank] :param include_begin: boolean flag to include or exclude start point from range output :param include_end: boolean flag to include or exclude end point from range output :return: range node producing 1D output """ graph = rank.graph name = rank.soft_get('name', rank.id) start_idx = get_canonical_axis_index_node(rank, begin) end_idx = get_canonical_axis_index_node(rank, end) if not include_begin: const = Const(graph, { 'value': int64_array(1), 'name': name + '/exclude_begin/value' }).create_node() add = Add(graph, {'name': name + '/exclude_begin'}).create_node() start_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) start_idx = add if include_end: const = Const(graph, { 'value': int64_array(1), 'name': name + '/including_end/value' }).create_node() add = Add(graph, {'name': name + '/including_end'}).create_node() end_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) end_idx = add delta = Const(graph, { 'name': name + '/delta', 'value': int64_array(1) }).create_node() range_node = Range(graph, {'name': name + '/range_idxs'}).create_node() start_idx.out_port(0).connect(range_node.in_port(0)) end_idx.out_port(0).connect(range_node.in_port(1)) delta.out_port(0).connect(range_node.in_port(2)) return range_node
def replace_op(self, graph: Graph, node: Node): # Add new nodes const = Const(graph, dict(value=np.array(-1, dtype=np.int32))).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(negate.in_port(0)) const.out_port(0).connect(negate.in_port(1)) node.in_port(0).get_connection().set_destination(add.in_port(1)) negate.out_port(0).connect(add.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [add.id]
def replace_pattern(graph: Graph, match: [str, Node]): node = match['sub'] # Add new nodes negate_const = Const( graph, dict(name=node.name + '/negate_const', value=np.array(-1))).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(negate.in_port(0)) negate_const.out_port(0).connect(add.in_port(1)) node.in_port(0).get_connection().set_destination(add.in_port(1)) negate.out_port(0).connect(add.in_port(0)) node.out_port(0).get_connection().set_source(add.out_port(0))
def replace_op(self, graph: Graph, node: Node): const_dtype = np.float32 if node.has_valid('data_type'): const_dtype = node.data_type const = Const(graph, { 'value': np.array([1], dtype=const_dtype) }).create_node() add = Add(graph, {'name': node.name + '/Add_'}).create_node() log = LogOp(graph, {'name': node.name + '/Log_'}).create_node() # Connect nodes: input -> Add -> Log const.out_port(0).connect(add.in_port(0)) node.in_port(0).get_connection().set_destination(add.in_port(1)) add.out_port(0).connect(log.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [log.id]
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='Gemm'): name = node.soft_get('name', node.id) node_output_port = node.out_port(0) if node.has_valid('alpha') and not math.isclose(node.alpha, 1): mul_alpha = create_op_with_const_inputs( graph, Mul, {1: np.array(node.alpha)}, { 'name': name + '/Alpha', 'can_be_scaleshift': False }) node_output_port.get_connection().insert_node(mul_alpha) node_output_port = mul_alpha.out_port(0) del node['alpha'] if node.is_in_port_connected(2): # biases normalization bias_node = Add(graph, { 'name': name + '/Bias_', 'can_be_scaleshift': False }).create_node() without_biases_node_name = name + '/WithoutBiases' rename_nodes([(node, without_biases_node_name), (bias_node, name)]) node_output_port.get_connection().set_source( bias_node.out_port(0)) node.in_port(2).get_connection().set_destination( bias_node.in_port(1)) node_output_port.connect(bias_node.in_port(0)) if node.has_valid('beta') and not math.isclose(node.beta, 1): bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array( node.beta), new_op_attrs={ 'name': name + '/Beta', 'can_be_scaleshift': False }) del node['beta'] MatMul.update_node_stat( node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def _fused_batch_norm_decomposition(graph: Graph, tinput: Port, toutput: Port, gamma: Port, beta: Port, mean: np.ndarray, variance: np.ndarray, can_be_fused=True): """ This is common function for TF, Caffe and MXNet It creates Mul->Add->Mul->Add sub graph """ batch_norm_name = tinput.get_connection().get_destination().node.name # Create first Mul & Add operations mul1_node = Mul( graph, dict(name=batch_norm_name + "/mean", can_be_fused=can_be_fused)).create_node() add1_node = Add( graph, dict(name=batch_norm_name + "/variance", can_be_fused=can_be_fused)).create_node() const_mul1_node = Const(graph, dict(name="data_mul_", value=np.array(mean))).create_node() const_add1_node = Const(graph, dict(name="data_add_", value=np.array(variance))).create_node() # Broadcast const from scalar # We can broadcast only when const.value is scalar if gamma.data.get_shape()[0] != gamma.data.get_value().shape[0]: value = gamma.data.get_value() value.resize(gamma.data.get_shape()).fill(value[0]) gamma.data.set_value(value) # Create second Mul & Add mul2_node = Mul( graph, dict(name=batch_norm_name + "/gamma", can_be_fused=can_be_fused)).create_node() add2_node = Add( graph, dict(name=batch_norm_name + "/beta", can_be_fused=can_be_fused)).create_node() # Connect edges Mul1->Add1->Mul2->Add2 tinput.get_connection().set_destination(mul1_node.in_port(0)) mul1_node.in_port(1).get_connection().set_source( const_mul1_node.out_port(0)) add1_node.in_port(0).get_connection().set_source(mul1_node.out_port(0)) add1_node.in_port(1).get_connection().set_source( const_add1_node.out_port(0)) mul2_node.in_port(0).get_connection().set_source(add1_node.out_port(0)) gamma.get_connection().set_destination(mul2_node.in_port(1)) add2_node.in_port(0).get_connection().set_source(mul2_node.out_port(0)) beta.get_connection().set_destination(add2_node.in_port(1)) toutput.get_connection().set_source(add2_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.has_port('in', 2) and not node.in_port( 2).disconnected() and not node.has_and_set('shape_input'): bias_name = node.name new_node_name = node.name + '/WithoutBiases' add = Add(graph, dict(name=bias_name)).create_node() rename_nodes([(node, new_node_name), (add, bias_name)]) node.out_port(0).get_connection().set_source(add.out_port(0)) node.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) bias = add.in_port(1).get_source().node if bias.has_valid("type") and bias.type == "Const": input_shape = add.in_port(0).data.get_shape() if len(input_shape) > 2: dims_to_add = len(input_shape) - 2 if graph.graph[ 'layout'] == 'NCHW' else 0 if dims_to_add > 0: reshape = create_op_node_with_second_input( graph, Reshape, np.array([input_shape[1]] + [1] * dims_to_add, dtype=np.int64), {'name': node.id + '/Dims'}) add.in_port(1).get_connection().set_destination( reshape.in_port(0)) reshape.out_port(0).connect(add.in_port(1))
def replace_sub_graph(self, graph: Graph, match: dict): tf_slice_node = match['op'] slice_name = tf_slice_node.soft_get('name', tf_slice_node.id) slice_node = Slice(graph).create_node() rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node() # reconnect input, begin, and size from TFSlice to the subgraph with Slice tf_slice_node.in_port(0).get_connection().set_destination( slice_node.in_port(0)) tf_slice_node.in_port(1).get_connection().set_destination( slice_node.in_port(1)) tf_slice_node.in_port(2).get_connection().set_destination( ends_node.in_port(0)) slice_node.in_port(1).get_connection().add_destination( ends_node.in_port(1)) max_ends = Shape(graph, { 'name': slice_name + '/ShapeOf' }).create_node() slice_node.in_port(0).get_connection().add_destination( max_ends.in_port(0)) # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank where_max_ends_is_needed = create_op_with_const_inputs( graph, Equal, {0: int64_array(-1)}, {'name': slice_name + '/where_max_ends_is_needed'}) ends_node.in_port(0).get_connection().add_destination( where_max_ends_is_needed.in_port(1)) # select requires equal dtypes, need to convert ends to I64 ends_casted_to_i64 = Cast(graph, { 'name': slice_name + '/CastToI64', 'dst_type': np.int64 }).create_node([ends_node]) # if size[i] == 1 then take max_ends values correct_ends = Select(graph, { 'name': slice_name + '/chosen_ends' }).create_node( [where_max_ends_is_needed, max_ends, ends_casted_to_i64]) correct_ends.out_port(0).connect(slice_node.in_port(2)) tf_slice_node.out_port(0).get_connection().set_source( slice_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] slice_name = node.soft_get('name', node.id) slice_node = Slice(graph).create_node() rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node() minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node() int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node() select_node = Select(graph, {'name': slice_name + '/select'}).create_node() # node to convert sizes to ends sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node() # reconnect input from tfslice to slice node.in_port(0).get_source().connect(slice_node.in_port(0)) node.in_port(0).disconnect() # reconnect begin of tfslice to start of slice node.in_port(1).get_source().connect(slice_node.in_port(1)) node.in_port(1).disconnect() # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice # connects begins to slice slice_node.in_port(1).get_source().connect(sum_node.in_port(0)) node.in_port(2).get_source().connect(sum_node.in_port(1)) node.in_port(2).disconnect() # if size[i] == -1 when take int32_max as end[i] sum_node.in_port(1).get_source().connect(eq_node.in_port(0)) minus_one_node.out_port(0).connect(eq_node.in_port(1)) # from equal to 0 port of select eq_node.out_port(0).connect(select_node.in_port(0)) # from int32_max to 1 of select int32_max_node.out_port(0).connect(select_node.in_port(1)) # from sum to 2nd of select sum_node.out_port(0).connect(select_node.in_port(2)) # out of select to end (2nd of slice) select_node.out_port(0).connect(slice_node.in_port(2)) cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node() select_node.in_port(2).get_connection().insert_node(cast) node.out_port(0).get_connection().set_source(slice_node.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): quantize = match['quantize'] sum_node = Add(graph, dict()).create_node() const = Const(graph, {'value': np.array(0.5)}).create_node() mul_node = Mul(graph, dict()).create_node() mul_node.in_port(0).connect(sum_node.out_port(0)) mul_node.in_port(1).connect(const.out_port(0)) quantize.in_port(1).get_connection().get_source().connect( sum_node.in_port(0)) quantize.in_port(2).get_connection().get_source().connect( sum_node.in_port(1)) quantize.in_port(1).disconnect() quantize.in_port(2).disconnect() mul_node.out_port(0).connect(quantize.in_port(1)) mul_node.out_port(0).connect(quantize.in_port(2))
def replace_op(self, graph: Graph, node: Node): # Create nodes const_neg = Const( graph, dict(value=np.array(-1), name=node.name + '/negate_const_')).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() const = Const(graph, {'value': np.array(2)}).create_node() squared = Pow(graph, {'name': node.name + '/squared_'}).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(add.in_port(0)) node.in_port(1).get_connection().set_destination(negate.in_port(0)) const_neg.out_port(0).connect(negate.in_port(1)) negate.out_port(0).connect(add.in_port(1)) add.out_port(0).connect(squared.in_port(0)) const.out_port(0).connect(squared.in_port(1)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [squared.id]
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) # create range of axes for MVN based on `start_axis` and rank of input rank = Rank(graph, {'name': name + '/Rank'}).create_node() rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, { 'name': name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'eps': node.epsilon, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'name': name + '/Ins_Norm/MVN_', }).create_node() node.in_port(0).get_connection().set_destination(mvn.in_port(0)) rng.out_port(0).connect(mvn.in_port(1)) mul = Mul(graph, { 'axis': 1, 'name': name + '/Ins_Norm/mul_' }).create_node() mvn.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) add = Add(graph, { 'axis': 1, 'name': name + '/Ins_Norm/add_' }).create_node() mul.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.in_port(0).get_connection().add_destination(rank.in_port(0)) rng.in_port(1).connect(rank.out_port(0)) rename_nodes([(node, name + '/TBD'), (add, name)]) return [add.id]
def get_canonical_axis_index_node(rank: Node, axis: int) -> Node: """ Returns positive axis value :param rank: the node of 0D output shape to get rank of tensor from :param axis: integer value from [-rank; rank - 1] :return: node producing positive integer value of axis """ graph = rank.graph name = rank.soft_get('name', rank.id) if axis < 0: axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array(axis)}).create_node() add = Add(graph, {'name': name + '/positive_axis'}).create_node() rank.out_port(0).connect(add.in_port(0)) axis.out_port(0).connect(add.in_port(1)) return add else: return Const(graph, {'name': name + '/positive_axis', 'value': int64_array(axis)}).create_node()
def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) split_node_axis = Const(graph, {'value': np.int64(1)}).create_node() split_node = Split(graph, { 'name': 'Split_lstm_input_', 'num_splits': 5 }).create_node() node.in_port(0).get_connection().set_destination(split_node.in_port(0)) split_node.in_port(1).connect(split_node_axis.out_port(0)) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = {'name': 'i_scaleshift', 'bias_term': False} i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) split_node.out_port(4).connect(i_scale.in_port(0)) sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = {'name': 'f_scaleshift', 'bias_term': False} f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) split_node.out_port(4).connect(f_scale.in_port(0)) sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = {'name': 'o_scaleshift', 'bias_term': False} o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Mul(graph, { 'name': 'prod_o_c_t_tanh_' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, {'name': 'Concat_c_m'}).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]
def replace_pattern(self, graph: Graph, match: dict): assert match['operator'].has('multiplication_transparent_ports') quantize = match['quantize'] port = match['operator'].input_ports_with(match['quantized']) assert len(port) >= 1 if len(port) > 1: log.debug( 'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more' ' than once'.format(match['quantized'].name)) return assert len(port) == 1 port = port[0] applicable = [ pair for pair in match['operator'].multiplication_transparent_ports if pair[0] == port ] if len(applicable) == 0: return # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through. # Assume that the constant that should be passed through is a scalar. output_low = quantize.in_node(3) output_high = quantize.in_node(4) assert len(output_low.out_nodes()) == 1 assert len(output_high.out_nodes()) == 1 if not output_low.has_valid('value') and not output_high.has_valid( 'value'): return output_low = output_low.value output_high = output_high.value operator = match['operator'] weights = operator.in_node(1).value weights_rounded = np.round(weights) weights_consistent = np.all(np.isclose(weights, weights_rounded)) and \ set(np.unique(weights_rounded)).issubset({-1, 1}) if weights_consistent and np.all(np.isclose(output_low, 0)) and np.all( np.isclose(output_high, 1)): reduction_indices = set(range(len(weights.shape))) - set( [operator.output_feature_channel]) weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices)) weights_reduced = weights_reduced.reshape( [len(weights_reduced), 1, 1]) # FIXME: works for NCHW only add_term = Const(graph, {'value': weights_reduced}).create_node() add = Add(graph, {}).create_node() add.in_port(1).connect(add_term.out_port(0)) mul_term = Const(graph, {'value': np.array(0.5)}).create_node() mul = Mul(graph, {}).create_node() mul.in_port(1).connect(mul_term.out_port(0)) add.out_port(0).connect(mul.in_port(0)) operator.out_port(0).get_connection().set_source(mul.out_port(0)) add.in_port(0).connect(operator.out_port(0)) operator['pad_value'] = float(-1.0) elif weights_consistent and np.all(np.isclose( output_low, -1)) and np.all(np.isclose(output_high, +1)): pass else: log.debug( 'ConvToBinaryConv: cannot apply transformation because input range is neither in [0, +1] nor ' 'in [-1, +1].') return operator['type'] = 'BinaryConvolution' operator['mode'] = 'xnor-popcount' operator['pad_value'] = operator.soft_get('pad_value', float(0)) operator['input'] = operator.in_node(0).shape[1] # Weights are not bit-packed yet; there should be a separate transformation to do that assert output_low.size == 1 assert output_high.size == 1 output_low = quantize.in_node(3) output_high = quantize.in_node(4) # Make sure that low/high values are exactly 0/1 output_low.value = np.zeros(output_low.shape) output_high.value = np.ones(output_high.shape)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, { 'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_shape_op_node.out_port(0).connect( initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value( initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value( initial_shape_op_node) initial_spatial_dims_node = Cast( graph, { 'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_spatial_dims_node_int.out_port(0).connect( initial_spatial_dims_node.in_port(0)) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) new_shape_node = Cast(graph, { 'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64 }).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt' }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port( mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(1), 2: int64_array(1) }, { 'name': group_norm_node.name + '/Range', 'output_type': np.int64 }) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))
def replace_pattern(self, graph: Graph, match: [str, Node]): node = match['crop'] assert node.has_valid('axis') node.axis = self.list_to_ndarray(node.axis) in_shape = node.in_port(0).data.get_shape() shape_rank = in_shape.size axis_mask = int64_array( [1 if i in node.axis else 0 for i in range(shape_rank)]) begin_mask = axis_mask.copy() end_mask = axis_mask.copy() if len(node.in_nodes()) == 2 and node.has_valid('offset'): # Crop Type 1 begin = Const(graph, { 'value': self.mask_normalizer(shape_rank, node.axis, node.offset) }).create_node() shape = Shape(graph, { 'name': node.name + '/shape_of_crop' }).create_node() end = Add(graph, {'name': node.name + '/end'}).create_node() node.in_port(1).get_connection().get_source().connect( shape.in_port(0)) node.in_port(1).disconnect() shape.out_port(0).connect(end.in_port(0)) begin.out_port(0).connect(end.in_port(1)) elif node.has_valid('dim') and node.has_valid('offset'): # Crop Type 2 node.dim = self.list_to_ndarray(node.dim) node.offset = self.list_to_ndarray(node.offset) assert node.dim.size == node.offset.size == node.axis.size begin = Const(graph, { 'value': self.mask_normalizer(shape_rank, node.axis, node.offset) }).create_node() end_values = np.array( [node.offset[i] + node.dim[i] for i in range(len(node.dim))]) end = Const(graph, { 'value': self.mask_normalizer(shape_rank, node.axis, end_values) }).create_node() elif node.has_valid('crop_begin') and node.has_valid('crop_end'): # Crop Type 3 node.crop_begin = self.list_to_ndarray(node.crop_begin) node.crop_end = self.list_to_ndarray(node.crop_end) assert len(node.crop_begin) == len(node.crop_end) == len(node.axis) begin = Const( graph, { 'value': self.mask_normalizer(shape_rank, node.axis, node.crop_begin) }).create_node() shape = Shape(graph, { 'name': node.name + '/shape_of_crop' }).create_node() const = Const( graph, { 'value': -1 * self.mask_normalizer(shape_rank, node.axis, node.crop_end) }).create_node() end = Add(graph, {'name': node.name + '/end'}).create_node() node.in_port(0).get_connection().get_source().connect( shape.in_port(0)) shape.out_port(0).connect(end.in_port(0)) const.out_port(0).connect(end.in_port(1)) else: raise Exception("Unknown type of Crop") source = node.in_port(0).get_connection().get_source() stride = Const(graph, { 'value': np.ones(shape_rank, dtype=np.int64) }).create_node() ss = StridedSlice( graph, { 'name': 'Crop_', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0]) }).create_node() source.connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) node.in_port(0).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) ss['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node) initial_spatial_dims_node = node_to_get_spatial_dimensions_value( initial_shape_op_node) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'across_channels': 1, 'normalize_variance': 1, 'eps': group_norm_node.eps }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] if 1 not in node.in_ports() or node.in_port(1).disconnected(): if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'): factor = Const(graph, {'value': np.array(node.factor)}).create_node() shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node() source = node.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(mul.in_port(0)) factor.out_port(0).connect(mul.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() mul.out_port(0).connect(node.in_port(1)) else: shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() source = node.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) pads_value = node.pads_begin + node.pads_end pads_const = Const(graph, {'value': np.array(pads_value)}).create_node() add = Add(graph, {'name': node.name + '/pad_add'}).create_node() ss.out_port(0).connect(add.in_port(0)) add.in_port(1).connect(pads_const.out_port(0)) if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1: shrink_factor = node.shrink_factor if shrink_factor < 1: log.error('Shrink factor should be positive in node {}'.format(node.id)) return None const = Const(graph, {'name': node.name + '/pre_shrink_sub_const', 'value': np.array(-1)}).create_node() sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node() add.out_port(0).connect(sub.in_port(0)) sub.in_port(1).connect(const.out_port(0)) const = Const(graph, {'value': np.array(1 / shrink_factor), 'name': node.name + 'shrink_factor_div_const'}).create_node() div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node() sub.out_port(0).connect(div.in_port(0)) div.in_port(1).connect(const.out_port(0)) const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1) }).create_node() add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node() div.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() add.out_port(0).connect(node.in_port(1)) elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1: zoom_factor = node.zoom_factor if zoom_factor < 1: log.error('Zoom factor should be positive in node {}'.format(node.id)) return None node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \ ' file (extensions/front/InterpolateNormalizer.py at the line {}).' \ ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100) # Reshape methods can be different in some cases # Commented out section represents reshape that used in deeplab-caffe # Uncomment the following lines, if your model was trained with deeplab-caffe # or have the same reshape method # const = Const(graph, {'value': np.array(-1), # 'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node() # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node() # add.out_port(0).connect(sub.in_port(0)) # const.out_port(0).connect(sub.in_port(1)) # # const = Const(graph, {'value': np.array(zoom_factor - 1), # 'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node() # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node() # sub.out_port(0).connect(mul.in_port(0)) # const.out_port(0).connect(mul.in_port(1)) # # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node() # add.out_port(0).connect(sum.in_port(0)) # mul.out_port(0).connect(sum.in_port(1)) # # node.add_input_port(1, skip_if_exist=True) # assert node.in_port(1).disconnected() # sum.out_port(0).connect(node.in_port(1)) # Comment out the following lines if you use the reshape method from previous section const = Const(graph, {'value': np.array(zoom_factor), 'name': node.name + '/zoom_factor_mul_const'}).create_node() mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node() add.out_port(0).connect(mul.in_port(0)) const.out_port(0).connect(mul.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() mul.out_port(0).connect(node.in_port(1)) elif node.soft_get('width') != 0 and node.soft_get('height') != 0: const = Const(graph, {'value': np.array([node.height, node.width])}).create_node() node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() const.out_port(0).connect(node.in_port(1)) elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1: shrink_factor = node.shrink_factor zoom_factor = node.zoom_factor if shrink_factor < 1: log.error('Shrink factor should be positive in node {}'.format(node.id)) return None if zoom_factor < 1: log.error('Zoom factor should be positive in node {}'.format(node.id)) return None const = Const(graph, {'value': np.array(-1)}).create_node() sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node() add.out_port(0).connect(sub.in_port(0)) const.out_port(0).connect(sub.in_port(1)) const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node() div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node() sub.out_port(0).connect(div.in_port(0)) const.out_port(0).connect(div.in_port(1)) const = Const(graph, {'value': np.array(-1), 'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node() sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node() div.out_port(0).connect(sum.in_port(0)) const.out_port(0).connect(sum.in_port(1)) const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node() mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node() sum.out_port(0).connect(mul.in_port(0)) const.out_port(0).connect(mul.in_port(1)) sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node() div.out_port(0).connect(sum.in_port(0)) mul.out_port(0).connect(sum.in_port(1)) node.add_input_port(1, skip_if_exist=True) assert node.in_port(1).disconnected() sum.out_port(0).connect(node.in_port(1)) else: if node.soft_get('fw') == 'caffe': shape = Shape(graph, {'name': node.name + '/shape'}).create_node() begin = Const(graph, {'value': np.array([2])}).create_node() end = Const(graph, {'value': np.array([4])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]), 'end_mask': np.array([0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() source = node.in_port(1).get_connection().get_source() node.in_port(1).disconnect() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) begin.out_port(0).connect(ss.in_port(1)) end.out_port(0).connect(ss.in_port(2)) stride.out_port(0).connect(ss.in_port(3)) ss.out_port(0).connect(node.in_port(1))
def replace_op(self, graph: Graph, node: Node): input_out_port = node.in_port(0).get_source() memory_pair_input = unique_id('id') memory_pair_output = unique_id('id') # Input -> FullyConnected fc_layer_after_input_attrs = { 'name': 'input_fullyconnected', 'out-size': node.gifo_x_weights_shape[0], 'transpose_weights': True, 'bias_term': True, } fc_layer_after_input = FullyConnected( graph, fc_layer_after_input_attrs).create_node() fc_layer_after_input.in_port(0).connect(input_out_port) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases) init_value_prev_lstm_output = create_zero_value_with_batch_from_input( input_out_port, node.gifo_r_weights_shape[1]) prev_lstm_output = ReadValue(graph, { 'name': 'prev_memory_output', 'variable_id': memory_pair_input }).create_node() prev_lstm_output.in_port(0).connect( init_value_prev_lstm_output.out_port(0)) # *Memory(output) -> FullyConnected fc_layer_from_prev_state_attrs = { 'name': 'prev_memory_output_fullyconnected', 'out-size': node.gifo_r_weights_shape[0], 'transpose_weights': True, 'bias_term': False, } fc_layer_from_prev_state = FullyConnected( graph, fc_layer_from_prev_state_attrs).create_node() fc_layer_from_prev_state.in_port(0).connect( prev_lstm_output.out_port(0)) input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights) # Memory -> FullyConnected \ # *Eltwise(sum) # Input -> FullyConnected / join_input_prev_state_sum = Add(graph, { 'name': 'join_input_eltwise' }).create_node() join_input_prev_state_sum.in_port(0).connect( fc_layer_from_prev_state.out_port(0)) join_input_prev_state_sum.in_port(1).connect( fc_layer_after_input.out_port(0)) # *Eltwise(sum) -> Split # it is split into 4 nodes: Act, Eltw*3 # the following order is mandatory # ___Tanh # / # Split ---(2)Eltwise(sum) # |\ # | \__(3)Eltwise(sum) # |____(4)Eltwise(sum) split_joined_input_axis = Const(graph, { 'value': np.int64(1) }).create_node() split_joined_input = Split(graph, { 'name': 'join_input_split', 'num_splits': 4, 'out_ports_count': 4 }).create_node() split_joined_input.in_port(0).connect( join_input_prev_state_sum.out_port(0)) split_joined_input.in_port(1).connect( split_joined_input_axis.out_port(0)) # prev_lstm_state = Memory(graph, {'name': 'prev_memory_state', # 'id': memory_pair_output, # 'index': 1, # 'size': 2, # 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64) # }).create_node() init_value_prev_lstm_state = create_zero_value_with_batch_from_input( split_joined_input.out_port(0), node.input_gate_weights.shape[0]) prev_lstm_state = ReadValue(graph, { 'name': 'prev_memory_state', 'variable_id': memory_pair_output }).create_node() prev_lstm_state.in_port(0).connect( init_value_prev_lstm_state.out_port(0)) # *Memory(state) -> *ScaleShift(input) state_input_scaleshift_attrs = { 'name': 'input_scaleshift', 'bias_term': False } state_input_scaleshift = ScaleShiftOp( graph, state_input_scaleshift_attrs).create_node() state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights) # *Memory(state) -> *ScaleShift(forget) state_forget_scaleshift_attrs = { 'name': 'forget_scaleshift', 'bias_term': False } state_forget_scaleshift = ScaleShiftOp( graph, state_forget_scaleshift_attrs).create_node() state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights) # Split \ # (2)Eltwise(sum) # Memory(state) -> *ScaleShift(input) / join_prev_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_input_eltwise' }).create_node() join_prev_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(1)) join_prev_lstm_input_joined_input_sum.in_port(1).connect( state_input_scaleshift.out_port(0)) # Split \ # (3)Eltwise(sum) # Memory(state) -> *ScaleShift(forget) / join_prev_lstm_input_joined_forget_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_forget_sum', }).create_node() join_prev_lstm_input_joined_forget_sum.in_port(0).connect( split_joined_input.out_port(2)) join_prev_lstm_input_joined_forget_sum.in_port(1).connect( state_forget_scaleshift.out_port(0)) # Split -> Tanh remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node() remember_tahn.in_port(0).connect(split_joined_input.out_port(0)) # Split -> (2)Eltwise(sum) -> *Sigmoid remember_sigmoid = Sigmoid(graph, { 'name': 'remember_sigmoid' }).create_node() remember_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_input_sum.out_port(0)) # Split -> (3)Eltwise(sum) -> **Sigmoid forget_sigmoid = Sigmoid(graph, { 'name': 'forget_sigmoid' }).create_node() forget_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_forget_sum.out_port(0)) # *Memory(state) \ # (6)Eltwise(mul) # Split -> (3)Eltwise(sum) -> **Sigmoid / join_forget_prev_state_mul = Mul(graph, { 'name': 'join_forget_prev_state_mul' }).create_node() join_forget_prev_state_mul.in_port(0).connect( forget_sigmoid.out_port(0)) join_forget_prev_state_mul.in_port(1).connect( prev_lstm_state.out_port(0)) # Split -> Tahn \ # (5)Eltwise(mul) # Split -> (2)Eltwise(sum) -> *Sigmoid / join_remember_candidates_mul = Mul( graph, { 'name': 'join_remember_candidates_mul' }).create_node() join_remember_candidates_mul.in_port(0).connect( remember_tahn.out_port(0)) join_remember_candidates_mul.in_port(1).connect( remember_sigmoid.out_port(0)) # (5)Eltwise(mul) \ # (7)Eltwise(sum) # (6)Eltwise(mul) / join_forget_remember_sum = Add(graph, { 'name': 'join_forget_remember_sum' }).create_node() join_forget_remember_sum.in_port(0).connect( join_forget_prev_state_mul.out_port(0)) join_forget_remember_sum.in_port(1).connect( join_remember_candidates_mul.out_port(0)) # (7)Eltwise(sum) -> Clamp join_forget_clamp = create_op_with_const_inputs( graph, Clamp, { 1: np.array(-node.clip_value, dtype=np.float32), 2: np.array(node.clip_value, dtype=np.float32) }, {'name': 'join_forget_clamp'}, join_forget_remember_sum) # # Clamp -> (2)Memory(state) next_lstm_state = Assign(graph, { 'name': 'next_lstm_state', 'variable_id': memory_pair_output }).create_node() next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0)) res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node() res_node.in_port(0).connect(next_lstm_state.out_port(0)) # Clamp -> (2)Tahn state_filtered_tahn = Tanh(graph, { 'name': 'state_filtered_tahn' }).create_node() state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0)) # Clamp -> (2)ScaleShift clamp_scaleshift_attrs = { 'name': 'clamp_scaleshift', 'bias_term': False } clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node() clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0)) input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights) # Split \ # (4)Eltwise(sum) # Clamp -> (2)ScaleShift / join_next_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_next_lstm_input_joined_input_sum', }).create_node() join_next_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(3)) join_next_lstm_input_joined_input_sum.in_port(1).connect( clamp_scaleshift.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid output_sigmoid = Sigmoid(graph, { 'name': 'output_sigmoid' }).create_node() output_sigmoid.in_port(0).connect( join_next_lstm_input_joined_input_sum.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid \ # (5)Eltwise(mul) # Clamp -> (2)Tahn / joined_output_mul = Mul(graph, { 'name': 'joined_output_mul' }).create_node() joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0)) joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0)) # (5)Eltwise(mul) -> (3)FullyConnected fc_output_attrs = { 'name': 'FullyConnected', 'out-size': node.projection_weights_shape[0], 'transpose_weights': True, 'bias_term': False } fc_output = FullyConnected(graph, fc_output_attrs).create_node() fc_output.in_port(0).connect(joined_output_mul.out_port(0)) input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights) # / (2)Memory(output) # (3)FullyConnected # \ Output (any next node) (edge created automatically after replacement) next_lstm_output = Assign(graph, { 'name': 'next_lstm_output', 'variable_id': memory_pair_input }).create_node() next_lstm_output.in_port(0).connect(fc_output.out_port(0)) res_node_lstm_output = Result(graph, { 'name': 'next_lstm_output_out' }).create_node() res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0)) return [fc_output.id]