class TransformerCore(nn.Module): def __init__(self, embed, num_layers, latent_dim, hidden_size, heads, dropout=0.0, dropword=0.0, max_length=100): super(TransformerCore, self).__init__() self.tgt_embed = embed self.padding_idx = embed.padding_idx embed_dim = embed.embedding_dim self.embed_scale = math.sqrt(embed_dim) assert embed_dim == latent_dim layers = [TransformerDecoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)] self.layers = nn.ModuleList(layers) self.pos_enc = PositionalEncoding(latent_dim, self.padding_idx, max_length + 1) self.dropword = dropword # drop entire tokens self.mu = LinearWeightNorm(latent_dim, latent_dim, bias=True) self.logvar = LinearWeightNorm(latent_dim, latent_dim, bias=True) self.reset_parameters() def reset_parameters(self): pass @overrides def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): x = self.embed_scale * self.tgt_embed(tgt_sents) x = F.dropout2d(x, p=self.dropword, training=self.training) x += self.pos_enc(tgt_sents) x = F.dropout(x, p=0.2, training=self.training) mask = tgt_masks.eq(0) key_mask = src_masks.eq(0) for layer in self.layers: x = layer(x, mask, src_enc, key_mask) mu = self.mu(x) * tgt_masks.unsqueeze(2) logvar = self.logvar(x) * tgt_masks.unsqueeze(2) return mu, logvar def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True): with torch.no_grad(): x = self.embed_scale * self.tgt_embed(tgt_sents) x = F.dropout2d(x, p=self.dropword, training=self.training) x += self.pos_enc(tgt_sents) x = F.dropout(x, p=0.2, training=self.training) mask = tgt_masks.eq(0) key_mask = src_masks.eq(0) for layer in self.layers: x = layer.init(x, mask, src_enc, key_mask, init_scale=init_scale) x = x * tgt_masks.unsqueeze(2) mu = self.mu.init(x, init_scale=0.05 * init_scale) if init_mu else self.mu(x) logvar = self.logvar.init(x, init_scale=0.05 * init_scale) if init_var else self.logvar(x) mu = mu * tgt_masks.unsqueeze(2) logvar = logvar * tgt_masks.unsqueeze(2) return mu, logvar
class NICESelfAttnBlock(nn.Module): def __init__(self, src_features, in_features, out_features, hidden_features, heads, dropout=0.0, pos_enc='add', max_length=100): super(NICESelfAttnBlock, self).__init__() assert pos_enc in ['add', 'attn'] self.src_proj = nn.Linear( src_features, in_features, bias=False) if src_features != in_features else None self.pos_enc = PositionalEncoding(in_features, padding_idx=None, init_size=max_length + 1) self.pos_attn = MultiHeadAttention( in_features, heads, dropout=dropout) if pos_enc == 'attn' else None self.transformer = TransformerDecoderLayer(in_features, hidden_features, heads, dropout=dropout) self.linear = LinearWeightNorm(in_features, out_features, bias=True) def forward(self, x, mask, src, src_mask): if self.src_proj is not None: src = self.src_proj(src) key_mask = mask.eq(0) pos_enc = self.pos_enc(x) * mask.unsqueeze(2) if self.pos_attn is None: x = x + pos_enc else: x = self.pos_attn(pos_enc, x, x, key_mask) x = self.transformer(x, key_mask, src, src_mask.eq(0)) return self.linear(x) def init(self, x, mask, src, src_mask, init_scale=1.0): if self.src_proj is not None: src = self.src_proj(src) key_mask = mask.eq(0) pos_enc = self.pos_enc(x) * mask.unsqueeze(2) if self.pos_attn is None: x = x + pos_enc else: x = self.pos_attn(pos_enc, x, x, key_mask) x = self.transformer.init(x, key_mask, src, src_mask.eq(0), init_scale=init_scale) x = x * mask.unsqueeze(2) return self.linear.init(x, init_scale=0.0)
class ShiftRecurrentCore(nn.Module): def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0): super(ShiftRecurrentCore, self).__init__() if rnn_mode == 'RNN': RNN = nn.RNN elif rnn_mode == 'LSTM': RNN = nn.LSTM elif rnn_mode == 'GRU': RNN = nn.GRU else: raise ValueError('Unknown RNN mode: %s' % rnn_mode) assert hidden_size % 2 == 0 self.tgt_embed = embed assert num_layers == 1 self.bidirectional = bidirectional if bidirectional: self.rnn = RNN(embed.embedding_dim, hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True) else: self.rnn = RNN(embed.embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False) self.use_attn = use_attn if use_attn: self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size) self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU()) else: self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU()) self.dropout = dropout self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True) self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True) #@overrides def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): tgt_embed = self.tgt_embed(tgt_sents) if self.dropout2d is not None: tgt_embed = self.dropout2d(tgt_embed) lengths = tgt_masks.sum(dim=1).long() packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False) packed_enc, _ = self.rnn(packed_embed) tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1)) if self.bidirectional: # split into fwd and bwd fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2) pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2))) pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1) pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1) tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2) else: pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2))) tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1] if self.use_attn: ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0)) ctx = torch.cat([tgt_enc, ctx], dim=2) else: ctx = tgt_enc ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) mu = self.mu(ctx) * tgt_masks.unsqueeze(2) logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2) return mu, logvar def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True): with torch.no_grad(): tgt_embed = self.tgt_embed(tgt_sents) if self.dropout2d is not None: tgt_embed = self.dropout2d(tgt_embed) lengths = tgt_masks.sum(dim=1).long() packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False) packed_enc, _ = self.rnn(packed_embed) tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1)) if self.bidirectional: fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2) pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2))) pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1) pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1) tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2) else: pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2))) tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1] if self.use_attn: ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale) ctx = torch.cat([tgt_enc, ctx], dim=2) else: ctx = tgt_enc ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx) logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx) mu = mu * tgt_masks.unsqueeze(2) logvar = logvar * tgt_masks.unsqueeze(2) return mu, logvar
class NICERecurrentBlock(nn.Module): def __init__(self, rnn_mode, src_features, in_features, out_features, hidden_features, dropout=0.0): super(NICERecurrentBlock, self).__init__() if rnn_mode == 'RNN': RNN = nn.RNN elif rnn_mode == 'LSTM': RNN = nn.LSTM elif rnn_mode == 'GRU': RNN = nn.GRU else: raise ValueError('Unknown RNN mode: %s' % rnn_mode) self.rnn = RNN(in_features, hidden_features // 2, batch_first=True, bidirectional=True) self.attn = GlobalAttention(src_features, hidden_features, hidden_features, dropout=dropout) self.linear = LinearWeightNorm(in_features + hidden_features, out_features, bias=True) def forward(self, x, mask, src, src_mask): lengths = mask.sum(dim=1).long() packed_out = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) packed_out, _ = self.rnn(packed_out) out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=mask.size(1)) # [batch, length, out_features] out = self.attn(out, src, key_mask=src_mask.eq(0)) out = self.linear(torch.cat([x, out], dim=2)) return out def init(self, x, mask, src, src_mask, init_scale=1.0): lengths = mask.sum(dim=1).long() packed_out = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) packed_out, _ = self.rnn(packed_out) out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=mask.size(1)) # [batch, length, out_features] out = self.attn.init(out, src, key_mask=src_mask.eq(0), init_scale=init_scale) out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0) return out
class NICEConvBlock(nn.Module): def __init__(self, src_features, in_features, out_features, hidden_features, kernel_size, dropout=0.0): super(NICEConvBlock, self).__init__() self.conv1 = Conv1dWeightNorm(in_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2, bias=True) self.conv2 = Conv1dWeightNorm(hidden_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2, bias=True) self.activation = nn.ELU(inplace=True) self.attn = GlobalAttention(src_features, hidden_features, hidden_features, dropout=dropout) self.linear = LinearWeightNorm(hidden_features * 2, out_features, bias=True) def forward(self, x, mask, src, src_mask): """ Args: x: Tensor input tensor [batch, length, in_features] mask: Tensor x mask tensor [batch, length] src: Tensor source input tensor [batch, src_length, src_features] src_mask: Tensor source mask tensor [batch, src_length] Returns: Tensor out tensor [batch, length, out_features] """ out = self.activation(self.conv1(x.transpose(1, 2))) out = self.activation(self.conv2(out)).transpose(1, 2) * mask.unsqueeze(2) out = self.attn(out, src, key_mask=src_mask.eq(0)) out = self.linear(torch.cat([x, out], dim=2)) return out def init(self, x, mask, src, src_mask, init_scale=1.0): out = self.activation( self.conv1.init(x.transpose(1, 2), init_scale=init_scale)) out = self.activation(self.conv2.init( out, init_scale=init_scale)).transpose(1, 2) * mask.unsqueeze(2) out = self.attn.init(out, src, key_mask=src_mask.eq(0), init_scale=init_scale) out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0) return out