Esempio n. 1
0
    def extract(cls, node):
        activation_alpha = onnx_attr(
            node,
            'activation_alpha',
            'floats',
            default=None,
            dst_type=lambda x: np.array(x, dtype=np.float32))
        activation_beta = onnx_attr(
            node,
            'activation_beta',
            'floats',
            default=None,
            dst_type=lambda x: np.array(x, dtype=np.float32))
        activations = onnx_attr(
            node,
            'activations',
            'strings',
            default=None,
            dst_type=lambda x: list(
                map(lambda s: s.decode(encoding="utf-8").lower(), list(x))))
        clip = onnx_attr(node, 'clip', 'f', default=None)
        input_forget = onnx_attr(node, 'input_forget', 'i', default=0)

        attrs = {
            'batch_dim':
            1,
            'sequence_dim':
            0,
            'blobs_wrb':
            True,
            'has_num_directions':
            True,
            'num_layers':
            1,
            'format':
            'onnx',
            'multilayers':
            False,
            'gate_order': [2, 0, 3, 1],  # iofc --> fico

            # ONNX attrs
            'activation_alpha':
            activation_alpha,
            'activation_beta':
            activation_beta,
            'activations':
            activations,
            'clip':
            clip,
            'direction':
            onnx_attr(node, 'direction', 's', b'forward').decode().lower(),
            'hidden_size':
            np.array(onnx_attr(node, 'hidden_size', 'i'), dtype=np.int64),
            'input_forget':
            input_forget,
        }

        LSTM.update_node_stat(node, attrs)
        return cls.enabled
Esempio n. 2
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        mode = attrs.str('mode', None)
        state_size = attrs.int('state_size', None)
        bidirectional = attrs.bool('bidirectional', False)
        num_layers = attrs.int('num_layers', 1)
        layout = attrs.str('layout',
                           'TNC')  # in MXNet RNN by default take data in
        # format [seq_len, batch_size, inp_size]

        node_attrs = {
            'batch_dim': layout.index('N'),
            'sequence_dim': layout.index('T'),
            'blobs_wrb': False,
            'hidden_size': state_size,
            'has_num_directions': bidirectional,
            'direction': 'bidirectional' if bidirectional else 'forward',
            'num_layers': num_layers,
            'format': 'mxnet',
            'multilayers': num_layers != 1,
            'gate_order': None,
        }

        if mode == 'rnn_tanh':
            node_attrs['gate_order'] = [0]
            node_attrs['activations'] = ['tanh'] if not bidirectional else [
                'tanh', 'tanh'
            ]
            RNN.update_node_stat(node, node_attrs)
        elif mode == 'rnn_relu':
            node_attrs['gate_order'] = [0]
            node_attrs['activations'] = ['relu'] if not bidirectional else [
                'relu', 'relu'
            ]
            RNN.update_node_stat(node, node_attrs)
        elif mode == 'gru':
            node_attrs['gate_order'] = [1, 0, 2]
            node_attrs['linear_before_reset'] = 1
            GRU.update_node_stat(node, node_attrs)
        elif mode == 'lstm':
            node_attrs['gate_order'] = [1, 0, 2, 3]
            LSTM.update_node_stat(node, node_attrs)
        else:
            raise Error(
                "Operation RNN with mode '{}' not supported." +
                refer_to_faq_msg(86), mode)
        return cls.enabled
    def replace_pattern(graph: Graph, match: dict):
        time_len = match['concatenated_hidden_states'].shape[0]
        r"""
        Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have
        concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM

        We simplify the sub-graph below by taking another output of BlockLSTM:
        concatenated cell states over the whole time sequence -> last cell state

        BlockLSTM
           || out 1 (concatenated cell states coming out of BlockLSTM)
           \/  in 1
        ConcatV2
           || (concatenation with initial state or another unused data)
           \/
        Reshape
           ||
           \/
         Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len)
        """
        # check that there are no other consumers of concatenated_cell_states_data data flow
        valid_output_names = [
            'concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data',
            'gather_1', 'gather_1_data'
        ]
        valid_output_node_ids = [match[name].id for name in valid_output_names]
        node_names_to_check_outputs = [
            'concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data'
        ]
        for name in node_names_to_check_outputs:
            for node in match[name].out_nodes():
                if node.id not in valid_output_node_ids:
                    raise Error(
                        "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                        "time sequence. It is not replaceable by another output and is not supported "
                        "originally".format(match['BlockLSTM'].id))

        # check that we really take the last cell state data by Gather
        gather_indexes = match['gather_1'].in_node(1).value
        if len(gather_indexes) == 1:
            gather_index = gather_indexes[0]
        else:
            raise Error(
                "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                "time sequence. It is not replaceable by another output and is not supported "
                "originally".format(match['BlockLSTM'].id))
        if gather_index != time_len:
            raise Error(
                "BlockLSTM node {} has output which contains concatenated cell states over the whole "
                "time sequence. It is not replaceable by another output and is not supported "
                "originally".format(match['BlockLSTM'].id))
        """
        We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern 
        to LSTMSequence even without following optimizations
        """

        node = match['BlockLSTM']
        weights_node = node.in_node(1)
        biases_node = node.in_node(2)
        shift_const = node.forget_bias

        # Assign temporary shape for them for easier manipulation
        # TF stores weights in IO order
        input_size = node.in_node(0).shape[-1]
        hidden_size = node.in_node(3).shape[-1]
        weights = weights_node.value
        biases = biases_node.value
        assert weights.shape[0] == input_size + hidden_size, \
            "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
            "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)

        weights = weights.reshape([
            weights.shape[0],
            4,  # gates
            hidden_size
        ])

        biases = biases.reshape([
            4,  # gates
            hidden_size
        ])

        # Reorder gates icfo --> fico for both weights and biases
        gate_reorder = [2, 0, 1, 3]
        weights = np.take(weights, gate_reorder, axis=1)
        biases = np.take(biases, gate_reorder, axis=0)

        # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
        # Note: in case of moving this code up before gate reordering, the addition
        # should be applied at different place
        biases[0] += shift_const

        # Return to the original shapes
        weights = weights.reshape([weights.shape[0], -1])
        biases = biases.flatten()

        # TF stores weights in IO, but IE requires it in OI: transpose
        weights = weights.transpose()

        weights_node.value = weights
        weights_node.shape = int64_array(weights.shape)
        biases_node.value = biases
        biases_node.shape = int64_array(biases.shape)

        attrs = dict(
            graph.get_edge_data(match['gather_1'].id,
                                match['gather_1_data'].id)[0])
        attrs.update({'out': 2})
        graph.remove_edge(match['BlockLSTM'].id,
                          match['concatenated_cell_states_data'].id)
        graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id)

        match['BlockLSTM'].add_output_port(attrs['out'])
        graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id,
                       **attrs)
        """
        #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order.
        """
        h_init_port = 4
        c_init_port = 5
        # c_init_state
        if 4 in node.in_nodes():
            assert c_init_port not in node.in_nodes()
            cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id)
            cell_state_edge[0]['in'] = c_init_port

        #h_init_state
        if 3 in node.in_nodes():
            assert h_init_port not in node.in_nodes()
            hidden_state_edge = graph.get_edge_data(
                node.in_node(3).id, node.id)
            hidden_state_edge[0]['in'] = h_init_port

        new_attrs = {
            'sequence_dim': 0,
            'batch_dim': 1,
            'direction': 'forward',
            'hidden_size': match['concatenated_hidden_states'].shape[-1],
            'format': 'tf',
        }

        LSTM.update_node_stat(match['BlockLSTM'], new_attrs)
        """
        Optional #4 optimization from class description following
        """
        data_to_mul = [
            n for n in match['mul'].in_nodes().values()
            if n.id != match['concatenated_hidden_states'].id
        ]
        if len(data_to_mul) != 1:
            return  # unexpected type of mul
        data_to_mul = data_to_mul[0]
        if not data_to_mul.has_valid('value'):
            return  # unexpected type of mul
        data_to_mul_value = data_to_mul.value
        if not np.all(data_to_mul_value == 1):
            return  # unexpected type of mul

        # remove useless mul
        attrs = dict(
            graph.get_edge_data(match['BlockLSTM'].id,
                                match['concatenated_hidden_states'].id)[0])
        graph.remove_edge(match['BlockLSTM'].id,
                          match['concatenated_hidden_states'].id)
        graph.remove_edge(match['mul'].id, match['mul_data'].id)
        graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs)

        # find true usages of concatenated hidden states data (not last hidden state)
        valid_output_names = [
            'mul_data', 'concat_0', 'concat_0_data', 'reshape_0',
            'reshape_0_data', 'gather_0', 'gather_0_data'
        ]
        valid_output_node_ids = [match[name].id for name in valid_output_names]
        node_names_to_check_outputs = [
            'mul_data', 'concat_0_data', 'reshape_0_data'
        ]

        list_of_concatenated_hidden_states_children_node_ids = []
        for name in node_names_to_check_outputs:
            for node in match[name].out_nodes():
                if node.id not in valid_output_node_ids:
                    list_of_concatenated_hidden_states_children_node_ids.append(
                        node.id)

        if len(list_of_concatenated_hidden_states_children_node_ids) != 1:
            return  # not supported placement of pattern
        conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[
            0]
        if conacenated_child_node_id != match[
                'after_mul_op_to_the_rest_of_model'].id:
            return  # not supported placement of pattern

        gather_indexes = match['gather_0'].in_node(1).value
        if len(gather_indexes) == 1:
            gather_index = gather_indexes[0]
        else:
            return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is
        if gather_index != time_len:
            return  # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is

        attrs = dict(
            graph.get_edge_data(match['gather_0'].id,
                                match['gather_0_data'].id)[0])
        attrs.update({'out': 1})
        graph.remove_edge(match['mul_data'].id, match['concat_0'].id)
        graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id)

        graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id,
                       **attrs)