def extract(cls, node): pb = node.pb model = node.model_pb param = pb.scale_param attrs = { 'axis': param.axis, } if model is None and len(pb.bottom) == 1: # default weights and biases for scale layer if the caffemodel file doesn't contain them model = NamedAttrsClass({ 'blobs': np.array([ NamedAttrsClass({'data': np.array([1])}), NamedAttrsClass({'data': np.array([0])}) ]) }) # scale with 1 input and 1 or 2 blobs if model and len(model.blobs) != 0 and len(pb.bottom) == 1: attrs.update(weights_biases(param.bias_term, model)) # 2 inputs + bias elif len(pb.bottom) == 2 and param.bias_term: if model is None or len(model.blobs) == 0: # default bias for scale layer with 2 inputs if the caffemodel file doesn't contain them model = NamedAttrsClass({ 'blobs': np.array([NamedAttrsClass({'data': np.array([0])})]) }) embed_input(attrs, 1, 'biases', model.blobs[0].data) ScaleShiftOp.update_node_stat(node, attrs) return cls.enabled
def replace_op(self, graph: Graph, node: Node): attrs = {'name': node.id + "/ScaleShift_"} param = graph.node[node.id]['pb'].bn_param pb_model = graph.node[node.id]['model_pb'] blobs = pb_model.blobs if len(blobs) != 4: raise Error("Incorrect number of blobs in BN layer {}".format( node.id)) mean = np.array(blobs[0].data) var = np.array(blobs[1].data) betta = np.array(blobs[2].data) gamma = np.array(blobs[3].data) gamma = gamma + np.repeat(param.eps, gamma.shape) scale = 1.0 / np.sqrt(gamma) * mean shift = var - betta * scale embed_input(attrs, 1, 'scale', scale, 'weights') embed_input(attrs, 2, 'bias', shift, 'biases') ss = ScaleShiftOp(graph, attrs) scale_shift = ss.create_node([node.in_node(0)]) return [scale_shift.id]
def extract(cls, node): pb = node.parameters collect_until_token(pb, b'<Dim>') dim = read_binary_integer32_token(pb) collect_until_token(pb, b'<BlockDim>') block_dim = read_binary_integer32_token(pb) collect_until_token(pb, b'<Epsilon>') eps = read_binary_float_token(pb) collect_until_token(pb, b'<TargetRms>') target_rms = read_binary_float_token(pb) collect_until_token(pb, b'<StatsMean>') mean = read_binary_vector(pb) collect_until_token(pb, b'<StatsVar>') var = read_binary_vector(pb) scale = target_rms / np.sqrt(var + eps) shift = -target_rms * mean / np.sqrt(var + eps) scale = np.tile(scale, dim // block_dim) shift = np.tile(shift, dim // block_dim) attrs = {'out-size': dim} embed_input(attrs, 1, 'weights', scale) embed_input(attrs, 2, 'biases', shift) ScaleShiftOp.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node): pb = node.parameters read_learning_info(pb) weights = read_binary_vector(pb) mapping_rule = {} embed_input(mapping_rule, 1, 'weights', weights) ScaleShiftOp.update_node_stat(node, mapping_rule) return cls.enabled
def extract(cls, node): pb = node.parameters read_learning_info(pb) biases = read_binary_vector(pb) bias_term = True mapping_rule = {'bias_term': bias_term} embed_input(mapping_rule, 1, 'weights', np.ones(biases.shape)) embed_input(mapping_rule, 2, 'biases', biases) ScaleShiftOp.update_node_stat(node, mapping_rule) return cls.enabled
def extract(cls, node): pb = node.parameters collect_until_token(pb, b'<Params>') weights = read_binary_vector(pb) find_next_tag(pb) read_placeholder(pb, 1) mapping_rule = {'layout': 'NCHW'} embed_input(mapping_rule, 1, 'weights', weights) ScaleShiftOp.update_node_stat(node, mapping_rule) return cls.enabled
def replace_op(self, graph: Graph, node: Node): in_node_0 = node.in_node(0) in_node_1 = node.in_node(1) in_node_2 = node.in_node(2) ss = ScaleShiftOp(graph, {'name': node.id + "/ScaleShift_", 'axis': 0}) scale_shift = ss.create_node(inputs=[in_node_1, in_node_0]) el = Add(graph, {'name': node.id + "/Add_"}) el_node = el.create_node(inputs=[scale_shift, in_node_2]) return [el_node.id]
def extract(node): pb = node.parameters collect_until_token(pb, b'<Dim>') dim = read_binary_integer32_token(pb) target_rms = 1 d_scaled = dim * target_rms**2 in_norm = np.zeros([dim], np.float64) in_norm += 1.0 / d_scaled in_norm = np.maximum(in_norm, 2.**(-66)) in_norm = np.power(in_norm, -0.5) attrs = {} embed_input(attrs, 1, 'weights', in_norm) ScaleShiftOp.update_node_stat(node, attrs) return __class__.enabled
def extract(cls, node): pb = node.parameters collect_until_token(pb, b'<Dim>') dim = read_binary_integer32_token(pb) collect_until_token(pb, b'<Scale>') scale = read_binary_float_token(pb) # TODO add real batch here attrs = {} embed_input(attrs, 1, 'weights', np.full([dim], scale)) ScaleShiftOp.update_node_stat(node, attrs) return cls.enabled
def convert_add_or_mul_to_scaleshift(graph: Graph): if graph.graph['cmd_params'].generate_experimental_IR_V10: return graph.strict_mode = False for node in graph.get_op_nodes(): if node.soft_get('op') in ['Add', 'Mul'] and len(node.in_ports()) == 2: tensor_port, value_port = get_tensor_in_port(node), get_value_in_port(node) if tensor_port is not None and not tensor_port.disconnected() and value_port is not None and node.soft_get('can_be_scaleshift') is not False: original_value = value_port.data.get_value() if original_value.size == 1: continue # Remove 1 dims from value array (should be 1D) value_port.data.set_value(np.squeeze(original_value)) # Updated shapes accordingly # Create ScaleShift operation scsh_op = ScaleShiftOp(graph, dict(name='ScaleShift/{}'.format(node.name))).create_node() if node.op == 'Mul': # Create fake biases for scale shift node const_op = Const(graph, dict(name='{}/biases'.format(scsh_op.name), value=np.zeros(value_port.data.get_shape(), dtype=np.float32), shape=np.array(value_port.data.get_shape()), )).create_node() # Reconnect input and weights to scale shift node tensor_port.get_connection().set_destination(scsh_op.in_port(0)) value_port.get_connection().set_destination(scsh_op.in_port(1)) const_op.out_port(0).connect(scsh_op.in_port(2)) else: # Create fake weights for scale shift node const_op = Const(graph, dict(name='{}/weights'.format(scsh_op.name), value=np.ones(value_port.data.get_shape(), dtype=np.float32), shape=np.array(value_port.data.get_shape()), )).create_node() # Reconnect input and biases to scale shift node tensor_port.get_connection().set_destination(scsh_op.in_port(0)) const_op.out_port(0).connect(scsh_op.in_port(1)) value_port.get_connection().set_destination(scsh_op.in_port(2)) node.out_port(0).get_connection().set_source(scsh_op.out_port(0)) # Set bin attribute to ScaleShift input ports scsh_op.in_port(1).bin = 'weights' scsh_op.in_port(2).bin = 'biases' graph.strict_mode = True
def extract(cls, node): pb = node.parameters collect_until_token(pb, b'<Bias>') biases = read_binary_vector(pb) find_next_tag(pb) read_placeholder(pb, 1) mapping_rule = { 'layout': 'NCHW', 'bias_term': True, 'out-size': biases.shape[0], } embed_input(mapping_rule, 2, 'biases', biases) ScaleShiftOp.update_node_stat(node, mapping_rule) return cls.enabled
def extract(node): pb = node.parameters collect_until_token(pb, b'<Dim>') dim = read_binary_integer32_token(pb) collect_until_token(pb, b'<BlockDim>') block_dim = read_binary_integer32_token(pb) if block_dim != dim: raise Error( "Dim is not equal BlockDim for BatchNorm is not supported") collect_until_token(pb, b'<Epsilon>') eps = read_binary_float_token(pb) collect_until_token(pb, b'<TargetRms>') target_rms = read_binary_float_token(pb) collect_until_token(pb, b'<TestMode>') test_mode = read_binary_bool_token(pb) if test_mode is not False: raise Error("Test mode True for BatchNorm is not supported") collect_until_token(pb, b'<StatsMean>') mean = read_binary_vector(pb) collect_until_token(pb, b'<StatsVar>') var = read_binary_vector(pb) scale = target_rms / np.sqrt(var + eps) shift = -target_rms * mean / np.sqrt(var + eps) attrs = {} embed_input(attrs, 1, 'weights', scale) embed_input(attrs, 2, 'biases', shift) ScaleShiftOp.update_node_stat(node, attrs) return __class__.enabled
def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) split_node_axis = Const(graph, {'value': np.int64(1)}).create_node() split_node = Split(graph, { 'name': 'Split_lstm_input_', 'num_splits': 5 }).create_node() node.in_port(0).get_connection().set_destination(split_node.in_port(0)) split_node.in_port(1).connect(split_node_axis.out_port(0)) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = {'name': 'i_scaleshift', 'bias_term': False} i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) split_node.out_port(4).connect(i_scale.in_port(0)) sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = {'name': 'f_scaleshift', 'bias_term': False} f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) split_node.out_port(4).connect(f_scale.in_port(0)) sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = {'name': 'o_scaleshift', 'bias_term': False} o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Mul(graph, { 'name': 'prod_o_c_t_tanh_' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, {'name': 'Concat_c_m'}).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]
def replace_op(self, graph: Graph, node: Node): input_node = node.in_node() memory_pair_input = unique_id('id') memory_pair_output = unique_id('id') # Input -> FullyConnected fc_layer_after_input_attrs = { 'name': 'input_fullyconnected', 'num_output': node.gifo_x_weights_shape[0], 'bias_term': True } embed_input(fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights) embed_input(fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases) fc_layer_after_input = InnerProduct( graph, fc_layer_after_input_attrs).create_node([input_node]) prev_lstm_output = Memory( graph, { 'name': 'prev_memory_output', 'id': memory_pair_input, 'index': 1, 'size': 2, 'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64) }).create_node() # *Memory(output) -> FullyConnected fc_layer_from_prev_state_attrs = { 'name': 'prev_memory_output_fullyconnected', 'num_output': node.gifo_r_weights_shape[0], 'bias_term': False } embed_input(fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights) fc_layer_from_prev_state = InnerProduct( graph, fc_layer_from_prev_state_attrs).create_node([prev_lstm_output]) # Memory -> FullyConnected \ # *Eltwise(sum) # Input -> FullyConnected / join_input_prev_state_sum = Add(graph, { 'name': 'join_input_eltwise', }).create_node([fc_layer_from_prev_state, fc_layer_after_input]) # *Eltwise(sum) -> Split # it is split into 4 nodes: Act, Eltw*3 # the following order is mandatory # ___Tanh # / # Split ---(2)Eltwise(sum) # |\ # | \__(3)Eltwise(sum) # |____(4)Eltwise(sum) split_joined_input = Split( graph, { 'name': 'join_input_split', 'axis': 1, 'num_split': 4, 'out_ports_count': 4, }).create_node([join_input_prev_state_sum]) prev_lstm_state = Memory( graph, { 'name': 'prev_memory_state', 'id': memory_pair_output, 'index': 1, 'size': 2, 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64) }).create_node() # *Memory(state) -> *ScaleShift(input) state_input_scaleshift_attrs = { 'name': 'input_scaleshift', 'bias_term': False } embed_input(state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights) state_input_scaleshift = ScaleShiftOp( graph, state_input_scaleshift_attrs).create_node([prev_lstm_state]) # *Memory(state) -> *ScaleShift(forget) state_forget_scaleshift_attrs = { 'name': 'forget_scaleshift', 'bias_term': False } embed_input(state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights) state_forget_scaleshift = ScaleShiftOp( graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state]) # Split \ # (2)Eltwise(sum) # Memory(state) -> *ScaleShift(input) / join_prev_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_input_eltwise', }).create_node([(split_joined_input, 1), state_input_scaleshift]) # Split \ # (3)Eltwise(sum) # Memory(state) -> *ScaleShift(forget) / join_prev_lstm_input_joined_forget_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_forget_sum', }).create_node([(split_joined_input, 2), state_forget_scaleshift]) # Split -> Tanh remember_tahn = Tanh(graph, { 'name': 'remember_tahnv' }).create_node([(split_joined_input, 0)]) # Split -> (2)Eltwise(sum) -> *Sigmoid remember_sigmoid = Sigmoid(graph, { 'name': 'remember_sigmoid' }).create_node([join_prev_lstm_input_joined_input_sum]) # Split -> (3)Eltwise(sum) -> **Sigmoid forget_sigmoid = Sigmoid(graph, { 'name': 'forget_sigmoid' }).create_node([join_prev_lstm_input_joined_forget_sum]) # *Memory(state) \ # (6)Eltwise(mul) # Split -> (3)Eltwise(sum) -> **Sigmoid / join_forget_prev_state_mul = Mul(graph, { 'name': 'join_forget_prev_state_mul', }).create_node([forget_sigmoid, prev_lstm_state]) # Split -> Tahn \ # (5)Eltwise(mul) # Split -> (2)Eltwise(sum) -> *Sigmoid / join_remember_candidates_mul = Mul( graph, { 'name': 'join_remember_candidates_mul', }).create_node([remember_tahn, remember_sigmoid]) # (5)Eltwise(mul) \ # (7)Eltwise(sum) # (6)Eltwise(mul) / join_forget_remember_sum = Add(graph, { 'name': 'join_forget_remember_sum', }).create_node( [join_forget_prev_state_mul, join_remember_candidates_mul]) # (7)Eltwise(sum) -> Clamp join_forget_clamp = Clamp( graph, { 'name': 'join_forget_clamp', 'max': node.clip_value, 'min': -node.clip_value }).create_node([join_forget_remember_sum]) # # Clamp -> (2)Memory(state) next_lstm_state = Memory( graph, { 'name': 'next_lstm_state', 'id': memory_pair_output, 'index': 0, 'size': 2, 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64) }).create_node([join_forget_clamp]) Result(graph, { 'name': 'next_lstm_state_out' }).create_node([next_lstm_state]) # Clamp -> (2)Tahn state_filtered_tahn = Tanh(graph, { 'name': 'state_filtered_tahn' }).create_node([join_forget_clamp]) # Clamp -> (2)ScaleShift clamp_scaleshift_attrs = { 'name': 'clamp_scaleshift', 'bias_term': False } embed_input(clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights) clamp_scaleshift = ScaleShiftOp( graph, clamp_scaleshift_attrs).create_node([join_forget_clamp]) # Split \ # (4)Eltwise(sum) # Clamp -> (2)ScaleShift / join_next_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_next_lstm_input_joined_input_sum', }).create_node([(split_joined_input, 3), clamp_scaleshift]) # (4)Eltwise(sum) -> (3)Sigmoid output_sigmoid = Sigmoid(graph, { 'name': 'output_sigmoid' }).create_node([join_next_lstm_input_joined_input_sum]) # (4)Eltwise(sum) -> (3)Sigmoid \ # (5)Eltwise(mul) # Clamp -> (2)Tahn / joined_output_mul = Mul(graph, { 'name': 'joined_output_mul' }).create_node([state_filtered_tahn, output_sigmoid]) # (5)Eltwise(mul) -> (3)FullyConnected fc_output_attrs = { 'name': 'FullyConnected', 'num_output': node.projection_weights_shape[0], 'bias_term': False } embed_input(fc_output_attrs, 1, 'weights', node.projection_weights) fc_output = InnerProduct(graph, fc_output_attrs).create_node( [joined_output_mul]) # / (2)Memory(output) # (3)FullyConnected # \ Output (any next node) (edge created automatically after replacement) next_lstm_output = Memory( graph, { 'name': 'next_lstm_output', 'id': memory_pair_input, 'index': 0, 'size': 2, 'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64) }).create_node([fc_output]) Result(graph, { 'name': 'next_lstm_output_out' }).create_node([next_lstm_output]) return [fc_output.id]
def replace_op(self, graph: Graph, node: Node): input_out_port = node.in_port(0).get_source() memory_pair_input = unique_id('id') memory_pair_output = unique_id('id') # Input -> FullyConnected fc_layer_after_input_attrs = { 'name': 'input_fullyconnected', 'out-size': node.gifo_x_weights_shape[0], 'transpose_weights': True, 'bias_term': True, } fc_layer_after_input = FullyConnected( graph, fc_layer_after_input_attrs).create_node() fc_layer_after_input.in_port(0).connect(input_out_port) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases) init_value_prev_lstm_output = create_zero_value_with_batch_from_input( input_out_port, node.gifo_r_weights_shape[1]) prev_lstm_output = ReadValue(graph, { 'name': 'prev_memory_output', 'variable_id': memory_pair_input }).create_node() prev_lstm_output.in_port(0).connect( init_value_prev_lstm_output.out_port(0)) # *Memory(output) -> FullyConnected fc_layer_from_prev_state_attrs = { 'name': 'prev_memory_output_fullyconnected', 'out-size': node.gifo_r_weights_shape[0], 'transpose_weights': True, 'bias_term': False, } fc_layer_from_prev_state = FullyConnected( graph, fc_layer_from_prev_state_attrs).create_node() fc_layer_from_prev_state.in_port(0).connect( prev_lstm_output.out_port(0)) input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights) # Memory -> FullyConnected \ # *Eltwise(sum) # Input -> FullyConnected / join_input_prev_state_sum = Add(graph, { 'name': 'join_input_eltwise' }).create_node() join_input_prev_state_sum.in_port(0).connect( fc_layer_from_prev_state.out_port(0)) join_input_prev_state_sum.in_port(1).connect( fc_layer_after_input.out_port(0)) # *Eltwise(sum) -> Split # it is split into 4 nodes: Act, Eltw*3 # the following order is mandatory # ___Tanh # / # Split ---(2)Eltwise(sum) # |\ # | \__(3)Eltwise(sum) # |____(4)Eltwise(sum) split_joined_input_axis = Const(graph, { 'value': np.int64(1) }).create_node() split_joined_input = Split(graph, { 'name': 'join_input_split', 'num_splits': 4, 'out_ports_count': 4 }).create_node() split_joined_input.in_port(0).connect( join_input_prev_state_sum.out_port(0)) split_joined_input.in_port(1).connect( split_joined_input_axis.out_port(0)) # prev_lstm_state = Memory(graph, {'name': 'prev_memory_state', # 'id': memory_pair_output, # 'index': 1, # 'size': 2, # 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64) # }).create_node() init_value_prev_lstm_state = create_zero_value_with_batch_from_input( split_joined_input.out_port(0), node.input_gate_weights.shape[0]) prev_lstm_state = ReadValue(graph, { 'name': 'prev_memory_state', 'variable_id': memory_pair_output }).create_node() prev_lstm_state.in_port(0).connect( init_value_prev_lstm_state.out_port(0)) # *Memory(state) -> *ScaleShift(input) state_input_scaleshift_attrs = { 'name': 'input_scaleshift', 'bias_term': False } state_input_scaleshift = ScaleShiftOp( graph, state_input_scaleshift_attrs).create_node() state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights) # *Memory(state) -> *ScaleShift(forget) state_forget_scaleshift_attrs = { 'name': 'forget_scaleshift', 'bias_term': False } state_forget_scaleshift = ScaleShiftOp( graph, state_forget_scaleshift_attrs).create_node() state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights) # Split \ # (2)Eltwise(sum) # Memory(state) -> *ScaleShift(input) / join_prev_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_input_eltwise' }).create_node() join_prev_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(1)) join_prev_lstm_input_joined_input_sum.in_port(1).connect( state_input_scaleshift.out_port(0)) # Split \ # (3)Eltwise(sum) # Memory(state) -> *ScaleShift(forget) / join_prev_lstm_input_joined_forget_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_forget_sum', }).create_node() join_prev_lstm_input_joined_forget_sum.in_port(0).connect( split_joined_input.out_port(2)) join_prev_lstm_input_joined_forget_sum.in_port(1).connect( state_forget_scaleshift.out_port(0)) # Split -> Tanh remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node() remember_tahn.in_port(0).connect(split_joined_input.out_port(0)) # Split -> (2)Eltwise(sum) -> *Sigmoid remember_sigmoid = Sigmoid(graph, { 'name': 'remember_sigmoid' }).create_node() remember_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_input_sum.out_port(0)) # Split -> (3)Eltwise(sum) -> **Sigmoid forget_sigmoid = Sigmoid(graph, { 'name': 'forget_sigmoid' }).create_node() forget_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_forget_sum.out_port(0)) # *Memory(state) \ # (6)Eltwise(mul) # Split -> (3)Eltwise(sum) -> **Sigmoid / join_forget_prev_state_mul = Mul(graph, { 'name': 'join_forget_prev_state_mul' }).create_node() join_forget_prev_state_mul.in_port(0).connect( forget_sigmoid.out_port(0)) join_forget_prev_state_mul.in_port(1).connect( prev_lstm_state.out_port(0)) # Split -> Tahn \ # (5)Eltwise(mul) # Split -> (2)Eltwise(sum) -> *Sigmoid / join_remember_candidates_mul = Mul( graph, { 'name': 'join_remember_candidates_mul' }).create_node() join_remember_candidates_mul.in_port(0).connect( remember_tahn.out_port(0)) join_remember_candidates_mul.in_port(1).connect( remember_sigmoid.out_port(0)) # (5)Eltwise(mul) \ # (7)Eltwise(sum) # (6)Eltwise(mul) / join_forget_remember_sum = Add(graph, { 'name': 'join_forget_remember_sum' }).create_node() join_forget_remember_sum.in_port(0).connect( join_forget_prev_state_mul.out_port(0)) join_forget_remember_sum.in_port(1).connect( join_remember_candidates_mul.out_port(0)) # (7)Eltwise(sum) -> Clamp join_forget_clamp = create_op_with_const_inputs( graph, Clamp, { 1: np.array(-node.clip_value, dtype=np.float32), 2: np.array(node.clip_value, dtype=np.float32) }, {'name': 'join_forget_clamp'}, join_forget_remember_sum) # # Clamp -> (2)Memory(state) next_lstm_state = Assign(graph, { 'name': 'next_lstm_state', 'variable_id': memory_pair_output }).create_node() next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0)) res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node() res_node.in_port(0).connect(next_lstm_state.out_port(0)) # Clamp -> (2)Tahn state_filtered_tahn = Tanh(graph, { 'name': 'state_filtered_tahn' }).create_node() state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0)) # Clamp -> (2)ScaleShift clamp_scaleshift_attrs = { 'name': 'clamp_scaleshift', 'bias_term': False } clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node() clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0)) input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights) # Split \ # (4)Eltwise(sum) # Clamp -> (2)ScaleShift / join_next_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_next_lstm_input_joined_input_sum', }).create_node() join_next_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(3)) join_next_lstm_input_joined_input_sum.in_port(1).connect( clamp_scaleshift.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid output_sigmoid = Sigmoid(graph, { 'name': 'output_sigmoid' }).create_node() output_sigmoid.in_port(0).connect( join_next_lstm_input_joined_input_sum.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid \ # (5)Eltwise(mul) # Clamp -> (2)Tahn / joined_output_mul = Mul(graph, { 'name': 'joined_output_mul' }).create_node() joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0)) joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0)) # (5)Eltwise(mul) -> (3)FullyConnected fc_output_attrs = { 'name': 'FullyConnected', 'out-size': node.projection_weights_shape[0], 'transpose_weights': True, 'bias_term': False } fc_output = FullyConnected(graph, fc_output_attrs).create_node() fc_output.in_port(0).connect(joined_output_mul.out_port(0)) input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights) # / (2)Memory(output) # (3)FullyConnected # \ Output (any next node) (edge created automatically after replacement) next_lstm_output = Assign(graph, { 'name': 'next_lstm_output', 'variable_id': memory_pair_input }).create_node() next_lstm_output.in_port(0).connect(fc_output.out_port(0)) res_node_lstm_output = Result(graph, { 'name': 'next_lstm_output_out' }).create_node() res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0)) return [fc_output.id]
def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) split_node = Split(graph, { 'name': graph.unique_id(prefix='Split_lstm_input_'), 'num_split': 5 }).create_node() node.in_port(0).get_connection().set_destination(split_node.in_port(0)) for i in range(5): split_node.add_output_port(i) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = { 'name': graph.unique_id(prefix='i_scaleshift'), 'bias_term': False } embed_input(i_scale_attrs, 1, 'weights', node.i_weights) i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() split_node.out_port(4).connect(i_scale.in_port(0)) sum_i_c = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum' }).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = { 'name': graph.unique_id(prefix='f_scaleshift'), 'bias_term': False } embed_input(f_scale_attrs, 1, 'weights', node.f_weights) f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() split_node.out_port(4).connect(f_scale.in_port(0)) sum_f_c = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum' }).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) prod_i_c_tanh = Eltwise( graph, { 'name': graph.unique_id(prefix='prod_i_c_tanh_'), 'operation': 'mul' }).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Eltwise(graph, { 'name': graph.unique_id(prefix='prod_f_ct_1_'), 'operation': 'mul' }).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) sum_f_i = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum' }).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = { 'name': graph.unique_id(prefix='o_scaleshift'), 'bias_term': False } embed_input(o_scale_attrs, 1, 'weights', node.o_weights) o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum' }).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Eltwise( graph, { 'name': graph.unique_id(prefix='prod_o_c_t_tanh_'), 'operation': 'mul' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, { 'name': graph.unique_id(prefix='Concat_c_m') }).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # check if we have dropout input_port = node.in_port(0) if node.has_and_set('use_dropout'): split_dropout = AttributedVariadicSplit( graph, { 'name': node_name + '/split_dropout', 'size_splits': int64_array([-1, 1, 1, 1]), 'axis': int64_array(1) }).create_node() input_port.get_connection().set_destination( split_dropout.in_port(0)) input_port = split_dropout.out_port(0) i_drop_scale = split_dropout.out_port(1) f_drop_scale = split_dropout.out_port(2) o_drop_scale = split_dropout.out_port(3) # split input to (i_part, f_part, c_part, o_part, ct_1) split_node = create_op_with_const_inputs( graph, Split, {1: np.int64(1)}, { 'name': node_name + '/split_lstm_input', 'num_splits': 5 }) input_port.get_connection().set_destination(split_node.in_port(0)) i_part = split_node.out_port(0) f_part = split_node.out_port(1) c_part = split_node.out_port(2) o_part = split_node.out_port(3) ct_1 = split_node.out_port(4) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = { 'name': node_name + '/i_scaleshift', 'bias_term': False } i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) ct_1.connect(i_scale.in_port(0)) sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node() i_part.connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, { 'name': node_name + '/i_sigmoid' }).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) if node['use_dropout']: mul_dropout_i = Mul(graph, { 'name': split_node.soft_get('name', split_node.id) + '/mul_i' }).create_node() mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0)) mul_dropout_i.in_port(1).connect(i_drop_scale) i_sigmoid = mul_dropout_i # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = { 'name': node_name + '/f_scaleshift', 'bias_term': False } f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) ct_1.connect(f_scale.in_port(0)) sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node() f_part.connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, { 'name': node_name + '/f_sigmoid' }).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) if node['use_dropout']: mul_dropout_f = Mul(graph, { 'name': split_node.soft_get('name', split_node.id) + '/mul_f' }).create_node() mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0)) mul_dropout_f.in_port(1).connect(f_drop_scale) f_sigmoid = mul_dropout_f # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node() c_part.connect(c_tanh.in_port(0)) prod_i_c_tanh = Mul(graph, { 'name': node_name + '/prod_i_c_tanh_' }).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Mul(graph, { 'name': node_name + '/prod_f_ct_1_' }).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) ct_1.connect(prod_f_ct_1.in_port(1)) sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = { 'name': node_name + '/o_scaleshift', 'bias_term': False } o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node() o_part.connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, { 'name': node_name + '/o_sigmoid' }).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) if node['use_dropout']: mul_dropout_o = Mul(graph, { 'name': split_node.soft_get('name', split_node.id) + '/mul_o' }).create_node() mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0)) mul_dropout_o.in_port(1).connect(o_drop_scale) o_sigmoid = mul_dropout_o # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Mul(graph, { 'name': node_name + '/prod_o_c_t_tanh_' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, { 'name': node_name + '/concat_c_m' }).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]