def replace_op(self, graph: Graph, node: Node): reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0, 'name': node.name + '/reciprocal_'}).create_node() mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(reciprocal.in_port(0)) node.in_port(0).get_connection().set_destination(mul.in_port(1)) reciprocal.out_port(0).connect(mul.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [mul.id]
def replace_op(self, graph: Graph, node: Node): ss_node = Split(graph, attrs={ 'name': 'Split_eltwise_' + node.name, 'num_split': node['num_inputs'] }).create_node() inp = node.get_inputs() in_node = inp[0][0] edge_attrs = inp[0][1] graph.add_edge(in_node, ss_node.id, **edge_attrs) if ss_node.num_split == 2: eltwise_node = Eltwise(graph, attrs={ 'name': 'Eltwise_' + node.name, 'operation': node['operation'] }).create_node() elif ss_node.num_split > 2: eltwise_node = EltwiseN(graph, attrs={ 'name': 'Eltwise_' + node.name, 'operation': node['operation'] }).create_node() else: raise Error('Error on replacing Kaldi eltwise') for i in range(ss_node.num_split): ss_node.add_output_port(i) ss_node.out_port(i).get_connection().set_destination( eltwise_node.in_port(i)) return [eltwise_node.id]
def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) split_node_axis = Const(graph, {'value': np.int64(1)}).create_node() split_node = Split(graph, { 'name': graph.unique_id(prefix='Split_lstm_input_'), 'num_splits': 5 }).create_node() node.in_port(0).get_connection().set_destination(split_node.in_port(0)) split_node.in_port(1).connect(split_node_axis.out_port(0)) # i_t = Sigmoid(i_part + w_ic*ct_1) i_scale_attrs = { 'name': graph.unique_id(prefix='i_scaleshift'), 'bias_term': False } i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) split_node.out_port(4).connect(i_scale.in_port(0)) sum_i_c = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum' }).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) f_scale_attrs = { 'name': graph.unique_id(prefix='f_scaleshift'), 'bias_term': False } f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) split_node.out_port(4).connect(f_scale.in_port(0)) sum_f_c = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum' }).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) prod_i_c_tanh = Eltwise( graph, { 'name': graph.unique_id(prefix='prod_i_c_tanh_'), 'operation': 'mul' }).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) prod_f_ct_1 = Eltwise(graph, { 'name': graph.unique_id(prefix='prod_f_ct_1_'), 'operation': 'mul' }).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) sum_f_i = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum' }).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) o_scale_attrs = { 'name': graph.unique_id(prefix='o_scaleshift'), 'bias_term': False } o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) sum_o_c = Eltwise(graph, { 'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum' }).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) prod_o_c_t_tanh = Eltwise( graph, { 'name': graph.unique_id(prefix='prod_o_c_t_tanh_'), 'operation': 'mul' }).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output concat = Concat(graph, { 'name': graph.unique_id(prefix='Concat_c_m') }).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) return [concat.id]