class Encoder(nn.Module): """ User utterance encoder Args: ntoken: vocab size ninp: embedding dimension nhead: number of heads nhid: hidden layer size nlayers: number of layers dropout: dropout rate """ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, params, dropout=0.5, embedding=None): super().__init__() from torch.nn import TransformerEncoder, TransformerEncoderLayer self.model_type = 'TransformerEncoder' self.src_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) self.embedding = nn.Embedding(ntoken, ninp) if embedding is None else embedding self.ninp = ninp self.params = params # self.init_weights() # def init_weights(self): # initrange = 0.1 # self.embedding.weight.data.uniform_(-initrange, initrange) def train(self, t=True): self.transformer_encoder.train(t) def forward(self, src): mask = src.eq(0).transpose(0,1) # 0 corresponds to <pad> src = self.embedding(src) * self.ninp src = self.pos_encoder(src) output = self.transformer_encoder(src, src_key_padding_mask=mask) return output