コード例 #1
0
class TransformerCore(nn.Module):
    def __init__(self, embed, num_layers, latent_dim, hidden_size, heads, dropout=0.0, dropword=0.0, max_length=100):
        super(TransformerCore, self).__init__()
        self.tgt_embed = embed
        self.padding_idx = embed.padding_idx

        embed_dim = embed.embedding_dim
        self.embed_scale = math.sqrt(embed_dim)
        assert embed_dim == latent_dim
        layers = [TransformerDecoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)]
        self.layers = nn.ModuleList(layers)
        self.pos_enc = PositionalEncoding(latent_dim, self.padding_idx, max_length + 1)
        self.dropword = dropword # drop entire tokens
        self.mu = LinearWeightNorm(latent_dim, latent_dim, bias=True)
        self.logvar = LinearWeightNorm(latent_dim, latent_dim, bias=True)
        self.reset_parameters()

    def reset_parameters(self):
        pass

    @overrides
    def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
        x = self.embed_scale * self.tgt_embed(tgt_sents)
        x = F.dropout2d(x, p=self.dropword, training=self.training)
        x += self.pos_enc(tgt_sents)
        x = F.dropout(x, p=0.2, training=self.training)

        mask = tgt_masks.eq(0)
        key_mask = src_masks.eq(0)
        for layer in self.layers:
            x = layer(x, mask, src_enc, key_mask)

        mu = self.mu(x) * tgt_masks.unsqueeze(2)
        logvar = self.logvar(x) * tgt_masks.unsqueeze(2)
        return mu, logvar

    def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True):
        with torch.no_grad():
            x = self.embed_scale * self.tgt_embed(tgt_sents)
            x = F.dropout2d(x, p=self.dropword, training=self.training)
            x += self.pos_enc(tgt_sents)
            x = F.dropout(x, p=0.2, training=self.training)

            mask = tgt_masks.eq(0)
            key_mask = src_masks.eq(0)
            for layer in self.layers:
                x = layer.init(x, mask, src_enc, key_mask, init_scale=init_scale)

            x = x * tgt_masks.unsqueeze(2)
            mu = self.mu.init(x, init_scale=0.05 * init_scale) if init_mu else self.mu(x)
            logvar = self.logvar.init(x, init_scale=0.05 * init_scale) if init_var else self.logvar(x)
            mu = mu * tgt_masks.unsqueeze(2)
            logvar = logvar * tgt_masks.unsqueeze(2)
            return mu, logvar
コード例 #2
0
ファイル: blocks.py プロジェクト: yyht/flowseq
class NICESelfAttnBlock(nn.Module):
    def __init__(self,
                 src_features,
                 in_features,
                 out_features,
                 hidden_features,
                 heads,
                 dropout=0.0,
                 pos_enc='add',
                 max_length=100):
        super(NICESelfAttnBlock, self).__init__()
        assert pos_enc in ['add', 'attn']
        self.src_proj = nn.Linear(
            src_features, in_features,
            bias=False) if src_features != in_features else None
        self.pos_enc = PositionalEncoding(in_features,
                                          padding_idx=None,
                                          init_size=max_length + 1)
        self.pos_attn = MultiHeadAttention(
            in_features, heads, dropout=dropout) if pos_enc == 'attn' else None
        self.transformer = TransformerDecoderLayer(in_features,
                                                   hidden_features,
                                                   heads,
                                                   dropout=dropout)
        self.linear = LinearWeightNorm(in_features, out_features, bias=True)

    def forward(self, x, mask, src, src_mask):
        if self.src_proj is not None:
            src = self.src_proj(src)

        key_mask = mask.eq(0)
        pos_enc = self.pos_enc(x) * mask.unsqueeze(2)
        if self.pos_attn is None:
            x = x + pos_enc
        else:
            x = self.pos_attn(pos_enc, x, x, key_mask)

        x = self.transformer(x, key_mask, src, src_mask.eq(0))
        return self.linear(x)

    def init(self, x, mask, src, src_mask, init_scale=1.0):
        if self.src_proj is not None:
            src = self.src_proj(src)

        key_mask = mask.eq(0)
        pos_enc = self.pos_enc(x) * mask.unsqueeze(2)
        if self.pos_attn is None:
            x = x + pos_enc
        else:
            x = self.pos_attn(pos_enc, x, x, key_mask)

        x = self.transformer.init(x,
                                  key_mask,
                                  src,
                                  src_mask.eq(0),
                                  init_scale=init_scale)
        x = x * mask.unsqueeze(2)
        return self.linear.init(x, init_scale=0.0)
コード例 #3
0
ファイル: shift_rnn.py プロジェクト: juheeuu/flowseq
class ShiftRecurrentCore(nn.Module):
    def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0):
        super(ShiftRecurrentCore, self).__init__()
        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)
        assert hidden_size % 2 == 0
        self.tgt_embed = embed
        assert num_layers == 1
        self.bidirectional = bidirectional
        if bidirectional:
            self.rnn = RNN(embed.embedding_dim, hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True)
        else:
            self.rnn = RNN(embed.embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False)
        self.use_attn = use_attn
        if use_attn:
            self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size)
            self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU())
        else:
            self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU())
        self.dropout = dropout
        self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens
        self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True)
        self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True)

    #@overrides
    def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
        tgt_embed = self.tgt_embed(tgt_sents)
        if self.dropout2d is not None:
            tgt_embed = self.dropout2d(tgt_embed)
        lengths = tgt_masks.sum(dim=1).long()
        packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
        packed_enc, _ = self.rnn(packed_embed)
        tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))

        if self.bidirectional:
            # split into fwd and bwd
            fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2)
            pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2)))
            pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1)
            pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1)
            tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2)
        else:
            pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2)))
            tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1]

        if self.use_attn:
            ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0))
            ctx = torch.cat([tgt_enc, ctx], dim=2)
        else:
            ctx = tgt_enc
        ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
        mu = self.mu(ctx) * tgt_masks.unsqueeze(2)
        logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2)
        return mu, logvar

    def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True):
        with torch.no_grad():
            tgt_embed = self.tgt_embed(tgt_sents)
            if self.dropout2d is not None:
                tgt_embed = self.dropout2d(tgt_embed)
            lengths = tgt_masks.sum(dim=1).long()
            packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
            packed_enc, _ = self.rnn(packed_embed)
            tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))

            if self.bidirectional:
                fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2)  # (batch_size, seq_len, hidden_size // 2)
                pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2)))
                pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1)
                pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1)
                tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2)
            else:
                pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2)))
                tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1]

            if self.use_attn:
                ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale)
                ctx = torch.cat([tgt_enc, ctx], dim=2)
            else:
                ctx = tgt_enc
            ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
            mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx)
            logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx)
            mu = mu * tgt_masks.unsqueeze(2)
            logvar = logvar * tgt_masks.unsqueeze(2)
            return mu, logvar
コード例 #4
0
ファイル: blocks.py プロジェクト: yyht/flowseq
class NICERecurrentBlock(nn.Module):
    def __init__(self,
                 rnn_mode,
                 src_features,
                 in_features,
                 out_features,
                 hidden_features,
                 dropout=0.0):
        super(NICERecurrentBlock, self).__init__()
        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)

        self.rnn = RNN(in_features,
                       hidden_features // 2,
                       batch_first=True,
                       bidirectional=True)
        self.attn = GlobalAttention(src_features,
                                    hidden_features,
                                    hidden_features,
                                    dropout=dropout)
        self.linear = LinearWeightNorm(in_features + hidden_features,
                                       out_features,
                                       bias=True)

    def forward(self, x, mask, src, src_mask):
        lengths = mask.sum(dim=1).long()
        packed_out = pack_padded_sequence(x,
                                          lengths,
                                          batch_first=True,
                                          enforce_sorted=False)
        packed_out, _ = self.rnn(packed_out)
        out, _ = pad_packed_sequence(packed_out,
                                     batch_first=True,
                                     total_length=mask.size(1))
        # [batch, length, out_features]
        out = self.attn(out, src, key_mask=src_mask.eq(0))
        out = self.linear(torch.cat([x, out], dim=2))
        return out

    def init(self, x, mask, src, src_mask, init_scale=1.0):
        lengths = mask.sum(dim=1).long()
        packed_out = pack_padded_sequence(x,
                                          lengths,
                                          batch_first=True,
                                          enforce_sorted=False)
        packed_out, _ = self.rnn(packed_out)
        out, _ = pad_packed_sequence(packed_out,
                                     batch_first=True,
                                     total_length=mask.size(1))
        # [batch, length, out_features]
        out = self.attn.init(out,
                             src,
                             key_mask=src_mask.eq(0),
                             init_scale=init_scale)
        out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0)
        return out
コード例 #5
0
ファイル: blocks.py プロジェクト: yyht/flowseq
class NICEConvBlock(nn.Module):
    def __init__(self,
                 src_features,
                 in_features,
                 out_features,
                 hidden_features,
                 kernel_size,
                 dropout=0.0):
        super(NICEConvBlock, self).__init__()
        self.conv1 = Conv1dWeightNorm(in_features,
                                      hidden_features,
                                      kernel_size=kernel_size,
                                      padding=kernel_size // 2,
                                      bias=True)
        self.conv2 = Conv1dWeightNorm(hidden_features,
                                      hidden_features,
                                      kernel_size=kernel_size,
                                      padding=kernel_size // 2,
                                      bias=True)
        self.activation = nn.ELU(inplace=True)
        self.attn = GlobalAttention(src_features,
                                    hidden_features,
                                    hidden_features,
                                    dropout=dropout)
        self.linear = LinearWeightNorm(hidden_features * 2,
                                       out_features,
                                       bias=True)

    def forward(self, x, mask, src, src_mask):
        """

        Args:
            x: Tensor
                input tensor [batch, length, in_features]
            mask: Tensor
                x mask tensor [batch, length]
            src: Tensor
                source input tensor [batch, src_length, src_features]
            src_mask: Tensor
                source mask tensor [batch, src_length]

        Returns: Tensor
            out tensor [batch, length, out_features]

        """
        out = self.activation(self.conv1(x.transpose(1, 2)))
        out = self.activation(self.conv2(out)).transpose(1,
                                                         2) * mask.unsqueeze(2)
        out = self.attn(out, src, key_mask=src_mask.eq(0))
        out = self.linear(torch.cat([x, out], dim=2))
        return out

    def init(self, x, mask, src, src_mask, init_scale=1.0):
        out = self.activation(
            self.conv1.init(x.transpose(1, 2), init_scale=init_scale))
        out = self.activation(self.conv2.init(
            out, init_scale=init_scale)).transpose(1, 2) * mask.unsqueeze(2)
        out = self.attn.init(out,
                             src,
                             key_mask=src_mask.eq(0),
                             init_scale=init_scale)
        out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0)
        return out