def extract(cls, node): activation_alpha = onnx_attr(node, 'activation_alpha', 'floats', default=None, dst_type=lambda x: np.array(x, dtype=np.float32)) activation_beta = onnx_attr(node, 'activation_beta', 'floats', default=None, dst_type=lambda x: np.array(x, dtype=np.float32)) activations = onnx_attr(node, 'activations', 'strings', default=None, dst_type=lambda x: list(map(lambda s: s.decode(encoding="utf-8").lower(), list(x)))) clip = onnx_attr(node, 'clip', 'f', default=None) input_forget = onnx_attr(node, 'input_forget', 'i', default=0) attrs = { 'batch_dim': 1, 'sequence_dim': 0, 'blobs_wrb': True, 'has_num_directions': True, 'num_layers': 1, 'format': 'onnx', 'multilayers': False, 'gate_order': [2, 0, 3, 1], # iofc --> fico # ONNX attrs 'activation_alpha': activation_alpha, 'activation_beta': activation_beta, 'activations': activations, 'clip': clip, 'direction': onnx_attr(node, 'direction', 's', b'forward').decode().lower(), 'hidden_size': np.array(onnx_attr(node, 'hidden_size', 'i'), dtype=np.int64), 'input_forget': input_forget, } LSTM.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) mode = attrs.str('mode', None) state_size = attrs.int('state_size', None) bidirectional = attrs.bool('bidirectional', False) num_layers = attrs.int('num_layers', 1) layout = attrs.str('layout', 'TNC') # in MXNet RNN by default take data in # format [seq_len, batch_size, inp_size] node_attrs = { 'batch_dim': layout.index('N'), 'sequence_dim': layout.index('T'), 'blobs_wrb': False, 'hidden_size': state_size, 'has_num_directions': bidirectional, 'direction': 'bidirectional' if bidirectional else 'forward', 'num_layers': num_layers, 'format': 'mxnet', 'multilayers': num_layers != 1, 'gate_order': None, } if mode == 'rnn_tanh': node_attrs['gate_order'] = [0] node_attrs['activations'] = ['tanh'] if not bidirectional else [ 'tanh', 'tanh' ] RNN.update_node_stat(node, node_attrs) elif mode == 'rnn_relu': node_attrs['gate_order'] = [0] node_attrs['activations'] = ['relu'] if not bidirectional else [ 'relu', 'relu' ] RNN.update_node_stat(node, node_attrs) elif mode == 'gru': node_attrs['gate_order'] = [1, 0, 2] node_attrs['linear_before_reset'] = 1 GRU.update_node_stat(node, node_attrs) elif mode == 'lstm': node_attrs['gate_order'] = [1, 0, 2, 3] LSTM.update_node_stat(node, node_attrs) else: raise Error( "Operation RNN with mode '{}' not supported." + refer_to_faq_msg(86), mode) return cls.enabled
def replace_pattern(graph: Graph, match: dict): time_len = match['concatenated_hidden_states'].shape[0] """ Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM We simplify the sub-graph below by taking another output of BlockLSTM: concatenated cell states over the whole time sequence -> last cell state BlockLSTM || out 1 (concatenated cell states comming out of BlockLSTM) \/ in 1 ConcatV2 || (concatenation with initial state or another unused data) \/ Reshape || \/ Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len) """ # check that there are no other consumers of concatenated_cell_states_data data flow valid_output_names = [ 'concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data', 'gather_1', 'gather_1_data' ] valid_output_node_ids = [match[name].id for name in valid_output_names] node_names_to_check_outputs = [ 'concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data' ] for name in node_names_to_check_outputs: for node in match[name].out_nodes(): if node.id not in valid_output_node_ids: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) # check that we really take the last cell state data by Gather gather_indexes = match['gather_1'].in_node(1).value if len(gather_indexes) == 1: gather_index = gather_indexes[0] else: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) if gather_index != time_len: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) """ We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern to LSTMSequence even without following optimizations """ node = match['BlockLSTM'] weights_node = node.in_node(1) biases_node = node.in_node(2) shift_const = node.forget_bias # Assign temporary shape for them for easier manipulation # TF stores weights in IO order input_size = node.in_node(0).shape[-1] hidden_size = node.in_node(3).shape[-1] weights = weights_node.value biases = biases_node.value assert weights.shape[0] == input_size + hidden_size, \ "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size) assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \ "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size) weights = weights.reshape([ weights.shape[0], 4, # gates hidden_size ]) biases = biases.reshape([ 4, # gates hidden_size ]) # Reorder gates icfo --> fico for both weights and biases gate_reorder = [2, 0, 1, 3] weights = np.take(weights, gate_reorder, axis=1) biases = np.take(biases, gate_reorder, axis=0) # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0) # Note: in case of moving this code up before gate reordering, the addition # should be applied at different place biases[0] += shift_const # Return to the original shapes weights = weights.reshape([weights.shape[0], -1]) biases = biases.flatten() # TF stores weights in IO, but IE requires it in OI: transpose weights = weights.transpose() weights_node.value = weights weights_node.shape = np.array(weights.shape, dtype=np.int64) biases_node.value = biases biases_node.shape = np.array(biases.shape, dtype=np.int64) attrs = dict( graph.get_edge_data(match['gather_1'].id, match['gather_1_data'].id)[0]) attrs.update({'out': 2}) graph.remove_edge(match['BlockLSTM'].id, match['concatenated_cell_states_data'].id) graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id) graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs) """ #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order. """ h_init_port = 4 c_init_port = 5 # c_init_state if 4 in node.in_nodes(): assert c_init_port not in node.in_nodes() cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id) cell_state_edge[0]['in'] = c_init_port #h_init_state if 3 in node.in_nodes(): assert h_init_port not in node.in_nodes() hidden_state_edge = graph.get_edge_data( node.in_node(3).id, node.id) hidden_state_edge[0]['in'] = h_init_port new_attrs = { 'sequence_dim': 0, 'batch_dim': 1, 'direction': 'forward', 'hidden_size': match['concatenated_hidden_states'].shape[-1], 'format': 'tf', } LSTM.update_node_stat(match['BlockLSTM'], new_attrs) """ Optional #4 optimization from class description following """ data_to_mul = [ n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id ] if len(data_to_mul) != 1: return # unexpected type of mul data_to_mul = data_to_mul[0] if not data_to_mul.has_valid('value'): return # unexpected type of mul data_to_mul_value = data_to_mul.value if not np.all(data_to_mul_value == 1): return # unexpected type of mul # remove useless mul attrs = dict( graph.get_edge_data(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)[0]) graph.remove_edge(match['BlockLSTM'].id, match['concatenated_hidden_states'].id) graph.remove_edge(match['mul'].id, match['mul_data'].id) graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs) # find true usages of concatenated hidden states data (not last hidden state) valid_output_names = [ 'mul_data', 'concat_0', 'concat_0_data', 'reshape_0', 'reshape_0_data', 'gather_0', 'gather_0_data' ] valid_output_node_ids = [match[name].id for name in valid_output_names] node_names_to_check_outputs = [ 'mul_data', 'concat_0_data', 'reshape_0_data' ] list_of_concatenated_hidden_states_children_node_ids = [] for name in node_names_to_check_outputs: for node in match[name].out_nodes(): if node.id not in valid_output_node_ids: list_of_concatenated_hidden_states_children_node_ids.append( node.id) if len(list_of_concatenated_hidden_states_children_node_ids) != 1: return # not supported placement of patten conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[ 0] if conacenated_child_node_id != match[ 'after_mul_op_to_the_rest_of_model'].id: return # not supported placement of patten gather_indexes = match['gather_0'].in_node(1).value if len(gather_indexes) == 1: gather_index = gather_indexes[0] else: return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is if gather_index != time_len: return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is attrs = dict( graph.get_edge_data(match['gather_0'].id, match['gather_0_data'].id)[0]) attrs.update({'out': 1}) graph.remove_edge(match['mul_data'].id, match['concat_0'].id) graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id) graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id, **attrs)