コード例 #1
0
class Trainer:
    """
    ## Trainer
    """
    def __init__(self, configs: Configs):
        # Get the device
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        # Initialize the dataset
        self.dataset = TinyShakespeareDataset(configs.seq_len)
        # Initialize the dataloader
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=configs.batch_size,
                                     collate_fn=transpose_batch,
                                     shuffle=True)

        # FFN with Gated Linear Unit
        # $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
        if configs.glu_variant == 'GLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Sigmoid(), True, False, False, False)
        # FFN with Bilinear hidden layer
        # $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
        elif configs.glu_variant == 'Bilinear':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Identity(), True, False, False, False)
        # FFN with ReLU gate
        # $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
        elif configs.glu_variant == 'ReGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU(), True, False, False, False)
        # FFN with GELU gate
        # $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
        elif configs.glu_variant == 'GEGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU(), True, False, False, False)
        # FFN with Swish gate
        # $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
        # where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
        elif configs.glu_variant == 'SwiGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.SiLU(), True, False, False, False)
        # FFN with ReLU activation
        # $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
        elif configs.glu_variant == 'ReLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU())
        # FFN with ReLU activation
        # $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
        elif configs.glu_variant == 'GELU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU())
        else:
            raise ValueError(f'Unknown variant {configs.glu_variant}')

        # Number of different characters
        n_chars = len(self.dataset.stoi)

        # Initialize [Multi-Head Attention module](../mha.html)
        mha = MultiHeadAttention(configs.n_heads, configs.d_model,
                                 configs.dropout)
        # Initialize the [Transformer Block](../models.html#TransformerLayer)
        transformer_layer = TransformerLayer(d_model=configs.d_model,
                                             self_attn=mha,
                                             src_attn=None,
                                             feed_forward=ffn,
                                             dropout_prob=configs.dropout)
        # Initialize the model with an
        # [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
        # (with fixed positional encoding)
        # [transformer encoder](../models.html#Encoder) and
        # a linear layer to generate logits.
        self.model = AutoregressiveModel(
            EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
            Encoder(transformer_layer, configs.n_layers),
            nn.Linear(configs.d_model, n_chars))

        # Move the model to the current device
        self.model.to(self.device)

        # Initialize [Noam optimizer](../../optimizers/noam.html)
        self.optimizer = Noam(self.model.parameters(),
                              lr=1.0,
                              warmup=2_000,
                              d_model=configs.d_model)

        # Cross-entropy loss
        self.loss_func = nn.CrossEntropyLoss()
        # Number of training epochs;
        # *note that our dataset definition repeats the data `seq_len` times in a single epoch
        self.epochs = configs.epochs
        # Gradient clipping norm
        self.grad_norm_clip = configs.grad_norm_clip

        # Set tracker configurations
        tracker.set_scalar("loss.*", True)

    def sample(self):
        """
        ### Sampling function to generate samples periodically while training
        """

        # Starting prompt
        prompt = 'It is'
        # Collect output for printing
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.dataset.itos[output[-1].item()]
            # Add the prediction for logging
            log += [(self.dataset.itos[output[-1].item()], Text.value)]

        # Print the sampled output
        logger.log(log)

    def train(self):
        """
        ### Train the model
        """

        # Loop for the given number of epochs
        for _ in monit.loop(self.epochs):
            # Iterate over the minibatches
            for i, batch in monit.enum('Train', self.dataloader):
                # Move data to the device
                data, target = batch[0].to(self.device), batch[1].to(
                    self.device)

                # Set tracker step, as the number of characters trained on
                tracker.add_global_step(data.shape[0] * data.shape[1])

                # Set model state to training
                self.model.train()
                # Evaluate the model
                output = self.model(data)

                # Calculate loss
                loss = self.loss_func(output.view(-1, output.shape[-1]),
                                      target.view(-1))
                # Log the loss
                tracker.add("loss.train", loss)

                # Calculate gradients
                loss.backward()
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_norm=self.grad_norm_clip)
                # Take optimizer step
                self.optimizer.step()
                # Log the model parameters and gradients
                if (i + 1) % 100 == 0:
                    tracker.add('model', self.model)
                # Clear the gradients
                self.optimizer.zero_grad()

                # Generate a sample
                if (i + 1) % 100 == 0:
                    self.model.eval()
                    with torch.no_grad():
                        self.sample()

                # Save the tracked metrics
                if (i + 1) % 10 == 0:
                    tracker.save()

            # Save the model
            experiment.save_checkpoint()
コード例 #2
0
class Trainer:
    def __init__(self, configs: Configs):
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        self.dataset = TinyShakespeareDataset(configs.seq_len)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=configs.batch_size,
                                     collate_fn=transpose_batch,
                                     shuffle=True)

        if configs.glu_variant == 'GLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Sigmoid(), True, False, False, False)
        elif configs.glu_variant == 'Bilinear':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.Identity(), True, False, False, False)
        elif configs.glu_variant == 'ReGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU(), True, False, False, False)
        elif configs.glu_variant == 'GEGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU(), True, False, False, False)
        elif configs.glu_variant == 'SwiGLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.SiLU(), True, False, False, False)
        elif configs.glu_variant == 'ReLU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.ReLU())
        elif configs.glu_variant == 'GELU':
            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout,
                              nn.GELU())
        else:
            raise ValueError(f'Unknown variant {configs.glu_variant}')

        n_chars = len(self.dataset.stoi)
        self.model = AutoregressiveModel(
            EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
            Encoder(
                TransformerLayer(d_model=configs.d_model,
                                 self_attn=MultiHeadAttention(
                                     configs.n_heads, configs.d_model,
                                     configs.dropout),
                                 src_attn=None,
                                 feed_forward=ffn,
                                 dropout_prob=configs.dropout),
                configs.n_layers), nn.Linear(configs.d_model, n_chars))
        self.model.to(self.device)

        self.optimizer = Noam(self.model.parameters(),
                              lr=1.0,
                              warmup=2_000,
                              d_model=configs.d_model)

        self.loss_func = nn.CrossEntropyLoss()
        self.epochs = configs.epochs
        self.grad_norm_clip = configs.grad_norm_clip

        # Set tracker configurations
        tracker.set_scalar("loss.*", True)

    def sample(self):
        """
        ### Sampling function to generate samples periodically while training
        """

        # Starting prompt
        prompt = 'It is'
        # Collect output for printing
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.dataset.itos[output[-1].item()]
            # Add the prediction for logging
            log += [(self.dataset.itos[output[-1].item()], Text.value)]

        # Print the sampled output
        logger.log(log)

    def train(self):
        for _ in monit.loop(self.epochs):
            for i, batch in monit.enum('Train', self.dataloader):
                # Move data to the device
                data, target = batch[0].to(self.device), batch[1].to(
                    self.device)

                tracker.add_global_step(data.shape[0] * data.shape[1])

                self.model.train()
                output = self.model(data)

                # Calculate and log loss
                loss = self.loss_func(output.view(-1, output.shape[-1]),
                                      target.view(-1))
                tracker.add("loss.train", loss)

                # Calculate gradients
                loss.backward()
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_norm=self.grad_norm_clip)
                # Take optimizer step
                self.optimizer.step()
                # Log the model parameters and gradients on last batch of every epoch
                if (i + 1) % 100 == 0:
                    tracker.add('model', self.model)
                # Clear the gradients
                self.optimizer.zero_grad()

                if (i + 1) % 100 == 0:
                    self.model.eval()
                    with torch.no_grad():
                        self.sample()

                # Save the tracked metrics
                if (i + 1) % 10 == 0:
                    tracker.save()

            experiment.save_checkpoint()