def test_positional_embedding_output_value(): layer = PositionalEmbedding(num_embeddings=100, embedding_dim=32) x1 = torch.randint(8000, (10, 6, 100)) x2 = torch.randint(8000, (10, 6, 100)) assert (layer(x1) == layer(x2)).all() x1 = torch.randint(8000, (10, )) x2 = torch.randint(8000, (5, )) for offset in range(5): assert (layer(x1)[offset:offset + 5] == layer(x2, offset=offset)).all() x1 = torch.randint(8000, ( 8, 5, 10, )) x2 = torch.randint(8000, ( 8, 5, 5, )) for offset in range(5): assert (layer(x1)[..., offset:offset + 5, :] == layer(x2, offset=offset)).all()
def test_positional_embedding_load_state_dict_with_wrapper(): layer_32 = nn.Sequential( PositionalEmbedding(num_embeddings=32, embedding_dim=16)) layer_64 = nn.Sequential( PositionalEmbedding(num_embeddings=64, embedding_dim=16)) # Reduce the embedding matrix to decrease the sequence length. layer_32.load_state_dict(layer_64.state_dict()) assert (layer_32[0].weight == layer_64[0].weight[:32]).all() layer_32[0].reset_parameters() layer_64[0].reset_parameters() # Expand the embedding matrix to increase the sequence length. layer_64.load_state_dict(layer_32.state_dict()) assert (layer_32[0].weight == layer_64[0].weight[:32]).all()
def test_positional_embedding_output_shape(): layer = PositionalEmbedding(num_embeddings=100, embedding_dim=32) x1 = torch.randint(8000, (10, 6, 100)) x2 = torch.randint(8000, (10, 6)) assert layer(x1).shape == x1.shape + (32, ) assert layer(x2).shape == x2.shape + (32, )
def __init__(self, layers: int, pad_idx: int, words: int, seq_len: int, heads: int, dims: int, rate: int = 4, dropout: float = 0.1, bidirectional: bool = True): super().__init__() self.bidirectional = bidirectional self.pad_masking = PadMasking(pad_idx) self.future_masking = FutureMasking() self.positional_embedding = PositionalEmbedding(seq_len, dims) self.token_embedding = TokenEmbedding(words, dims) self.dropout_embedding = nn.Dropout(dropout) self.transformers = nn.ModuleList([ TransformerLayer(heads, dims, rate, dropout) for _ in range(layers) ]) self.ln_head = LayerNorm(dims)