Ejemplo n.º 1
0
    def init(self):
        """
        ### Initialization
        """

        # `[MASK]` token
        self.mask_token = self.n_tokens - 1
        # `[PAD]` token
        self.padding_token = self.n_tokens - 2

        # [Masked Language Model (MLM) class](index.html) to generate the mask
        self.mlm = MLM(padding_token=self.padding_token,
                       mask_token=self.mask_token,
                       no_mask_tokens=self.no_mask_tokens,
                       n_tokens=self.n_tokens,
                       masking_prob=self.masking_prob,
                       randomize_prob=self.randomize_prob,
                       no_change_prob=self.no_change_prob)

        # Accuracy metric (ignore the labels equal to `[PAD]`)
        self.accuracy = Accuracy(ignore_index=self.padding_token)
        # Cross entropy loss (ignore the labels equal to `[PAD]`)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
        #
        super().init()
Ejemplo n.º 2
0
class Configs(MNISTConfigs, TrainValidConfigs):
    """
    ## Configurable Experiment Definition
    """
    optimizer: torch.optim.Adam
    model: nn.Module
    set_seed = SeedConfigs()
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    model: nn.Module
    inner_iterations = 10

    accuracy_func = Accuracy()
    loss_func = nn.CrossEntropyLoss()

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy_func]

    def step(self, batch: any, batch_idx: BatchIndex):
        # Get the batch
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Add global step if we are in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Run the model and specify whether to log the activations
        with self.mode.update(is_log_activations=batch_idx.is_last):
            output = self.model(data)

        # Calculate the loss
        loss = self.loss_func(output, target)
        # Calculate the accuracy
        self.accuracy_func(output, target)
        # Log the loss
        tracker.add("loss.", loss)

        # Optimize if we are in training mode
        if self.mode.is_train:
            # Calculate the gradients
            loss.backward()

            # Take optimizer step
            self.optimizer.step()
            # Log the parameter and gradient L2 norms once per epoch
            if batch_idx.is_last:
                tracker.add('model', self.model)
                tracker.add('optimizer', (self.optimizer, {
                    'model': self.model
                }))
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save logs
        tracker.save()
Ejemplo n.º 3
0
class Configs(MNISTConfigs, TrainValidConfigs):
    optimizer: torch.optim.Adam
    model: nn.Module
    set_seed = SeedConfigs()
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    model: nn.Module
    inner_iterations = 10

    accuracy_func = Accuracy()
    loss_func = nn.CrossEntropyLoss()

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy_func]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(len(data))

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy_func(output, target)
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()

            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()
Ejemplo n.º 4
0
class NLPAutoRegressionConfigs(TrainValidConfigs):
    """
    <a id="NLPAutoRegressionConfigs"></a>

    ## Trainer configurations

    This has the basic configurations for NLP auto-regressive task training.
    All the properties are configurable.
    """

    # Optimizer
    optimizer: torch.optim.Adam
    # Training device
    device: torch.device = DeviceConfigs()

    # Autoregressive model
    model: Module
    # Text dataset
    text: TextDataset
    # Batch size
    batch_size: int = 16
    # Length of the sequence, or context size
    seq_len: int = 512
    # Number of token in vocabulary
    n_tokens: int
    # Tokenizer
    tokenizer: Callable = 'character'

    # Text prompt to start sampling (for illustration)
    prompt: str
    # The token separator when sampling (blank for character level tokenization)
    prompt_separator: str

    # Whether to periodically save models
    is_save_models = True

    # Loss function
    loss_func = CrossEntropyLoss()
    # Accuracy function
    accuracy = Accuracy()
    # Model embedding size
    d_model: int = 512
    # Gradient clipping
    grad_norm_clip: float = 1.0

    # Training data loader
    train_loader: DataLoader = 'shuffled_train_loader'
    # Validation data loader
    valid_loader: DataLoader = 'shuffled_valid_loader'

    # Data loaders shuffle with replacement
    dataloader_shuffle_with_replacement: bool = False

    # Whether to log model parameters and gradients (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_params_grads: bool = False

    # Whether to log model activations (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_activations: bool = False

    def init(self):
        """
        ### Initialization
        """
        # Set tracker configurations
        tracker.set_scalar("accuracy.*", True)
        tracker.set_scalar("loss.*", True)
        # Add a hook to log module outputs
        hook_model_outputs(self.mode, self.model, 'model')
        # Add accuracy as a state module.
        # The name is probably confusing, since it's meant to store
        # states between training and validation for RNNs.
        # This will keep the accuracy metric stats separate for training and validation.
        self.state_modules = [self.accuracy]

    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
        """Override to calculate and log other metrics"""
        pass

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Set training/eval mode
        self.model.train(self.mode.is_train)

        # Move data to the device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Update global step (number of tokens processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last
                              and self.is_log_model_activations):
            # Get model outputs.
            # It's returning a tuple for states when using RNNs.
            # This is not implemented yet. 😜
            output, *_ = self.model(data)

        # Calculate and log loss
        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, target)
        self.accuracy.track()

        self.other_metrics(output, target)

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last and self.is_log_model_params_grads:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()

    def sample(self):
        """
        ### Sampling function to generate samples periodically while training
        """

        # Starting prompt
        prompt = self.prompt
        # Collect output for printing
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output, *_ = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.prompt_separator + self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.prompt_separator + self.text.itos[output[-1]],
                     Text.value)]

        # Print the sampled output
        logger.log(log)
Ejemplo n.º 5
0
class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
    """
    <a id="MNISTConfigs">
    ## Trainer configurations
    </a>
    """

    # Optimizer
    optimizer: torch.optim.Adam
    # Training device
    device: torch.device = DeviceConfigs()

    # Classification model
    model: Module
    # Number of epochs to train for
    epochs: int = 10

    # Number of times to switch between training and validation within an epoch
    inner_iterations = 10

    # Accuracy function
    accuracy = Accuracy()
    # Loss function
    loss_func = nn.CrossEntropyLoss()

    def init(self):
        """
        ### Initialization
        """
        # Set tracker configurations
        tracker.set_scalar("loss.*", True)
        tracker.set_scalar("accuracy.*", True)
        # Add a hook to log module outputs
        hook_model_outputs(self.mode, self.model, 'model')
        # Add accuracy as a state module.
        # The name is probably confusing, since it's meant to store
        # states between training and validation for RNNs.
        # This will keep the accuracy metric stats separate for training and validation.
        self.state_modules = [self.accuracy]

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Training/Evaluation mode
        self.model.train(self.mode.is_train)

        # Move data to the device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Update global step (number of samples processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last):
            # Get model outputs.
            output = self.model(data)

        # Calculate and log loss
        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, target)
        self.accuracy.track()

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()
Ejemplo n.º 6
0
class NLPClassificationConfigs(TrainValidConfigs):
    """
    <a id="NLPClassificationConfigs"></a>

    ## Trainer configurations

    This has the basic configurations for NLP classification task training.
    All the properties are configurable.
    """

    # Optimizer
    optimizer: torch.optim.Adam
    # Training device
    device: torch.device = DeviceConfigs()

    # Autoregressive model
    model: Module
    # Batch size
    batch_size: int = 16
    # Length of the sequence, or context size
    seq_len: int = 512
    # Vocabulary
    vocab: Vocab = 'ag_news'
    # Number of token in vocabulary
    n_tokens: int
    # Number of classes
    n_classes: int = 'ag_news'
    # Tokenizer
    tokenizer: Callable = 'character'

    # Whether to periodically save models
    is_save_models = True

    # Loss function
    loss_func = nn.CrossEntropyLoss()
    # Accuracy function
    accuracy = Accuracy()
    # Model embedding size
    d_model: int = 512
    # Gradient clipping
    grad_norm_clip: float = 1.0

    # Training data loader
    train_loader: DataLoader = 'ag_news'
    # Validation data loader
    valid_loader: DataLoader = 'ag_news'

    # Whether to log model parameters and gradients (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_params_grads: bool = False

    # Whether to log model activations (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_activations: bool = False

    def init(self):
        """
        ### Initialization
        """
        # Set tracker configurations
        tracker.set_scalar("accuracy.*", True)
        tracker.set_scalar("loss.*", True)
        # Add a hook to log module outputs
        hook_model_outputs(self.mode, self.model, 'model')
        # Add accuracy as a state module.
        # The name is probably confusing, since it's meant to store
        # states between training and validation for RNNs.
        # This will keep the accuracy metric stats separate for training and validation.
        self.state_modules = [self.accuracy]

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Move data to the device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Update global step (number of tokens processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(data.shape[1])

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last
                              and self.is_log_model_activations):
            # Get model outputs.
            # It's returning a tuple for states when using RNNs.
            # This is not implemented yet. 😜
            output, *_ = self.model(data)

        # Calculate and log loss
        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, target)
        self.accuracy.track()

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last and self.is_log_model_params_grads:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()
Ejemplo n.º 7
0
class Configs(TrainValidConfigs):
    optimizer: torch.optim.Adam
    device: torch.device = DeviceConfigs()

    model: Module
    text: TextDataset
    batch_size: int = 16
    seq_len: int = 512
    n_tokens: int
    n_layers: int = 2
    dropout: float = 0.2
    d_model: int = 512
    rnn_size: int = 512
    rhn_depth: int = 1
    tokenizer: Callable
    inner_iterations = 100

    is_save_models = True

    transformer: TransformerConfigs

    accuracy_func = Accuracy()
    loss_func: 'CrossEntropyLoss'

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy_func]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(len(data))

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output, *_ = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy_func(output, target)
        self.accuracy_func.track()
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()

            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()

    def sample(self):
        prompt = 'def train('
        log = [(prompt, Text.subtle)]
        for i in monit.iterate('Sample', 25):
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            output, *_ = self.model(data)
            output = output.argmax(dim=-1).squeeze()
            prompt += '' + self.text.itos[output[-1]]
            log += [('' + self.text.itos[output[-1]], Text.value)]

        logger.log(log)
Ejemplo n.º 8
0
class Configs(TrainValidConfigs):
    optimizer: torch.optim.Adam
    device: torch.device = DeviceConfigs()

    model: Module
    text = SourceCodeDataConfigs()
    n_tokens: int
    n_layers: int = 2
    dropout: float = 0.2
    d_model: int = 512
    rnn_size: int = 512
    rhn_depth: int = 1
    inner_iterations = 100

    is_save_models = True

    transformer: TransformerConfigs

    accuracy = Accuracy()
    loss_func: 'CrossEntropyLoss'

    state_updater: 'StateUpdater'
    state = SimpleStateModule()
    mem_len: int = 512
    grad_norm_clip: float = 1.0
    is_token_by_token: bool = False

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy, self.state]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(target.shape[0] * target.shape[1])

        with self.mode.update(is_log_activations=batch_idx.is_last):
            state = self.state.get()
            output, new_state = self.model(data, state)
            state = self.state_updater(state, new_state)
            self.state.set(state)

        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        self.accuracy(output, target)
        self.accuracy.track()

        if self.mode.is_train:
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()

    def sample(self):
        prompt = 'def train('
        log = [(prompt, Text.subtle)]
        state = None
        for i in monit.iterate('Sample', 25):
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            output, new_state = self.model(data, state)
            output = output.argmax(dim=-1).squeeze(1)
            prompt += '' + self.text.tokenizer.itos[output[-1]]
            if self.is_token_by_token:
                prompt = self.text.tokenizer.itos[output[-1]]
            else:
                prompt += '' + self.text.tokenizer.itos[output[-1]]
            log += [('' + self.text.tokenizer.itos[output[-1]], Text.value)]
            state = self.state_updater(state, new_state)

        logger.log(log)
Ejemplo n.º 9
0
class Configs(NLPAutoRegressionConfigs):
    """
    ## Configurations

    This inherits from
    [`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html)
    because it has the data pipeline implementations that we reuse here.
    We have implemented a custom training step form MLM.
    """

    # MLM model
    model: TransformerMLM
    # Transformer
    transformer: TransformerConfigs

    # Number of tokens
    n_tokens: int = 'n_tokens_mlm'
    # Tokens that shouldn't be masked
    no_mask_tokens: List[int] = []
    # Probability of masking a token
    masking_prob: float = 0.15
    # Probability of replacing the mask with a random token
    randomize_prob: float = 0.1
    # Probability of replacing the mask with original token
    no_change_prob: float = 0.1
    # [Masked Language Model (MLM) class](index.html) to generate the mask
    mlm: MLM

    # `[MASK]` token
    mask_token: int
    # `[PADDING]` token
    padding_token: int

    # Prompt to sample
    prompt: str = [
        "We are accounted poor citizens, the patricians good.",
        "What authority surfeits on would relieve us: if they",
        "would yield us but the superfluity, while it were",
        "wholesome, we might guess they relieved us humanely;",
        "but they think we are too dear: the leanness that",
        "afflicts us, the object of our misery, is as an",
        "inventory to particularise their abundance; our",
        "sufferance is a gain to them Let us revenge this with",
        "our pikes, ere we become rakes: for the gods know I",
        "speak this in hunger for bread, not in thirst for revenge.",
    ]

    def init(self):
        """
        ### Initialization
        """

        # `[MASK]` token
        self.mask_token = self.n_tokens - 1
        # `[PAD]` token
        self.padding_token = self.n_tokens - 2

        # [Masked Language Model (MLM) class](index.html) to generate the mask
        self.mlm = MLM(padding_token=self.padding_token,
                       mask_token=self.mask_token,
                       no_mask_tokens=self.no_mask_tokens,
                       n_tokens=self.n_tokens,
                       masking_prob=self.masking_prob,
                       randomize_prob=self.randomize_prob,
                       no_change_prob=self.no_change_prob)

        # Accuracy metric (ignore the labels equal to `[PAD]`)
        self.accuracy = Accuracy(ignore_index=self.padding_token)
        # Cross entropy loss (ignore the labels equal to `[PAD]`)
        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
        #
        super().init()

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Move the input to the device
        data = batch[0].to(self.device)

        # Update global step (number of tokens processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        # Get the masked input and labels
        with torch.no_grad():
            data, labels = self.mlm(data)

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last):
            # Get model outputs.
            # It's returning a tuple for states when using RNNs.
            # This is not implemented yet.
            output, *_ = self.model(data)

        # Calculate and log the loss
        loss = self.loss_func(output.view(-1, output.shape[-1]),
                              labels.view(-1))
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, labels)
        self.accuracy.track()

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()

    @torch.no_grad()
    def sample(self):
        """
        ### Sampling function to generate samples periodically while training
        """

        # Empty tensor for data filled with `[PAD]`.
        data = torch.full((self.seq_len, len(self.prompt)),
                          self.padding_token,
                          dtype=torch.long)
        # Add the prompts one by one
        for i, p in enumerate(self.prompt):
            # Get token indexes
            d = self.text.text_to_i(p)
            # Add to the tensor
            s = min(self.seq_len, len(d))
            data[:s, i] = d[:s]
        # Move the tensor to current device
        data = data.to(self.device)

        # Get masked input and labels
        data, labels = self.mlm(data)
        # Get model outputs
        output, *_ = self.model(data)

        # Print the samples generated
        for j in range(data.shape[1]):
            # Collect output from printing
            log = []
            # For each token
            for i in range(len(data)):
                # If the label is not `[PAD]`
                if labels[i, j] != self.padding_token:
                    # Get the prediction
                    t = output[i, j].argmax().item()
                    # If it's a printable character
                    if t < len(self.text.itos):
                        # Correct prediction
                        if t == labels[i, j]:
                            log.append((self.text.itos[t], Text.value))
                        # Incorrect prediction
                        else:
                            log.append((self.text.itos[t], Text.danger))
                    # If it's not a printable character
                    else:
                        log.append(('*', Text.danger))
                # If the label is `[PAD]` (unmasked) print the original.
                elif data[i, j] < len(self.text.itos):
                    log.append((self.text.itos[data[i, j]], Text.subtle))

            # Print
            logger.log(log)
Ejemplo n.º 10
0
class NLPAutoRegressionConfigs(TrainValidConfigs):
    optimizer: torch.optim.Adam
    device: torch.device = DeviceConfigs()

    model: Module
    text: TextDataset
    batch_size: int = 16
    seq_len: int = 512
    n_tokens: int
    tokenizer: Callable = 'character'

    prompt: str
    prompt_separator: str

    is_save_models = True

    loss_func = CrossEntropyLoss()
    accuracy = Accuracy()
    d_model: int = 512

    def init(self):
        tracker.set_scalar("accuracy.*", True)
        tracker.set_scalar("loss.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output, *_ = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy(output, target)
        self.accuracy.track()
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=1.)
            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()

    def sample(self):
        """
        Sampling function to generate samples periodically while training
        """
        prompt = self.prompt
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output, *_ = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.prompt_separator + self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.prompt_separator + self.text.itos[output[-1]],
                     Text.value)]

        logger.log(log)
Ejemplo n.º 11
0
class Configs(SimpleTrainValidConfigs):
    """
    ## Configurations

    The default configs can and will be over-ridden when we start the experiment
    """

    transformer: TransformerConfigs
    model: AutoregressiveModel
    text: TextDataset
    batch_size: int = 20
    seq_len: int = 32
    n_tokens: int
    tokenizer: Callable = 'character'

    is_save_models = True
    prompt: str
    prompt_separator: str

    is_save_ff_input = False
    optimizer: torch.optim.Adam = 'transformer_optimizer'

    accuracy = Accuracy()
    loss_func = CrossEntropyLoss()

    def init(self):
        # Create a configurable optimizer.
        # Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
        optimizer = OptimizerConfigs()
        optimizer.parameters = self.model.parameters()
        optimizer.d_model = self.transformer.d_model
        optimizer.optimizer = 'Noam'
        self.optimizer = optimizer

        # Create a sequential data loader for training
        self.train_loader = SequentialDataLoader(text=self.text.train,
                                                 dataset=self.text,
                                                 batch_size=self.batch_size,
                                                 seq_len=self.seq_len)

        # Create a sequential data loader for validation
        self.valid_loader = SequentialDataLoader(text=self.text.valid,
                                                 dataset=self.text,
                                                 batch_size=self.batch_size,
                                                 seq_len=self.seq_len)

        self.state_modules = [self.accuracy]

    def sample(self):
        """
        Sampling function to generate samples periodically while training
        """
        prompt = self.prompt
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.prompt_separator + self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

        logger.log(log)

    def step(self, batch: Any, batch_idx: BatchIndex):
        """
        This method is called for each batch
        """
        self.model.train(self.mode.is_train)

        # Get data and target labels
        data, target = batch[0].to(self.model.device), batch[1].to(self.model.device)

        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        # Run the model
        output = self.model(data)

        # Calculate loss
        loss = self.loss_func(output, target)
        # Calculate accuracy
        self.accuracy(output, target)

        # Log the loss
        tracker.add("loss.", loss)

        #  If we are in training mode, calculate the gradients
        if self.mode.is_train:
            loss.backward()
            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()