Ejemplo n.º 1
0
class SRUC(nn.Module):

    def __init__(
                self,
                input_dim=257,
                output_dim=257,
                hidden_layers=2,
                hidden_units=512,
                left_context=1,
                right_context=1,
                kernel_size=6,
                kernel_num=9,
                target_mode='MSA',
                dropout=0.2
                
        ):
        super(SRUC, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers
        self.hidden_units = hidden_units
        self.left_context = left_context
        self.right_context = right_context
        self.kernel_size = kernel_size
        self.kernel_sum = kernel_num
        self.target_mode = target_mode

        self.input_layer = nn.Sequential(
                nn.Linear((left_context+1+right_context)*input_dim, hidden_units),
                nn.Tanh()
            )
        
        self.rnn_layer = SRU(
                    input_size=hidden_units,
                    hidden_size=hidden_units,
                    num_layers=self.hidden_layers,
                    dropout=dropout,
                    rescale=True,
                    bidirectional=False,
                    layer_norm=False
            )
        
        self.conv2d_layer = nn.Sequential(
                #nn.Conv2d(in_channels=1,out_channels=kernel_num,kernel_size=(kernel_size, kernel_size), stride=[1,1],padding=(5,5), dilation=(2,2)),
                modules.Conv2d(in_channels=1, out_channels=kernel_num, kernel_size=(kernel_size, kernel_size)),
                nn.Tanh(),
                nn.MaxPool2d(3,stride=1,padding=(1,1))
            )
        
        self.output_layer = nn.Sequential(
                nn.Linear(hidden_units*kernel_num, (left_context+1+right_context)*self.output_dim),
                nn.Sigmoid()
            )
        #self.loss_func = nn.MSELoss(reduction='sum')
        #self.loss_func = nn.MSELoss()
        #show_model(self)
        #show_params(self)
        #self.apply(self._init_weights)
        #self.flatten_parameters()

    def flatten_parameters(self):
        self.rnn_layer.flatten_parameters()

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.constant_(param, 0.0)
        elif isinstance(m, nn.GRU):
            for name, param in m.named_parameters():
                if 'weight_ih' in name:
                    for ih in param.chunk(3,0):
                        nn.init.xavier_uniform_(ih)
                elif 'weight_hh' in name:
                    for hh in param.chunk(3,0):
                        nn.init.orthogonal_(hh)
                elif 'bias_ih' in name:
                    nn.init.zeros_(param)

    def forward(self, inputs, lens=None):
        outputs = self.input_layer(inputs)
#        packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(outputs, lens, batch_first=True)
#        outputs, _ = self.rnn_layer(packed_inputs)
#        outputs, lens = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        outputs = torch.transpose(outputs,0,1)
        outputs, _ = self.rnn_layer(outputs)
        outputs = torch.transpose(outputs,0,1)
        # reshape outputs to [batch_size, 1, length, dims]
        outputs = torch.unsqueeze(outputs, 1)
        # conv outputs to [batch_size, channels, length, dims]
        outputs = self.conv2d_layer(outputs)
        # conv outputs to [batch_size, dims, length, channels]
        outputs = torch.transpose(outputs, 1, -1)
        # conv outputs to [batch_size, length, dims, channels]
        outputs = torch.transpose(outputs, 1, 2)
        batch_size, max_len, dims, channels = outputs.size()

        outputs = torch.reshape(outputs, [batch_size, max_len, -1])
        mask = self.output_layer(outputs)
        #outputs = mask
        if self.target_mode == 'PSA' or self.target_mode == 'MSA':
            outputs = mask*inputs
            return outputs, outputs[:, :, self.left_context*self.output_dim:(self.left_context+1)*self.output_dim]
        elif self.target_mode == 'SPEC' or self.target_mode == 'TCS':
            outputs = mask 
            return outputs, outputs[:, :, self.left_context*self.output_dim:(self.left_context+1)*self.output_dim]
        else:
            outputs = mask 
            return outputs, (mask*inputs)[:, :, self.left_context*self.output_dim:(self.left_context+1)*self.output_dim]

    def get_params(self, weight_decay=0.0):
            # add L2 penalty
        weights, biases = [], []
        for name, param in self.named_parameters():
            if 'bias' in name:
                biases += [param]
            else:
                weights += [param]
        params = [{
                     'params': weights,
                     'weight_decay': weight_decay,
                 }, {
                     'params': biases,
                     'weight_decay': 0.0,
                 }]
        return params
Ejemplo n.º 2
0
class RNNEncoder(nn.Module):
    """Implements a multi-layer RNN.

    This module can be used to create multi-layer RNN models, and
    provides a way to reduce to output of the RNN to a single hidden
    state by pooling the encoder states either by taking the maximum,
    average, or by taking the last hidden state before padding.

    Padding is delt with by using torch's PackedSequence.

    Attributes
    ----------
    rnn: nn.Module
        The rnn submodule

    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        n_layers: int = 1,
        rnn_type: str = "lstm",
        dropout: float = 0,
        bidirectional: bool = False,
        layer_norm: bool = False,
        highway_bias: float = 0,
        rescale: bool = True,
        enforce_sorted: bool = False,
        **kwargs,
    ) -> None:
        """Initializes the RNNEncoder object.

        Parameters
        ----------
        input_size : int
            The dimension the input data
        hidden_size : int
            The hidden dimension to encode the data in
        n_layers : int, optional
            The number of rnn layers, defaults to 1
        rnn_type : str, optional
           The type of rnn cell, one of: `lstm`, `gru`, `sru`
           defaults to `lstm`
        dropout : float, optional
            Amount of dropout to use between RNN layers, defaults to 0
        bidirectional : bool, optional
            Set to use a bidrectional encoder, defaults to False
        layer_norm : bool, optional
            [SRU only] whether to use layer norm
        highway_bias : float, optional
            [SRU only] value to use for the highway bias
        rescale : bool, optional
            [SRU only] whether to use rescaling
        enforce_sorted: bool
            Whether rnn should enforce that sequences are ordered by
            length. Requires True for ONNX support. Defaults to False.
        kwargs
            Additional parameters to be passed to SRU when building
            the rnn.

        Raises
        ------
        ValueError
            The rnn type should be one of: `lstm`, `gru`, `sru`

        """
        super().__init__()

        self.rnn_type = rnn_type
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.enforce_sorted = enforce_sorted
        if rnn_type in ["lstm", "gru"]:
            if kwargs:
                logger.warn(
                    f"The following '{kwargs}' will be ignored " +
                    "as they are only considered when using 'sru' as " +
                    "'rnn_type'")

            rnn_fn = nn.LSTM if rnn_type == "lstm" else nn.GRU
            self.rnn = rnn_fn(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=n_layers,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        elif rnn_type == "sru":
            try:
                from sru import SRU
            except:  # noqa: E7222
                raise ImportError(
                    "SRU not installed. You can install it with: `pip install sru`"
                )

            try:
                self.rnn = SRU(
                    input_size,
                    hidden_size,
                    num_layers=n_layers,
                    dropout=dropout,
                    bidirectional=bidirectional,
                    layer_norm=layer_norm,
                    rescale=rescale,
                    highway_bias=highway_bias,
                    **kwargs,
                )
            except TypeError:
                raise ValueError(f"Unkown kwargs passed to SRU: {kwargs}")
        else:
            raise ValueError(
                f"Unkown rnn type: {rnn_type}, use of of: gru, sru, lstm")

    def forward(
        self,  # type: ignore
        data: Tensor,
        state: Optional[Tensor] = None,
        padding_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """Performs a forward pass through the network.

        Parameters
        ----------
        data : Tensor
            The input data, as a float tensor of shape [B x S x E]
        state: Tensor
            An optional previous state of shape [L x B x H]
        padding_mask: Tensor, optional
            The padding mask of shape [B x S]

        Returns
        -------
        Tensor
            The encoded output, as a float tensor of shape [B x H]

        """
        # Transpose to sequence first
        if hasattr(self.rnn, "flatten_parameters"):
            self.rnn.flatten_parameters()

        data = data.transpose(0, 1)
        if padding_mask is not None:
            padding_mask = padding_mask.transpose(0, 1)

        if padding_mask is None:
            # Default RNN behavior
            output, state = self.rnn(data, state)
        elif self.rnn_type == "sru":
            # SRU takes a mask instead of PackedSequence objects
            # Write (1 - mask_t) in weird way for type checking to work
            output, state = self.rnn(data,
                                     state,
                                     mask_pad=(-padding_mask + 1).byte())
        else:
            # Deal with variable length sequences
            lengths = padding_mask.long().sum(dim=0)
            # Pass through the RNN
            packed = nn.utils.rnn.pack_padded_sequence(
                data, lengths, enforce_sorted=self.enforce_sorted)
            output, state = self.rnn(packed, state)
            output, _ = nn.utils.rnn.pad_packed_sequence(output)

        # back to batch first
        output = output.transpose(0, 1).contiguous()

        # Compute lengths and pool the last hidden state before padding
        if padding_mask is None:
            lengths = torch.tensor([output.size(1)] * output.size(0)).long()
        else:
            lengths = padding_mask.long().sum(dim=0)

        return output[torch.arange(output.size(0)).long(), lengths - 1, :]