def __init__(self, dim, ignore_index=0, pad_value=0, tie_token_embeds=False, amp_enabled=False, **kwargs): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' enc_kwargs['dim'] = dec_kwargs['dim'] = dim dec_kwargs['amp_enabled'] = amp_enabled dec_kwargs['causal'] = True dec_kwargs['cross_attend'] = True enc = PerformerLM(**enc_kwargs) dec = PerformerLM(**dec_kwargs) if tie_token_embeds: enc.token_embed = dec.token_embed self.enc = enc self.dec = AutoregressiveWrapper(dec, ignore_index=ignore_index, pad_value=pad_value)
def __init__(self, dim, tie_token_embeds=False, no_projection=False, **kwargs): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' enc_kwargs['dim'] = dec_kwargs['dim'] = dim enc_kwargs['no_projection'] = dec_kwargs[ 'no_projection'] = no_projection dec_kwargs['causal'] = True dec_kwargs['cross_attend'] = True enc = PerformerLM(**enc_kwargs) dec = PerformerLM(**dec_kwargs) if tie_token_embeds: enc.token_emb = dec.token_emb self.enc = enc self.dec = AutoregressiveWrapper(dec)
class PerformerEncDec(nn.Module): def __init__(self, dim, tie_token_embeds=False, no_projection=False, **kwargs): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' enc_kwargs['dim'] = dec_kwargs['dim'] = dim enc_kwargs['no_projection'] = dec_kwargs[ 'no_projection'] = no_projection dec_kwargs['causal'] = True dec_kwargs['cross_attend'] = True enc = PerformerLM(**enc_kwargs) dec = PerformerLM(**dec_kwargs) if tie_token_embeds: enc.token_emb = dec.token_emb self.enc = enc self.dec = AutoregressiveWrapper(dec) @torch.no_grad() def generate(self, seq_in, seq_out_start, seq_len, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs) return self.dec.generate(seq_out_start, seq_len, context=encodings, **{ **dec_kwargs, **kwargs }) def forward(self, seq_in, seq_out, enc_mask=None, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) encodings = self.enc(seq_in, mask=enc_mask, return_encodings=True, **enc_kwargs) return self.dec(seq_out, context=encodings, context_mask=enc_mask, **dec_kwargs)
class PerformerEncDec(nn.Module): def __init__(self, dim, ignore_index=0, pad_value=0, tie_token_embeds=False, amp_enabled=False, **kwargs): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' enc_kwargs['dim'] = dec_kwargs['dim'] = dim dec_kwargs['amp_enabled'] = amp_enabled dec_kwargs['causal'] = True dec_kwargs['cross_attend'] = True enc = PerformerLM(**enc_kwargs) dec = PerformerLM(**dec_kwargs) if tie_token_embeds: enc.token_embed = dec.token_embed self.enc = enc self.dec = AutoregressiveWrapper(dec, ignore_index=ignore_index, pad_value=pad_value) def generate(self, seq_in, seq_out_start, seq_len, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs) return self.dec.generate(seq_out_start, seq_len, context=encodings, **{ **dec_kwargs, **kwargs }) def forward(self, seq_in, seq_out, return_loss=False, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs) return self.dec(seq_out, context=encodings, return_loss=return_loss, **dec_kwargs)
def decode_tokens(tokens): return ''.join(list(map(decode_token, tokens))) # instantiate model model = PerformerLM(num_tokens=256, dim=512, depth=6, max_seq_len=SEQ_LEN, heads=8, causal=True, reversible=True, nb_features=256) model = AutoregressiveWrapper(model) model.cuda() # prepare enwik8 data with gzip.open('./data/enwik8.gz') as file: X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) trX, vaX = np.split(X, [int(90e6)]) data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) class TextSamplerDataset(Dataset): def __init__(self, data, seq_len): super().__init__() self.data = data self.seq_len = seq_len