def __init__(self, *, dim, tie_token_emb=False, **kwargs): super().__init__() enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs) assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword' enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs) enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop( 'num_memory_tokens', None) dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs) self.encoder = TransformerWrapper(**enc_transformer_kwargs, attn_layers=Encoder(dim=dim, **enc_kwargs)) self.decoder = TransformerWrapper(**dec_transformer_kwargs, attn_layers=Decoder( dim=dim, cross_attend=True, **dec_kwargs)) if tie_token_emb: self.decoder.token_emb = self.encoder.token_emb self.decoder = AutoregressiveWrapper(self.decoder)
class XTransformer(nn.Module): def __init__(self, *, dim, tie_token_emb=False, **kwargs): super().__init__() enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs) assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword' enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs) enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop( 'num_memory_tokens', None) dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs) self.encoder = TransformerWrapper(**enc_transformer_kwargs, attn_layers=Encoder(dim=dim, **enc_kwargs)) self.decoder = TransformerWrapper(**dec_transformer_kwargs, attn_layers=Decoder( dim=dim, cross_attend=True, **dec_kwargs)) if tie_token_emb: self.decoder.token_emb = self.encoder.token_emb self.decoder = AutoregressiveWrapper(self.decoder) @torch.no_grad() def generate(self, seq_in, seq_out_start, seq_len, src_mask=None): encodings = self.encoder(seq_in, return_embeddings=True, mask=src_mask) return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask) def forward(self, src, tgt, src_mask=None, tgt_mask=None): enc = self.encoder(src, mask=src_mask, return_embeddings=True) out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask) return out
def decode_token(token): return str(chr(max(32, token))) def decode_tokens(tokens): return ''.join(list(map(decode_token, tokens))) # instantiate GPT-like decoder model model = TransformerWrapper(num_tokens=256, max_seq_len=SEQ_LEN, attn_layers=Decoder(dim=512, depth=6, heads=8)) model = AutoregressiveWrapper(model) model.cuda() # prepare enwik8 data with gzip.open('./data/enwik8.gz') as file: X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) trX, vaX = np.split(X, [int(90e6)]) data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) class TextSamplerDataset(Dataset): def __init__(self, data, seq_len): super().__init__() self.data = data self.seq_len = seq_len