Esempio n. 1
0
 def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0):
     super(ShiftRecurrentCore, self).__init__()
     if rnn_mode == 'RNN':
         RNN = nn.RNN
     elif rnn_mode == 'LSTM':
         RNN = nn.LSTM
     elif rnn_mode == 'GRU':
         RNN = nn.GRU
     else:
         raise ValueError('Unknown RNN mode: %s' % rnn_mode)
     assert hidden_size % 2 == 0
     self.tgt_embed = embed
     assert num_layers == 1
     self.bidirectional = bidirectional
     if bidirectional:
         self.rnn = RNN(embed.embedding_dim, hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True)
     else:
         self.rnn = RNN(embed.embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False)
     self.use_attn = use_attn
     if use_attn:
         self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size)
         self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU())
     else:
         self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU())
     self.dropout = dropout
     self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens
     self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True)
     self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True)
Esempio n. 2
0
    def __init__(self,
                 rnn_mode,
                 src_features,
                 in_features,
                 out_features,
                 hidden_features,
                 dropout=0.0):
        super(NICERecurrentBlock, self).__init__()
        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)

        self.rnn = RNN(in_features,
                       hidden_features // 2,
                       batch_first=True,
                       bidirectional=True)
        self.attn = GlobalAttention(src_features,
                                    hidden_features,
                                    hidden_features,
                                    dropout=dropout)
        self.linear = LinearWeightNorm(in_features + hidden_features,
                                       out_features,
                                       bias=True)
Esempio n. 3
0
 def __init__(self,
              src_features,
              in_features,
              out_features,
              hidden_features,
              kernel_size,
              dropout=0.0):
     super(NICEConvBlock, self).__init__()
     self.conv1 = Conv1dWeightNorm(in_features,
                                   hidden_features,
                                   kernel_size=kernel_size,
                                   padding=kernel_size // 2,
                                   bias=True)
     self.conv2 = Conv1dWeightNorm(hidden_features,
                                   hidden_features,
                                   kernel_size=kernel_size,
                                   padding=kernel_size // 2,
                                   bias=True)
     self.activation = nn.ELU(inplace=True)
     self.attn = GlobalAttention(src_features,
                                 hidden_features,
                                 hidden_features,
                                 dropout=dropout)
     self.linear = LinearWeightNorm(hidden_features * 2,
                                    out_features,
                                    bias=True)
Esempio n. 4
0
 def __init__(self, vocab_size, latent_dim, hidden_size, dropout=0.0, label_smoothing=0., _shared_weight=None):
     super(SimpleDecoder, self).__init__(vocab_size, latent_dim,
                                         label_smoothing=label_smoothing,
                                         _shared_weight=_shared_weight)
     self.attn = GlobalAttention(latent_dim, latent_dim, latent_dim, hidden_features=hidden_size)
     ctx_features = latent_dim * 2
     self.ctx_proj = nn.Sequential(nn.Linear(ctx_features, latent_dim), nn.ELU())
     self.dropout = dropout
Esempio n. 5
0
    def __init__(self,
                 vocab_size,
                 latent_dim,
                 rnn_mode,
                 num_layers,
                 hidden_size,
                 bidirectional=True,
                 dropout=0.0,
                 dropword=0.0,
                 label_smoothing=0.,
                 _shared_weight=None):
        super(RecurrentDecoder, self).__init__(vocab_size,
                                               latent_dim,
                                               label_smoothing=label_smoothing,
                                               _shared_weight=_shared_weight)

        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)
        assert hidden_size % 2 == 0
        # RNN for processing latent variables zs
        if bidirectional:
            self.rnn = RNN(latent_dim,
                           hidden_size // 2,
                           num_layers=num_layers,
                           batch_first=True,
                           bidirectional=True)
        else:
            self.rnn = RNN(latent_dim,
                           hidden_size,
                           num_layers=num_layers,
                           batch_first=True,
                           bidirectional=False)

        self.attn = GlobalAttention(latent_dim,
                                    hidden_size,
                                    latent_dim,
                                    hidden_features=hidden_size)
        self.ctx_proj = nn.Sequential(
            nn.Linear(latent_dim + hidden_size, latent_dim), nn.ELU())
        self.dropout = dropout
        self.dropout2d = nn.Dropout2d(
            dropword) if dropword > 0. else None  # drop entire tokens
Esempio n. 6
0
class ShiftRecurrentCore(nn.Module):
    def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0):
        super(ShiftRecurrentCore, self).__init__()
        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)
        assert hidden_size % 2 == 0
        self.tgt_embed = embed
        assert num_layers == 1
        self.bidirectional = bidirectional
        if bidirectional:
            self.rnn = RNN(embed.embedding_dim, hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True)
        else:
            self.rnn = RNN(embed.embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False)
        self.use_attn = use_attn
        if use_attn:
            self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size)
            self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU())
        else:
            self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU())
        self.dropout = dropout
        self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens
        self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True)
        self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True)

    #@overrides
    def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
        tgt_embed = self.tgt_embed(tgt_sents)
        if self.dropout2d is not None:
            tgt_embed = self.dropout2d(tgt_embed)
        lengths = tgt_masks.sum(dim=1).long()
        packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
        packed_enc, _ = self.rnn(packed_embed)
        tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))

        if self.bidirectional:
            # split into fwd and bwd
            fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2)
            pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2)))
            pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1)
            pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1)
            tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2)
        else:
            pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2)))
            tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1]

        if self.use_attn:
            ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0))
            ctx = torch.cat([tgt_enc, ctx], dim=2)
        else:
            ctx = tgt_enc
        ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
        mu = self.mu(ctx) * tgt_masks.unsqueeze(2)
        logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2)
        return mu, logvar

    def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True):
        with torch.no_grad():
            tgt_embed = self.tgt_embed(tgt_sents)
            if self.dropout2d is not None:
                tgt_embed = self.dropout2d(tgt_embed)
            lengths = tgt_masks.sum(dim=1).long()
            packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
            packed_enc, _ = self.rnn(packed_embed)
            tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))

            if self.bidirectional:
                fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2)  # (batch_size, seq_len, hidden_size // 2)
                pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2)))
                pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1)
                pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1)
                tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2)
            else:
                pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2)))
                tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1]

            if self.use_attn:
                ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale)
                ctx = torch.cat([tgt_enc, ctx], dim=2)
            else:
                ctx = tgt_enc
            ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
            mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx)
            logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx)
            mu = mu * tgt_masks.unsqueeze(2)
            logvar = logvar * tgt_masks.unsqueeze(2)
            return mu, logvar
Esempio n. 7
0
class NICERecurrentBlock(nn.Module):
    def __init__(self,
                 rnn_mode,
                 src_features,
                 in_features,
                 out_features,
                 hidden_features,
                 dropout=0.0):
        super(NICERecurrentBlock, self).__init__()
        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)

        self.rnn = RNN(in_features,
                       hidden_features // 2,
                       batch_first=True,
                       bidirectional=True)
        self.attn = GlobalAttention(src_features,
                                    hidden_features,
                                    hidden_features,
                                    dropout=dropout)
        self.linear = LinearWeightNorm(in_features + hidden_features,
                                       out_features,
                                       bias=True)

    def forward(self, x, mask, src, src_mask):
        lengths = mask.sum(dim=1).long()
        packed_out = pack_padded_sequence(x,
                                          lengths,
                                          batch_first=True,
                                          enforce_sorted=False)
        packed_out, _ = self.rnn(packed_out)
        out, _ = pad_packed_sequence(packed_out,
                                     batch_first=True,
                                     total_length=mask.size(1))
        # [batch, length, out_features]
        out = self.attn(out, src, key_mask=src_mask.eq(0))
        out = self.linear(torch.cat([x, out], dim=2))
        return out

    def init(self, x, mask, src, src_mask, init_scale=1.0):
        lengths = mask.sum(dim=1).long()
        packed_out = pack_padded_sequence(x,
                                          lengths,
                                          batch_first=True,
                                          enforce_sorted=False)
        packed_out, _ = self.rnn(packed_out)
        out, _ = pad_packed_sequence(packed_out,
                                     batch_first=True,
                                     total_length=mask.size(1))
        # [batch, length, out_features]
        out = self.attn.init(out,
                             src,
                             key_mask=src_mask.eq(0),
                             init_scale=init_scale)
        out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0)
        return out
Esempio n. 8
0
class NICEConvBlock(nn.Module):
    def __init__(self,
                 src_features,
                 in_features,
                 out_features,
                 hidden_features,
                 kernel_size,
                 dropout=0.0):
        super(NICEConvBlock, self).__init__()
        self.conv1 = Conv1dWeightNorm(in_features,
                                      hidden_features,
                                      kernel_size=kernel_size,
                                      padding=kernel_size // 2,
                                      bias=True)
        self.conv2 = Conv1dWeightNorm(hidden_features,
                                      hidden_features,
                                      kernel_size=kernel_size,
                                      padding=kernel_size // 2,
                                      bias=True)
        self.activation = nn.ELU(inplace=True)
        self.attn = GlobalAttention(src_features,
                                    hidden_features,
                                    hidden_features,
                                    dropout=dropout)
        self.linear = LinearWeightNorm(hidden_features * 2,
                                       out_features,
                                       bias=True)

    def forward(self, x, mask, src, src_mask):
        """

        Args:
            x: Tensor
                input tensor [batch, length, in_features]
            mask: Tensor
                x mask tensor [batch, length]
            src: Tensor
                source input tensor [batch, src_length, src_features]
            src_mask: Tensor
                source mask tensor [batch, src_length]

        Returns: Tensor
            out tensor [batch, length, out_features]

        """
        out = self.activation(self.conv1(x.transpose(1, 2)))
        out = self.activation(self.conv2(out)).transpose(1,
                                                         2) * mask.unsqueeze(2)
        out = self.attn(out, src, key_mask=src_mask.eq(0))
        out = self.linear(torch.cat([x, out], dim=2))
        return out

    def init(self, x, mask, src, src_mask, init_scale=1.0):
        out = self.activation(
            self.conv1.init(x.transpose(1, 2), init_scale=init_scale))
        out = self.activation(self.conv2.init(
            out, init_scale=init_scale)).transpose(1, 2) * mask.unsqueeze(2)
        out = self.attn.init(out,
                             src,
                             key_mask=src_mask.eq(0),
                             init_scale=init_scale)
        out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0)
        return out