コード例 #1
0
    def __init__(self,
                 dim,
                 ignore_index=0,
                 pad_value=0,
                 tie_token_embeds=False,
                 amp_enabled=False,
                 **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        dec_kwargs['amp_enabled'] = amp_enabled
        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_embed = dec.token_embed

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec,
                                         ignore_index=ignore_index,
                                         pad_value=pad_value)
コード例 #2
0
    def __init__(self,
                 dim,
                 tie_token_embeds=False,
                 no_projection=False,
                 **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['no_projection'] = dec_kwargs[
            'no_projection'] = no_projection

        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_emb = dec.token_emb

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec)
コード例 #3
0
class PerformerEncDec(nn.Module):
    def __init__(self,
                 dim,
                 tie_token_embeds=False,
                 no_projection=False,
                 **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['no_projection'] = dec_kwargs[
            'no_projection'] = no_projection

        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_emb = dec.token_emb

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec)

    @torch.no_grad()
    def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs)
        return self.dec.generate(seq_out_start,
                                 seq_len,
                                 context=encodings,
                                 **{
                                     **dec_kwargs,
                                     **kwargs
                                 })

    def forward(self, seq_in, seq_out, enc_mask=None, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        encodings = self.enc(seq_in,
                             mask=enc_mask,
                             return_encodings=True,
                             **enc_kwargs)
        return self.dec(seq_out,
                        context=encodings,
                        context_mask=enc_mask,
                        **dec_kwargs)
コード例 #4
0
class PerformerEncDec(nn.Module):
    def __init__(self,
                 dim,
                 ignore_index=0,
                 pad_value=0,
                 tie_token_embeds=False,
                 amp_enabled=False,
                 **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        dec_kwargs['amp_enabled'] = amp_enabled
        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_embed = dec.token_embed

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec,
                                         ignore_index=ignore_index,
                                         pad_value=pad_value)

    def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs)
        return self.dec.generate(seq_out_start,
                                 seq_len,
                                 context=encodings,
                                 **{
                                     **dec_kwargs,
                                     **kwargs
                                 })

    def forward(self, seq_in, seq_out, return_loss=False, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs)
        return self.dec(seq_out,
                        context=encodings,
                        return_loss=return_loss,
                        **dec_kwargs)
コード例 #5
0
ファイル: train.py プロジェクト: amy-hyunji/performer-pytorch
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))


# instantiate model

model = PerformerLM(num_tokens=256,
                    dim=512,
                    depth=6,
                    max_seq_len=SEQ_LEN,
                    heads=8,
                    causal=True,
                    reversible=True,
                    nb_features=256)

model = AutoregressiveWrapper(model)
model.cuda()

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)


class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len