Exemplo n.º 1
0
    def __init__(self, n_heads: int, d_model: int, dropout: float, d_ff: int, shortening_factors: List[int]):
        """
        * `n_heads` is the number of heads in [multi-head attention layers](../mha.html)
        * `d_model` is the size of the token embeddings
        * `dropout` is the dropout probability
        * `d_ff` is the dimensionality of the hidden layer in [position-wise feed-forward layers](../feed_forward.html)
        * `shortening_factors` is the list of shortening factors
        """
        super().__init__()

        # The transformer layer before down-sampling
        self.pre = TransformerLayer(d_model=d_model,
                                    # [Multi-head attention layer](../mha.html)
                                    self_attn=MultiHeadAttention(n_heads, d_model, dropout),
                                    # [Position wise feed-forward layers](.. / feed_forward.html)
                                    feed_forward=FeedForward(d_model, d_ff, dropout),
                                    #
                                    dropout_prob=dropout)
        # Auto-regressive mask
        self.mask = AutoregressiveMask()

        # The shortening factor $k$ (or the down-sampling rate)
        k = shortening_factors[0]

        # We shift the tokens to the right by $k - 1$ steps to make sure
        # information doesn't leak from the future tokens to past tokens
        # as a result of down-sampling and up-sampling
        self.shift_right = ShiftRight(k - 1)
        # Shortening or the down-sampling layer. We use the simplest form - average pooling.
        # The paper shows that attention based down sampling works best, which we haven't implemented yet.
        self.shortening = AvgPoolShortening(k)

        # If there are no more shortening (middle of the hourglass)
        if len(shortening_factors) == 1:
            # The center layer is another transformer layer
            self.shortened = TransformerLayer(d_model=d_model,
                                              self_attn=MultiHeadAttention(n_heads, d_model, dropout),
                                              feed_forward=FeedForward(d_model, d_ff, dropout),
                                              dropout_prob=dropout)
            # Autoregressive mask
            self.mask_short = AutoregressiveMask()
            self.hour_glass = None
        else:
            # Insert another hourglass model recursively
            self.hour_glass = HourGlass(n_heads, d_model, dropout, d_ff, shortening_factors[1:])

        # Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling
        # works better.
        self.up_sampling = NaiveUpSampling(k)

        # The final transformer layer after up-sampling
        self.post = TransformerLayer(d_model=d_model,
                                     self_attn=MultiHeadAttention(n_heads, d_model, dropout),
                                     feed_forward=FeedForward(d_model, d_ff, dropout),
                                     dropout_prob=dropout)
Exemplo n.º 2
0
    def __init__(self, configs: Configs):
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        self.dataset = TinyShakespeareDataset(configs.seq_len)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=configs.batch_size,
                                     collate_fn=transpose_batch,
                                     shuffle=True)

        if configs.glu_variant == 'GLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Sigmoid(), True, False, False, False)
        elif configs.glu_variant == 'Bilinear':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Identity(), True, False, False, False)
        elif configs.glu_variant == 'ReGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU(), True, False, False, False)
        elif configs.glu_variant == 'GEGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU(), True, False, False, False)
        elif configs.glu_variant == 'SwiGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.SiLU(), True, False, False, False)
        elif configs.glu_variant == 'ReLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU())
        elif configs.glu_variant == 'GELU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU())
        else:
            raise ValueError(f'Unknown variant {configs.glu_variant}')

        n_chars = len(self.dataset.stoi)
        self.model = AutoregressiveModel(
            EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
            Encoder(
                TransformerLayer(d_model=configs.d_model,
                                 self_attn=MultiHeadAttention(
                                     configs.n_heads, configs.d_model,
                                     configs.dropout),
                                 src_attn=None,
                                 feed_forward=ffn,
                                 dropout_prob=configs.dropout),
                configs.n_layers), nn.Linear(configs.d_model, n_chars))
        self.model.to(self.device)

        self.optimizer = Noam(self.model.parameters(),
                              lr=1.0,
                              warmup=2_000,
                              d_model=configs.d_model)

        self.loss_func = nn.CrossEntropyLoss()
        self.epochs = configs.epochs
        self.grad_norm_clip = configs.grad_norm_clip

        # Set tracker configurations
        tracker.set_scalar("loss.*", True)
Exemplo n.º 3
0
def _model(c: Configs):
    """
    #### Initialize the model
    """
    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
                                  DeepNormTransformerLayer(d_model=c.d_model,
                                                           deep_norm_alpha=c.deep_norm_alpha,
                                                           deep_norm_beta=c.deep_norm_beta,
                                                           feed_forward=FeedForward(d_model=c.d_model,
                                                                                    d_ff=c.d_model * 4),
                                                           self_attn=MultiHeadAttention(c.n_heads, c.d_model,
                                                                                        dropout_prob=0.0)))

    return m.to(c.device)
Exemplo n.º 4
0
def autoregressive_model(c: Configs):
    """
    ### Initialize the auto-regressive model
    """
    from labml_nn.transformers.xl import RelativeMultiHeadAttention
    from labml_nn.transformers.feed_forward import FeedForward
    m = AutoregressiveModel(
        c.n_tokens, c.d_model,
        TransformerXL(
            TransformerXLLayer(
                d_model=c.d_model,
                self_attn=RelativeMultiHeadAttention(c.heads, c.d_model,
                                                     c.dropout),
                feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
                dropout_prob=c.dropout), c.n_layers))
    return m.to(c.device)
Exemplo n.º 5
0
def transformer_xl_model(c: Configs):
    from labml_nn.transformers.xl import RelativeMultiHeadAttention
    from labml_nn.transformers.feed_forward import FeedForward
    from labml_nn.transformers.xl import TransformerXL
    from labml_nn.transformers.xl import TransformerXLLayer
    from python_autocomplete.models.xl import TransformerXLModel
    m = TransformerXLModel(
        c.n_tokens, c.d_model,
        TransformerXL(
            TransformerXLLayer(
                d_model=c.d_model,
                self_attn=RelativeMultiHeadAttention(c.transformer.n_heads,
                                                     c.d_model, c.dropout),
                feed_forward=FeedForward(c.d_model, c.transformer.ffn.d_ff,
                                         c.dropout),
                dropout_prob=c.dropout), c.n_layers))
    return m.to(c.device)
Exemplo n.º 6
0
def switch_transformer(c: Configs):
    """
    ### Initialize the switch transformer
    """
    from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
    from labml_nn.transformers import MultiHeadAttention
    from labml_nn.transformers.feed_forward import FeedForward

    return SwitchTransformer(
        SwitchTransformerLayer(d_model=c.d_model,
                               attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
                               feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
                                                              drop_tokens=c.drop_tokens,
                                                              is_scale_prob=c.is_scale_prob,
                                                              n_experts=c.n_experts,
                                                              expert=FeedForward(c.d_model, c.d_ff, c.dropout),
                                                              d_model=c.d_model),
                               dropout_prob=c.dropout),
        c.n_layers)
Exemplo n.º 7
0
    def __init__(self, configs: Configs):
        # Get the device
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        # Initialize the dataset
        self.dataset = TinyShakespeareDataset(configs.seq_len)
        # Initialize the dataloader
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=configs.batch_size,
                                     collate_fn=transpose_batch,
                                     shuffle=True)

        # FFN with Gated Linear Unit
        # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
        if configs.glu_variant == 'GLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Sigmoid(), True, False, False, False)
        # FFN with Bilinear hidden layer
        # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
        elif configs.glu_variant == 'Bilinear':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Identity(), True, False, False, False)
        # FFN with ReLU gate
        # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
        elif configs.glu_variant == 'ReGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU(), True, False, False, False)
        # FFN with GELU gate
        # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
        elif configs.glu_variant == 'GEGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU(), True, False, False, False)
        # FFN with Swish gate
        # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
        # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
        elif configs.glu_variant == 'SwiGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.SiLU(), True, False, False, False)
        # FFN with ReLU activation
        # $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
        elif configs.glu_variant == 'ReLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU())
        # FFN with ReLU activation
        # $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
        elif configs.glu_variant == 'GELU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU())
        else:
            raise ValueError(f'Unknown variant {configs.glu_variant}')

        # Number of different characters
        n_chars = len(self.dataset.stoi)

        # Initialize [Multi-Head Attention module](../mha.html)
        mha = MultiHeadAttention(configs.n_heads, configs.d_model,
                                 configs.dropout)
        # Initialize the [Transformer Block](../models.html#TransformerLayer)
        transformer_layer = TransformerLayer(d_model=configs.d_model,
                                             self_attn=mha,
                                             src_attn=None,
                                             feed_forward=ffn,
                                             dropout_prob=configs.dropout)
        # Initialize the model with an
        # [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
        # (with fixed positional encoding)
        # [transformer encoder](../models.html#Encoder) and
        # a linear layer to generate logits.
        self.model = AutoregressiveModel(
            EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
            Encoder(transformer_layer, configs.n_layers),
            nn.Linear(configs.d_model, n_chars))

        # Move the model to the current device
        self.model.to(self.device)

        # Initialize [Noam optimizer](../../optimizers/noam.html)
        self.optimizer = Noam(self.model.parameters(),
                              lr=1.0,
                              warmup=2_000,
                              d_model=configs.d_model)

        # Cross-entropy loss
        self.loss_func = nn.CrossEntropyLoss()
        # Number of training epochs;
        # *note that our dataset definition repeats the data `seq_len` times in a single epoch
        self.epochs = configs.epochs
        # Gradient clipping norm
        self.grad_norm_clip = configs.grad_norm_clip

        # Set tracker configurations
        tracker.set_scalar("loss.*", True)