def __init__(self, *, dim, tie_token_emb=False, **kwargs):
        super().__init__()
        enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
        dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)

        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
        enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'],
                                              enc_kwargs)
        enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop(
            'num_memory_tokens', None)

        dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'],
                                              dec_kwargs)

        self.encoder = TransformerWrapper(**enc_transformer_kwargs,
                                          attn_layers=Encoder(dim=dim,
                                                              **enc_kwargs))

        self.decoder = TransformerWrapper(**dec_transformer_kwargs,
                                          attn_layers=Decoder(
                                              dim=dim,
                                              cross_attend=True,
                                              **dec_kwargs))

        if tie_token_emb:
            self.decoder.token_emb = self.encoder.token_emb

        self.decoder = AutoregressiveWrapper(self.decoder)
class XTransformer(nn.Module):
    def __init__(self, *, dim, tie_token_emb=False, **kwargs):
        super().__init__()
        enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
        dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)

        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
        enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'],
                                              enc_kwargs)
        enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop(
            'num_memory_tokens', None)

        dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'],
                                              dec_kwargs)

        self.encoder = TransformerWrapper(**enc_transformer_kwargs,
                                          attn_layers=Encoder(dim=dim,
                                                              **enc_kwargs))

        self.decoder = TransformerWrapper(**dec_transformer_kwargs,
                                          attn_layers=Decoder(
                                              dim=dim,
                                              cross_attend=True,
                                              **dec_kwargs))

        if tie_token_emb:
            self.decoder.token_emb = self.encoder.token_emb

        self.decoder = AutoregressiveWrapper(self.decoder)

    @torch.no_grad()
    def generate(self, seq_in, seq_out_start, seq_len, src_mask=None):
        encodings = self.encoder(seq_in, return_embeddings=True, mask=src_mask)
        return self.decoder.generate(seq_out_start,
                                     seq_len,
                                     context=encodings,
                                     context_mask=src_mask)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc = self.encoder(src, mask=src_mask, return_embeddings=True)
        out = self.decoder(tgt,
                           context=enc,
                           mask=tgt_mask,
                           context_mask=src_mask)
        return out
Esempio n. 3
0
def decode_token(token):
    return str(chr(max(32, token)))


def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))


# instantiate GPT-like decoder model

model = TransformerWrapper(num_tokens=256,
                           max_seq_len=SEQ_LEN,
                           attn_layers=Decoder(dim=512, depth=6, heads=8))

model = AutoregressiveWrapper(model)
model.cuda()

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)


class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len