class Configs(TrainValidConfigs):
    device = DeviceConfigs()
    model: Module
    text: TextDataset
    batch_size: int = 16
    seq_len: int = 512
    n_tokens: int
    n_layers: int = 2
    dropout: float = 0.2
    d_model: int = 512
    rnn_size: int = 512
    rhn_depth: int = 1
    tokenizer: Callable
    inner_iterations = 100

    is_save_models = True

    transformer: TransformerConfigs

    def run(self):
        for _ in self.training_loop:
            prompt = 'def train('
            log = [(prompt, Text.subtle)]
            for i in monit.iterate('Sample', 25):
                data = self.text.text_to_i(prompt).unsqueeze(-1)
                data = data.to(self.device)
                output, *_ = self.model(data)
                output = output.argmax(dim=-1).squeeze()
                prompt += '' + self.text.itos[output[-1]]
                log += [('' + self.text.itos[output[-1]], Text.value)]

            logger.log(log)

            self.run_step()
Пример #2
0
class Configs(MNISTConfigs, TrainValidConfigs):
    """
    ## Configurable Experiment Definition
    """
    optimizer: torch.optim.Adam
    model: nn.Module
    set_seed = SeedConfigs()
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    model: nn.Module
    inner_iterations = 10

    accuracy_func = Accuracy()
    loss_func = nn.CrossEntropyLoss()

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy_func]

    def step(self, batch: any, batch_idx: BatchIndex):
        # Get the batch
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Add global step if we are in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Run the model and specify whether to log the activations
        with self.mode.update(is_log_activations=batch_idx.is_last):
            output = self.model(data)

        # Calculate the loss
        loss = self.loss_func(output, target)
        # Calculate the accuracy
        self.accuracy_func(output, target)
        # Log the loss
        tracker.add("loss.", loss)

        # Optimize if we are in training mode
        if self.mode.is_train:
            # Calculate the gradients
            loss.backward()

            # Take optimizer step
            self.optimizer.step()
            # Log the parameter and gradient L2 norms once per epoch
            if batch_idx.is_last:
                tracker.add('model', self.model)
                tracker.add('optimizer', (self.optimizer, {
                    'model': self.model
                }))
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save logs
        tracker.save()
Пример #3
0
class Configs(MNISTConfigs, TrainValidConfigs):
    """
    Configurations with MNIST data and Train & Validation setup
    """
    batch_step = 'capsule_network_batch_step'
    device: torch.device = DeviceConfigs()
    epochs: int = 10
    model = 'capsule_network_model'
Пример #4
0
class Configs(MNISTConfigs, TrainValidConfigs):
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    discriminator: Module
    generator: Module
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam
    discriminator_loss = DiscriminatorLogitsLoss()
    generator_loss = GeneratorLogitsLoss()
    batch_step = 'gan_batch_step'
Пример #5
0
class Configs(MNISTConfigs, TrainValidConfigs):
    seed = SeedConfigs()
    device: torch.device = DeviceConfigs()
    epochs: int = 10
    train_batch_size = 1
    valid_batch_size = 1

    is_save_models = True
    model: nn.Module

    loss_func = nn.CrossEntropyLoss()
    accuracy_func = SimpleAccuracy()
Пример #6
0
class Configs(MNISTConfigs, TrainValidConfigs):
    seed = SeedConfigs()
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    model: nn.Module

    learning_rate: float = 2.5e-4
    momentum: float = 0.5

    loss_func = 'cross_entropy_loss'
    accuracy_func = 'simple_accuracy'
Пример #7
0
class Configs(MNISTConfigs, TrainValidConfigs):
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    discriminator: Module
    generator: Module
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam
    generator_loss: GeneratorLogitsLoss
    discriminator_loss: DiscriminatorLogitsLoss
    batch_step = 'gan_batch_step'
    label_smoothing: float = 0.2
    discriminator_k: int = 1
Пример #8
0
class Configs(TrainValidConfigs):
    """
    ## Configurations

    The default configs can and will be over-ridden when we start the experiment
    """

    transformer: TransformerConfigs
    model: AutoregressiveModel
    device: torch.device = DeviceConfigs()
    text: TextDataset
    batch_size: int = 20
    seq_len: int = 32
    n_tokens: int
    tokenizer: Callable = 'character'

    is_save_models = True
    prompt: str
    prompt_separator: str

    is_save_ff_input = False
    optimizer: torch.optim.Adam = 'transformer_optimizer'

    batch_step = 'auto_regression_batch_step'

    def sample(self):
        """
        Sampling function to generate samples periodically while training
        """
        prompt = self.prompt
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.prompt_separator + self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.prompt_separator + self.text.itos[output[-1]],
                     Text.value)]

        logger.log(log)
Пример #9
0
class Configs(MNISTConfigs, TrainValidConfigs):
    optimizer: torch.optim.Adam
    model: nn.Module
    set_seed = SeedConfigs()
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    model: nn.Module
    inner_iterations = 10

    accuracy_func = Accuracy()
    loss_func = nn.CrossEntropyLoss()

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy_func]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(len(data))

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy_func(output, target)
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()

            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()
Пример #10
0
class Configs(TrainValidConfigs):
    device = DeviceConfigs()
    model: Module
    text: TextDataset
    batch_size: int = 16
    seq_len: int = 512
    n_tokens: int
    d_model: int = 512
    n_layers: int = 2
    dropout: float = 0.2
    d_lstm: int = 512
    tokenizer: Callable

    is_save_models = True

    transformer: TransformerConfigs

    def run(self):
        for _ in self.training_loop:
            prompt = 'def train('
            log = [(prompt, Text.subtle)]
            for i in monit.iterate('Sample', 25):
                data = self.text.text_to_i(prompt).unsqueeze(-1)
                data = data.to(self.device)
                output, *_ = self.model(data)
                output = output.argmax(dim=-1).squeeze()
                prompt += '' + self.text.itos[output[-1]]
                log += [('' + self.text.itos[output[-1]], Text.value)]

            logger.log(log)

            with Mode(is_train=True,
                      is_log_parameters=self.is_log_parameters,
                      is_log_activations=self.is_log_activations):
                with tracker.namespace('train'):
                    self.trainer()
            with tracker.namespace('valid'):
                self.validator()
Пример #11
0
class NLPAutoRegressionConfigs(TrainValidConfigs):
    optimizer: torch.optim.Adam
    device: torch.device = DeviceConfigs()

    model: Module
    text: TextDataset
    batch_size: int = 16
    seq_len: int = 512
    n_tokens: int
    tokenizer: Callable = 'character'

    prompt: str
    prompt_separator: str

    is_save_models = True

    loss_func = CrossEntropyLoss()
    accuracy = Accuracy()
    d_model: int = 512

    def init(self):
        tracker.set_scalar("accuracy.*", True)
        tracker.set_scalar("loss.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output, *_ = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy(output, target)
        self.accuracy.track()
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=1.)
            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()

    def sample(self):
        """
        Sampling function to generate samples periodically while training
        """
        prompt = self.prompt
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output, *_ = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.prompt_separator + self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.prompt_separator + self.text.itos[output[-1]],
                     Text.value)]

        logger.log(log)
Пример #12
0
class Configs(BaseConfigs):
    """## Configurations"""

    # `DeviceConfigs` will pick a GPU if available
    device: torch.device = DeviceConfigs()

    # Hyper-parameters
    epochs: int = 200
    dataset_name: str = 'monet2photo'
    batch_size: int = 1

    data_loader_workers = 8

    learning_rate = 0.0002
    adam_betas = (0.5, 0.999)
    decay_start = 100

    # The paper suggests using a least-squares loss instead of
    # negative log-likelihood, at it is found to be more stable.
    gan_loss = torch.nn.MSELoss()

    # L1 loss is used for cycle loss and identity loss
    cycle_loss = torch.nn.L1Loss()
    identity_loss = torch.nn.L1Loss()

    # Image dimensions
    img_height = 256
    img_width = 256
    img_channels = 3

    # Number of residual blocks in the generator
    n_residual_blocks = 9

    # Loss coefficients
    cyclic_loss_coefficient = 10.0
    identity_loss_coefficient = 5.

    sample_interval = 500

    # Models
    generator_xy: GeneratorResNet
    generator_yx: GeneratorResNet
    discriminator_x: Discriminator
    discriminator_y: Discriminator

    # Optimizers
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam

    # Learning rate schedules
    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

    # Data loaders
    dataloader: DataLoader
    valid_dataloader: DataLoader

    def sample_images(self, n: int):
        """Generate samples from test set and save them"""
        batch = next(iter(self.valid_dataloader))
        self.generator_xy.eval()
        self.generator_yx.eval()
        with torch.no_grad():
            data_x, data_y = batch['x'].to(
                self.generator_xy.device), batch['y'].to(
                    self.generator_yx.device)
            gen_y = self.generator_xy(data_x)
            gen_x = self.generator_yx(data_y)

            # Arrange images along x-axis
            data_x = make_grid(data_x, nrow=5, normalize=True)
            data_y = make_grid(data_y, nrow=5, normalize=True)
            gen_x = make_grid(gen_x, nrow=5, normalize=True)
            gen_y = make_grid(gen_y, nrow=5, normalize=True)

            # Arrange images along y-axis
            image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1)

        # Show samples
        plot_image(image_grid)

    def initialize(self):
        """
        ## Initialize models and data loaders
        """
        input_shape = (self.img_channels, self.img_height, self.img_width)

        # Create the models
        self.generator_xy = GeneratorResNet(
            self.img_channels, self.n_residual_blocks).to(self.device)
        self.generator_yx = GeneratorResNet(
            self.img_channels, self.n_residual_blocks).to(self.device)
        self.discriminator_x = Discriminator(input_shape).to(self.device)
        self.discriminator_y = Discriminator(input_shape).to(self.device)

        # Create the optmizers
        self.generator_optimizer = torch.optim.Adam(itertools.chain(
            self.generator_xy.parameters(), self.generator_yx.parameters()),
                                                    lr=self.learning_rate,
                                                    betas=self.adam_betas)
        self.discriminator_optimizer = torch.optim.Adam(itertools.chain(
            self.discriminator_x.parameters(),
            self.discriminator_y.parameters()),
                                                        lr=self.learning_rate,
                                                        betas=self.adam_betas)

        # Create the learning rate schedules.
        # The learning rate stars flat until `decay_start` epochs,
        # and then linearly reduce to $0$ at end of training.
        decay_epochs = self.epochs - self.decay_start
        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.generator_optimizer,
            lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start
                                          ) / decay_epochs)
        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.discriminator_optimizer,
            lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start
                                          ) / decay_epochs)

        # Image transformations
        transforms_ = [
            transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC),
            transforms.RandomCrop((self.img_height, self.img_width)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]

        # Training data loader
        self.dataloader = DataLoader(
            ImageDataset(self.dataset_name, transforms_, 'train'),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.data_loader_workers,
        )

        # Validation data loader
        self.valid_dataloader = DataLoader(
            ImageDataset(self.dataset_name, transforms_, "test"),
            batch_size=5,
            shuffle=True,
            num_workers=self.data_loader_workers,
        )

    def run(self):
        """
        ## Training

        We aim to solve:
        $$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$

        where,
        $G$ translates images from $X \rightarrow Y$,
        $F$ translates images from $Y \rightarrow X$,
        $D_X$ tests if images are from $X$ space,
        $D_Y$ tests if images are from $Y$ space, and
        \begin{align}
        \mathcal{L}(G, F, D_X, D_Y)
            &= \mathcal{L}_{GAN}(G, D_Y, X, Y) \\
            &+ \mathcal{L}_{GAN}(F, D_X, Y, X) \\
            &+ \lambda_1 \mathcal{L}_{cyc}(G, F) \\
            &+ \lambda_2 \mathcal{L}_{identity}(G, F) \\
        \\
        \mathcal{L}_{GAN}(G, F, D_Y, X, Y)
            &= \mathbb{E}_{y \sim p_{data}(y)} \Big[log D_Y(y)\Big] \\
            &+ \mathbb{E}_{x \sim p_{data}(x)} \bigg[log\Big(1 - D_Y(G(x))\Big)\bigg] \\
            &+ \mathbb{E}_{x \sim p_{data}(x)} \Big[log D_X(x)\Big] \\
            &+ \mathbb{E}_{y \sim p_{data}(y)} \bigg[log\Big(1 - D_X(F(y))\Big)\bigg] \\
        \\
        \mathcal{L}_{cyc}(G, F)
            &= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(G(x)) - x \lVert_1\Big] \\
            &+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(F(y)) - y \rVert_1\Big] \\
        \\
        \mathcal{L}_{identity}(G, F)
            &= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(x) - x \lVert_1\Big] \\
            &+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(y) - y \rVert_1\Big] \\
        \end{align}

        $\mathcal{L}_{GAN}$ is the generative adversarial loss from the original
        GAN paper.

        $\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$,
        and $G(F(y))$ to be similar to $y$.
        Basically if the two generators (transformations) are applied in series it should give back the
        original image.
        This is the main contribution of this paper.
        It trains the generators to generate an image of the other distribution that is similar to
        the original image.
        Without this loss $G(x)$ could generate anything that's from the distribution of $Y$.
        Now it needs to generate something from the distribution of $Y$ but still has properties of $x$,
        so that $F(G(x)$ can re-generate something like $x$.

        $\mathcal{L}_{cyc}$ is the identity loss.
        This was used to encourage the mapping to preserve color composition between
        the input and the output.

        To solve $G^{\*}, F^{\*}$,
        discriminators $D_X$ and $D_Y$ should **ascend** on the gradient,
        \begin{align}
        \nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
        &\Bigg[
        \log D_Y\Big(y^{(i)}\Big) \\
        &+ \log \Big(1 - D_Y\Big(G\Big(x^{(i)}\Big)\Big)\Big) \\
        &+ \log D_X\Big(x^{(i)}\Big) \\
        & +\log\Big(1 - D_X\Big(F\Big(y^{(i)}\Big)\Big)\Big)
        \Bigg]
        \end{align}
        That is descend on *negative* log-likelihood loss.

        In order to stabilize the training the negative log- likelihood objective
        was replaced by a least-squared loss -
        the least-squared error of discriminator, labelling real images with 1,
        and generated images with 0.
        So we want to descend on the gradient,
        \begin{align}
        \nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m
        &\Bigg[
            \bigg(D_Y\Big(y^{(i)}\Big) - 1\bigg)^2 \\
            &+ D_Y\Big(G\Big(x^{(i)}\Big)\Big)^2 \\
            &+ \bigg(D_X\Big(x^{(i)}\Big) - 1\bigg)^2 \\
            &+ D_X\Big(F\Big(y^{(i)}\Big)\Big)^2
        \Bigg]
        \end{align}

        We use least-squares for generators also.
        The generators should *descend* on the gradient,
        \begin{align}
        \nabla_{\theta_{F, G}} \frac{1}{m} \sum_{i=1}^m
        &\Bigg[
            \bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 \\
            &+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2 \\
            &+ \mathcal{L}_{cyc}(G, F)
            + \mathcal{L}_{identity}(G, F)
        \Bigg]
        \end{align}

        We use `generator_xy` for $G$ and `generator_yx$ for $F$.
        We use `discriminator_x$ for $D_X$ and `discriminator_y` for $D_Y$.
        """

        # Replay buffers to keep generated samples
        gen_x_buffer = ReplayBuffer()
        gen_y_buffer = ReplayBuffer()

        # Loop through epochs
        for epoch in monit.loop(self.epochs):
            # Loop through the dataset
            for i, batch in monit.enum('Train', self.dataloader):
                # Move images to the device
                data_x, data_y = batch['x'].to(self.device), batch['y'].to(
                    self.device)

                # true labels equal to $1$
                true_labels = torch.ones(data_x.size(0),
                                         *self.discriminator_x.output_shape,
                                         device=self.device,
                                         requires_grad=False)
                # false labels equal to $0$
                false_labels = torch.zeros(data_x.size(0),
                                           *self.discriminator_x.output_shape,
                                           device=self.device,
                                           requires_grad=False)

                # Train the generators.
                # This returns the generated images.
                gen_x, gen_y = self.optimize_generators(
                    data_x, data_y, true_labels)

                #  Train discriminators
                self.optimize_discriminator(data_x, data_y,
                                            gen_x_buffer.push_and_pop(gen_x),
                                            gen_y_buffer.push_and_pop(gen_y),
                                            true_labels, false_labels)

                # Save training statistics and increment the global step counter
                tracker.save()
                tracker.add_global_step(max(len(data_x), len(data_y)))

                # Save images at intervals
                batches_done = epoch * len(self.dataloader) + i
                if batches_done % self.sample_interval == 0:
                    # Save models when sampling images
                    experiment.save_checkpoint()
                    # Sample images
                    self.sample_images(batches_done)

            # Update learning rates
            self.generator_lr_scheduler.step()
            self.discriminator_lr_scheduler.step()
            # New line
            tracker.new_line()

    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor,
                            true_labels: torch.Tensor):
        """
        ### Optimize the generators with identity, gan and cycle losses.
        """

        #  Change to training mode
        self.generator_xy.train()
        self.generator_yx.train()

        # Identity loss
        # $$\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\
        #   \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1$$
        loss_identity = (
            self.identity_loss(self.generator_yx(data_x), data_x) +
            self.identity_loss(self.generator_xy(data_y), data_y))

        # Generate images $G(x)$ and $F(y)$
        gen_y = self.generator_xy(data_x)
        gen_x = self.generator_yx(data_y)

        # GAN loss
        # $$\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2
        #  + \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2$$
        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

        # Cycle loss
        # $$
        # \lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 +
        # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1
        # $$
        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
                      self.cycle_loss(self.generator_xy(gen_x), data_y))

        # Total loss
        loss_generator = (loss_gan +
                          self.cyclic_loss_coefficient * loss_cycle +
                          self.identity_loss_coefficient * loss_identity)

        # Take a step in the optimizer
        self.generator_optimizer.zero_grad()
        loss_generator.backward()
        self.generator_optimizer.step()

        # Log losses
        tracker.add({
            'loss.generator': loss_generator,
            'loss.generator.cycle': loss_cycle,
            'loss.generator.gan': loss_gan,
            'loss.generator.identity': loss_identity
        })

        # Return generated images
        return gen_x, gen_y

    def optimize_discriminator(self, data_x: torch.Tensor,
                               data_y: torch.Tensor, gen_x: torch.Tensor,
                               gen_y: torch.Tensor, true_labels: torch.Tensor,
                               false_labels: torch.Tensor):
        """
        ### Optimize the discriminators with gan loss.
        """
        # GAN Loss
        # \begin{align}
        # \bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2
        # + D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\
        # \bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2
        # + D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2
        # \end{align}
        loss_discriminator = (
            self.gan_loss(self.discriminator_x(data_x), true_labels) +
            self.gan_loss(self.discriminator_x(gen_x), false_labels) +
            self.gan_loss(self.discriminator_y(data_y), true_labels) +
            self.gan_loss(self.discriminator_y(gen_y), false_labels))

        # Take a step in the optimizer
        self.discriminator_optimizer.zero_grad()
        loss_discriminator.backward()
        self.discriminator_optimizer.step()

        # Log losses
        tracker.add({'loss.discriminator': loss_discriminator})
Пример #13
0
class NLPAutoRegressionConfigs(TrainValidConfigs):
    """
    <a id="NLPAutoRegressionConfigs"></a>

    ## Trainer configurations

    This has the basic configurations for NLP auto-regressive task training.
    All the properties are configurable.
    """

    # Optimizer
    optimizer: torch.optim.Adam
    # Training device
    device: torch.device = DeviceConfigs()

    # Autoregressive model
    model: Module
    # Text dataset
    text: TextDataset
    # Batch size
    batch_size: int = 16
    # Length of the sequence, or context size
    seq_len: int = 512
    # Number of token in vocabulary
    n_tokens: int
    # Tokenizer
    tokenizer: Callable = 'character'

    # Text prompt to start sampling (for illustration)
    prompt: str
    # The token separator when sampling (blank for character level tokenization)
    prompt_separator: str

    # Whether to periodically save models
    is_save_models = True

    # Loss function
    loss_func = CrossEntropyLoss()
    # Accuracy function
    accuracy = Accuracy()
    # Model embedding size
    d_model: int = 512
    # Gradient clipping
    grad_norm_clip: float = 1.0

    # Training data loader
    train_loader: DataLoader = 'shuffled_train_loader'
    # Validation data loader
    valid_loader: DataLoader = 'shuffled_valid_loader'

    # Data loaders shuffle with replacement
    dataloader_shuffle_with_replacement: bool = False

    # Whether to log model parameters and gradients (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_params_grads: bool = False

    # Whether to log model activations (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_activations: bool = False

    def init(self):
        """
        ### Initialization
        """
        # Set tracker configurations
        tracker.set_scalar("accuracy.*", True)
        tracker.set_scalar("loss.*", True)
        # Add a hook to log module outputs
        hook_model_outputs(self.mode, self.model, 'model')
        # Add accuracy as a state module.
        # The name is probably confusing, since it's meant to store
        # states between training and validation for RNNs.
        # This will keep the accuracy metric stats separate for training and validation.
        self.state_modules = [self.accuracy]

    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
        """Override to calculate and log other metrics"""
        pass

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Set training/eval mode
        self.model.train(self.mode.is_train)

        # Move data to the device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Update global step (number of tokens processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last
                              and self.is_log_model_activations):
            # Get model outputs.
            # It's returning a tuple for states when using RNNs.
            # This is not implemented yet. 😜
            output, *_ = self.model(data)

        # Calculate and log loss
        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, target)
        self.accuracy.track()

        self.other_metrics(output, target)

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last and self.is_log_model_params_grads:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()

    def sample(self):
        """
        ### Sampling function to generate samples periodically while training
        """

        # Starting prompt
        prompt = self.prompt
        # Collect output for printing
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output, *_ = self.model(data)
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.prompt_separator + self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.prompt_separator + self.text.itos[output[-1]],
                     Text.value)]

        # Print the sampled output
        logger.log(log)
Пример #14
0
class Configs(TrainValidConfigs):
    optimizer: torch.optim.Adam
    device: torch.device = DeviceConfigs()

    model: Module
    text = SourceCodeDataConfigs()
    n_tokens: int
    n_layers: int = 2
    dropout: float = 0.2
    d_model: int = 512
    rnn_size: int = 512
    rhn_depth: int = 1
    inner_iterations = 100

    is_save_models = True

    transformer: TransformerConfigs

    accuracy = Accuracy()
    loss_func: 'CrossEntropyLoss'

    state_updater: 'StateUpdater'
    state = SimpleStateModule()
    mem_len: int = 512
    grad_norm_clip: float = 1.0
    is_token_by_token: bool = False

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy, self.state]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(target.shape[0] * target.shape[1])

        with self.mode.update(is_log_activations=batch_idx.is_last):
            state = self.state.get()
            output, new_state = self.model(data, state)
            state = self.state_updater(state, new_state)
            self.state.set(state)

        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        self.accuracy(output, target)
        self.accuracy.track()

        if self.mode.is_train:
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()

    def sample(self):
        prompt = 'def train('
        log = [(prompt, Text.subtle)]
        state = None
        for i in monit.iterate('Sample', 25):
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            output, new_state = self.model(data, state)
            output = output.argmax(dim=-1).squeeze(1)
            prompt += '' + self.text.tokenizer.itos[output[-1]]
            if self.is_token_by_token:
                prompt = self.text.tokenizer.itos[output[-1]]
            else:
                prompt += '' + self.text.tokenizer.itos[output[-1]]
            log += [('' + self.text.tokenizer.itos[output[-1]], Text.value)]
            state = self.state_updater(state, new_state)

        logger.log(log)
Пример #15
0
class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
    """
    <a id="MNISTConfigs">
    ## Trainer configurations
    </a>
    """

    # Optimizer
    optimizer: torch.optim.Adam
    # Training device
    device: torch.device = DeviceConfigs()

    # Classification model
    model: Module
    # Number of epochs to train for
    epochs: int = 10

    # Number of times to switch between training and validation within an epoch
    inner_iterations = 10

    # Accuracy function
    accuracy = Accuracy()
    # Loss function
    loss_func = nn.CrossEntropyLoss()

    def init(self):
        """
        ### Initialization
        """
        # Set tracker configurations
        tracker.set_scalar("loss.*", True)
        tracker.set_scalar("accuracy.*", True)
        # Add a hook to log module outputs
        hook_model_outputs(self.mode, self.model, 'model')
        # Add accuracy as a state module.
        # The name is probably confusing, since it's meant to store
        # states between training and validation for RNNs.
        # This will keep the accuracy metric stats separate for training and validation.
        self.state_modules = [self.accuracy]

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Training/Evaluation mode
        self.model.train(self.mode.is_train)

        # Move data to the device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Update global step (number of samples processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last):
            # Get model outputs.
            output = self.model(data)

        # Calculate and log loss
        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, target)
        self.accuracy.track()

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()
Пример #16
0
class Configs(TrainValidConfigs):
    """
    ## Configurations

    These are default configurations which can later be adjusted by passing a `dict`.
    """

    # Device configurations to pick the device to run the experiment
    device: torch.device = DeviceConfigs()
    #
    encoder: EncoderRNN
    decoder: DecoderRNN
    optimizer: optim.Adam
    sampler: Sampler

    dataset_name: str
    train_loader: DataLoader
    valid_loader: DataLoader
    train_dataset: StrokesDataset
    valid_dataset: StrokesDataset

    # Encoder and decoder sizes
    enc_hidden_size = 256
    dec_hidden_size = 512

    # Batch size
    batch_size = 100

    # Number of features in $z$
    d_z = 128
    # Number of distributions in the mixture, $M$
    n_distributions = 20

    # Weight of KL divergence loss, $w_{KL}$
    kl_div_loss_weight = 0.5
    # Gradient clipping
    grad_clip = 1.
    # Temperature $\tau$ for sampling
    temperature = 0.4

    # Filter out stroke sequences longer than $200$
    max_seq_length = 200

    epochs = 100

    kl_div_loss = KLDivLoss()
    reconstruction_loss = ReconstructionLoss()

    def init(self):
        # Initialize encoder & decoder
        self.encoder = EncoderRNN(self.d_z,
                                  self.enc_hidden_size).to(self.device)
        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size,
                                  self.n_distributions).to(self.device)

        # Set optimizer. Things like type of optimizer and learning rate are configurable
        optimizer = OptimizerConfigs()
        optimizer.parameters = list(self.encoder.parameters()) + list(
            self.decoder.parameters())
        self.optimizer = optimizer

        # Create sampler
        self.sampler = Sampler(self.encoder, self.decoder)

        # `npz` file path is `data/sketch/[DATASET NAME].npz`
        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
        # Load the numpy file
        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

        # Create training dataset
        self.train_dataset = StrokesDataset(dataset['train'],
                                            self.max_seq_length)
        # Create validation dataset
        self.valid_dataset = StrokesDataset(dataset['valid'],
                                            self.max_seq_length,
                                            self.train_dataset.scale)

        # Create training data loader
        self.train_loader = DataLoader(self.train_dataset,
                                       self.batch_size,
                                       shuffle=True)
        # Create validation data loader
        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

        # Add hooks to monitor layer outputs on Tensorboard
        hook_model_outputs(self.mode, self.encoder, 'encoder')
        hook_model_outputs(self.mode, self.decoder, 'decoder')

        # Configure the tracker to print the total train/validation loss
        tracker.set_scalar("loss.total.*", True)

        self.state_modules = []

    def step(self, batch: Any, batch_idx: BatchIndex):
        self.encoder.train(self.mode.is_train)
        self.decoder.train(self.mode.is_train)

        # Move `data` and `mask` to device and swap the sequence and batch dimensions.
        # `data` will have shape `[seq_len, batch_size, 5]` and
        # `mask` will have shape `[seq_len, batch_size]`.
        data = batch[0].to(self.device).transpose(0, 1)
        mask = batch[1].to(self.device).transpose(0, 1)

        # Increment step in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Encode the sequence of strokes
        with monit.section("encoder"):
            # Get $z$, $\mu$, and $\hat{\sigma}$
            z, mu, sigma_hat = self.encoder(data)

        # Decode the mixture of distributions and $\hat{q}$
        with monit.section("decoder"):
            # Concatenate $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$
            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
            inputs = torch.cat([data[:-1], z_stack], 2)
            # Get mixture of distributions and $\hat{q}$
            dist, q_logits, _ = self.decoder(inputs, z, None)

        # Compute the loss
        with monit.section('loss'):
            # $L_{KL}$
            kl_loss = self.kl_div_loss(sigma_hat, mu)
            # $L_R$
            reconstruction_loss = self.reconstruction_loss(
                mask, data[1:], dist, q_logits)
            # $Loss = L_R + w_{KL} L_{KL}$
            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

            # Track losses
            tracker.add("loss.kl.", kl_loss)
            tracker.add("loss.reconstruction.", reconstruction_loss)
            tracker.add("loss.total.", loss)

        # Only if we are in training state
        if self.mode.is_train:
            # Run optimizer
            with monit.section('optimize'):
                # Set `grad` to zero
                self.optimizer.zero_grad()
                # Compute gradients
                loss.backward()
                # Log model parameters and gradients
                if batch_idx.is_last:
                    tracker.add(encoder=self.encoder, decoder=self.decoder)
                # Clip gradients
                nn.utils.clip_grad_norm_(self.encoder.parameters(),
                                         self.grad_clip)
                nn.utils.clip_grad_norm_(self.decoder.parameters(),
                                         self.grad_clip)
                # Optimize
                self.optimizer.step()

        tracker.save()

    def sample(self):
        # Randomly pick a sample from validation dataset to encoder
        data, *_ = self.valid_dataset[np.random.choice(len(
            self.valid_dataset))]
        # Add batch dimension and move it to device
        data = data.unsqueeze(1).to(self.device)
        # Sample
        self.sampler.sample(data, self.temperature)
Пример #17
0
class Configs(BaseConfigs):
    """
    ## Configurations
    """
    # Device to train the model on.
    # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
    #  picks up an available CUDA device or defaults to CPU.
    device: torch.device = DeviceConfigs()

    # U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$
    eps_model: UNet
    # [DDPM algorithm](index.html)
    diffusion: DenoiseDiffusion

    # Number of channels in the image. $3$ for RGB.
    image_channels: int = 3
    # Image size
    image_size: int = 32
    # Number of channels in the initial feature map
    n_channels: int = 64
    # The list of channel numbers at each resolution.
    # The number of channels is `channel_multipliers[i] * n_channels`
    channel_multipliers: List[int] = [1, 2, 2, 4]
    # The list of booleans that indicate whether to use attention at each resolution
    is_attention: List[int] = [False, False, False, True]

    # Number of time steps $T$
    n_steps: int = 1_000
    # Batch size
    batch_size: int = 64
    # Number of samples to generate
    n_samples: int = 16
    # Learning rate
    learning_rate: float = 2e-5

    # Number of training epochs
    epochs: int = 1_000

    # Dataset
    dataset: torch.utils.data.Dataset
    # Dataloader
    data_loader: torch.utils.data.DataLoader

    # Adam optimizer
    optimizer: torch.optim.Adam

    def init(self):
        # Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
        self.eps_model = UNet(
            image_channels=self.image_channels,
            n_channels=self.n_channels,
            ch_mults=self.channel_multipliers,
            is_attn=self.is_attention,
        ).to(self.device)

        # Create [DDPM class](index.html)
        self.diffusion = DenoiseDiffusion(
            eps_model=self.eps_model,
            n_steps=self.n_steps,
            device=self.device,
        )

        # Create dataloader
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       self.batch_size,
                                                       shuffle=True,
                                                       pin_memory=True)
        # Create optimizer
        self.optimizer = torch.optim.Adam(self.eps_model.parameters(),
                                          lr=self.learning_rate)

        # Image logging
        tracker.set_image("sample", True)

    def sample(self):
        """
        ### Sample images
        """
        with torch.no_grad():
            # $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
            x = torch.randn([
                self.n_samples, self.image_channels, self.image_size,
                self.image_size
            ],
                            device=self.device)

            # Remove noise for $T$ steps
            for t_ in monit.iterate('Sample', self.n_steps):
                # $t$
                t = self.n_steps - t_ - 1
                # Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
                x = self.diffusion.p_sample(
                    x, x.new_full((self.n_samples, ), t, dtype=torch.long))

            # Log samples
            tracker.save('sample', x)

    def train(self):
        """
        ### Train
        """

        # Iterate through the dataset
        for data in monit.iterate('Train', self.data_loader):
            # Increment global step
            tracker.add_global_step()
            # Move data to device
            data = data.to(self.device)

            # Make the gradients zero
            self.optimizer.zero_grad()
            # Calculate loss
            loss = self.diffusion.loss(data)
            # Compute gradients
            loss.backward()
            # Take an optimization step
            self.optimizer.step()
            # Track the loss
            tracker.save('loss', loss)

    def run(self):
        """
        ### Training loop
        """
        for _ in monit.loop(self.epochs):
            # Train the model
            self.train()
            # Sample some images
            self.sample()
            # New line in the console
            tracker.new_line()
            # Save the model
            experiment.save_checkpoint()
Пример #18
0
class Configs(TrainValidConfigs):
    optimizer: torch.optim.Adam
    device: torch.device = DeviceConfigs()

    model: Module
    text: TextDataset
    batch_size: int = 16
    seq_len: int = 512
    n_tokens: int
    n_layers: int = 2
    dropout: float = 0.2
    d_model: int = 512
    rnn_size: int = 512
    rhn_depth: int = 1
    tokenizer: Callable
    inner_iterations = 100

    is_save_models = True

    transformer: TransformerConfigs

    accuracy_func = Accuracy()
    loss_func: 'CrossEntropyLoss'

    def init(self):
        tracker.set_queue("loss.*", 20, True)
        tracker.set_scalar("accuracy.*", True)
        hook_model_outputs(self.mode, self.model, 'model')
        self.state_modules = [self.accuracy_func]

    def step(self, batch: any, batch_idx: BatchIndex):
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(len(data))

        with self.mode.update(is_log_activations=batch_idx.is_last):
            output, *_ = self.model(data)

        loss = self.loss_func(output, target)
        self.accuracy_func(output, target)
        self.accuracy_func.track()
        tracker.add("loss.", loss)

        if self.mode.is_train:
            loss.backward()

            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()

    def sample(self):
        prompt = 'def train('
        log = [(prompt, Text.subtle)]
        for i in monit.iterate('Sample', 25):
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            output, *_ = self.model(data)
            output = output.argmax(dim=-1).squeeze()
            prompt += '' + self.text.itos[output[-1]]
            log += [('' + self.text.itos[output[-1]], Text.value)]

        logger.log(log)
Пример #19
0
class Configs(BaseConfigs):
    """
    ## Configurations
    """

    # Model
    model: GAT
    # Number of nodes to train on
    training_samples: int = 500
    # Number of features per node in the input
    in_features: int
    # Number of features in the first graph attention layer
    n_hidden: int = 64
    # Number of heads
    n_heads: int = 8
    # Number of classes for classification
    n_classes: int
    # Dropout probability
    dropout: float = 0.6
    # Whether to include the citation network
    include_edges: bool = True
    # Dataset
    dataset: CoraDataset
    # Number of training iterations
    epochs: int = 1_000
    # Loss function
    loss_func = nn.CrossEntropyLoss()
    # Device to train on
    #
    # This creates configs for device, so that
    # we can change the device by passing a config value
    device: torch.device = DeviceConfigs()
    # Optimizer
    optimizer: torch.optim.Adam

    def run(self):
        """
        ### Training loop

        We do full batch training since the dataset is small.
        If we were to sample and train we will have to sample a set of
        nodes for each training step along with the edges that span
        across those selected nodes.
        """
        # Move the feature vectors to the device
        features = self.dataset.features.to(self.device)
        # Move the labels to the device
        labels = self.dataset.labels.to(self.device)
        # Move the adjacency matrix to the device
        edges_adj = self.dataset.adj_mat.to(self.device)
        # Add an empty third dimension for the heads
        edges_adj = edges_adj.unsqueeze(-1)

        # Random indexes
        idx_rand = torch.randperm(len(labels))
        # Nodes for training
        idx_train = idx_rand[:self.training_samples]
        # Nodes for validation
        idx_valid = idx_rand[self.training_samples:]

        # Training loop
        for epoch in monit.loop(self.epochs):
            # Set the model to training mode
            self.model.train()
            # Make all the gradients zero
            self.optimizer.zero_grad()
            # Evaluate the model
            output = self.model(features, edges_adj)
            # Get the loss for training nodes
            loss = self.loss_func(output[idx_train], labels[idx_train])
            # Calculate gradients
            loss.backward()
            # Take optimization step
            self.optimizer.step()
            # Log the loss
            tracker.add('loss.train', loss)
            # Log the accuracy
            tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))

            # Set mode to evaluation mode for validation
            self.model.eval()

            # No need to compute gradients
            with torch.no_grad():
                # Evaluate the model again
                output = self.model(features, edges_adj)
                # Calculate the loss for validation nodes
                loss = self.loss_func(output[idx_valid], labels[idx_valid])
                # Log the loss
                tracker.add('loss.valid', loss)
                # Log the accuracy
                tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))

            # Save logs
            tracker.save()
Пример #20
0
class NLPClassificationConfigs(TrainValidConfigs):
    """
    <a id="NLPClassificationConfigs"></a>

    ## Trainer configurations

    This has the basic configurations for NLP classification task training.
    All the properties are configurable.
    """

    # Optimizer
    optimizer: torch.optim.Adam
    # Training device
    device: torch.device = DeviceConfigs()

    # Autoregressive model
    model: Module
    # Batch size
    batch_size: int = 16
    # Length of the sequence, or context size
    seq_len: int = 512
    # Vocabulary
    vocab: Vocab = 'ag_news'
    # Number of token in vocabulary
    n_tokens: int
    # Number of classes
    n_classes: int = 'ag_news'
    # Tokenizer
    tokenizer: Callable = 'character'

    # Whether to periodically save models
    is_save_models = True

    # Loss function
    loss_func = nn.CrossEntropyLoss()
    # Accuracy function
    accuracy = Accuracy()
    # Model embedding size
    d_model: int = 512
    # Gradient clipping
    grad_norm_clip: float = 1.0

    # Training data loader
    train_loader: DataLoader = 'ag_news'
    # Validation data loader
    valid_loader: DataLoader = 'ag_news'

    # Whether to log model parameters and gradients (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_params_grads: bool = False

    # Whether to log model activations (once per epoch).
    # These are summarized stats per layer, but it could still lead
    # to many indicators for very deep networks.
    is_log_model_activations: bool = False

    def init(self):
        """
        ### Initialization
        """
        # Set tracker configurations
        tracker.set_scalar("accuracy.*", True)
        tracker.set_scalar("loss.*", True)
        # Add a hook to log module outputs
        hook_model_outputs(self.mode, self.model, 'model')
        # Add accuracy as a state module.
        # The name is probably confusing, since it's meant to store
        # states between training and validation for RNNs.
        # This will keep the accuracy metric stats separate for training and validation.
        self.state_modules = [self.accuracy]

    def step(self, batch: any, batch_idx: BatchIndex):
        """
        ### Training or validation step
        """

        # Move data to the device
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Update global step (number of tokens processed) when in training mode
        if self.mode.is_train:
            tracker.add_global_step(data.shape[1])

        # Whether to capture model outputs
        with self.mode.update(is_log_activations=batch_idx.is_last
                              and self.is_log_model_activations):
            # Get model outputs.
            # It's returning a tuple for states when using RNNs.
            # This is not implemented yet. 😜
            output, *_ = self.model(data)

        # Calculate and log loss
        loss = self.loss_func(output, target)
        tracker.add("loss.", loss)

        # Calculate and log accuracy
        self.accuracy(output, target)
        self.accuracy.track()

        # Train the model
        if self.mode.is_train:
            # Calculate gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=self.grad_norm_clip)
            # Take optimizer step
            self.optimizer.step()
            # Log the model parameters and gradients on last batch of every epoch
            if batch_idx.is_last and self.is_log_model_params_grads:
                tracker.add('model', self.model)
            # Clear the gradients
            self.optimizer.zero_grad()

        # Save the tracked metrics
        tracker.save()
Пример #21
0
class Configs(BaseConfigs):
    """
    ## Configurations
    """

    # Device to train the model on.
    # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
    #  picks up an available CUDA device or defaults to CPU.
    device: torch.device = DeviceConfigs()

    # [StyleGAN2 Discriminator](index.html#discriminator)
    discriminator: Discriminator
    # [StyleGAN2 Generator](index.html#generator)
    generator: Generator
    # [Mapping network](index.html#mapping_network)
    mapping_network: MappingNetwork

    # Discriminator and generator loss functions.
    # We use [Wasserstein loss](../wasserstein/index.html)
    discriminator_loss: DiscriminatorLoss
    generator_loss: GeneratorLoss

    # Optimizers
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam
    mapping_network_optimizer: torch.optim.Adam

    # [Gradient Penalty Regularization Loss](index.html#gradient_penalty)
    gradient_penalty = GradientPenalty()
    # Gradient penalty coefficient $\gamma$
    gradient_penalty_coefficient: float = 10.

    # [Path length penalty](index.html#path_length_penalty)
    path_length_penalty: PathLengthPenalty

    # Data loader
    loader: Iterator

    # Batch size
    batch_size: int = 32
    # Dimensionality of $z$ and $w$
    d_latent: int = 512
    # Height/width of the image
    image_size: int = 32
    # Number of layers in the mapping network
    mapping_network_layers: int = 8
    # Generator & Discriminator learning rate
    learning_rate: float = 1e-3
    # Mapping network learning rate ($100 \times$ lower than the others)
    mapping_network_learning_rate: float = 1e-5
    # Number of steps to accumulate gradients on. Use this to increase the effective batch size.
    gradient_accumulate_steps: int = 1
    # $\beta_1$ and $\beta_2$ for Adam optimizer
    adam_betas: Tuple[float, float] = (0.0, 0.99)
    # Probability of mixing styles
    style_mixing_prob: float = 0.9

    # Total number of training steps
    training_steps: int = 150_000

    # Number of blocks in the generator (calculated based on image resolution)
    n_gen_blocks: int

    # ### Lazy regularization
    # Instead of calculating the regularization losses, the paper proposes lazy regularization
    # where the regularization terms are calculated once in a while.
    # This improves the training efficiency a lot.

    # The interval at which to compute gradient penalty
    lazy_gradient_penalty_interval: int = 4
    # Path length penalty calculation interval
    lazy_path_penalty_interval: int = 32
    # Skip calculating path length penalty during the initial phase of training
    lazy_path_penalty_after: int = 5_000

    # How often to log generated images
    log_generated_interval: int = 500
    # How often to save model checkpoints
    save_checkpoint_interval: int = 2_000

    # Training mode state for logging activations
    mode: ModeState
    # Whether to log model layer outputs
    log_layer_outputs: bool = False

    # <a id="dataset_path"></a>
    # We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
    # You can find the download instruction in this
    # [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
    # Save the images inside `data/stylegan` folder.
    dataset_path: str = str(lab.get_data_path() / 'stylegan2')

    def init(self):
        """
        ### Initialize
        """
        # Create dataset
        dataset = Dataset(self.dataset_path, self.image_size)
        # Create data loader
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=self.batch_size,
                                                 num_workers=8,
                                                 shuffle=True,
                                                 drop_last=True,
                                                 pin_memory=True)
        # Continuous [cyclic loader](../../utils.html#cycle_dataloader)
        self.loader = cycle_dataloader(dataloader)

        # $\log_2$ of image resolution
        log_resolution = int(math.log2(self.image_size))

        # Create discriminator and generator
        self.discriminator = Discriminator(log_resolution).to(self.device)
        self.generator = Generator(log_resolution,
                                   self.d_latent).to(self.device)
        # Get number of generator blocks for creating style and noise inputs
        self.n_gen_blocks = self.generator.n_blocks
        # Create mapping network
        self.mapping_network = MappingNetwork(
            self.d_latent, self.mapping_network_layers).to(self.device)
        # Create path length penalty loss
        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)

        # Add model hooks to monitor layer outputs
        if self.log_layer_outputs:
            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
            hook_model_outputs(self.mode, self.generator, 'generator')
            hook_model_outputs(self.mode, self.mapping_network,
                               'mapping_network')

        # Discriminator and generator losses
        self.discriminator_loss = DiscriminatorLoss().to(self.device)
        self.generator_loss = GeneratorLoss().to(self.device)

        # Create optimizers
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=self.learning_rate,
            betas=self.adam_betas)
        self.generator_optimizer = torch.optim.Adam(
            self.generator.parameters(),
            lr=self.learning_rate,
            betas=self.adam_betas)
        self.mapping_network_optimizer = torch.optim.Adam(
            self.mapping_network.parameters(),
            lr=self.mapping_network_learning_rate,
            betas=self.adam_betas)

        # Set tracker configurations
        tracker.set_image("generated", True)

    def get_w(self, batch_size: int):
        """
        ### Sample $w$

        This samples $z$ randomly and get $w$ from the mapping network.

        We also apply style mixing sometimes where we generate two latent variables
        $z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$.
        Then we randomly sample a cross-over point and apply $w_1$ to
        the generator blocks before the cross-over point and
        $w_2$ to the blocks after.
        """

        # Mix styles
        if torch.rand(()).item() < self.style_mixing_prob:
            # Random cross-over point
            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
            # Sample $z_1$ and $z_2$
            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
            z1 = torch.randn(batch_size, self.d_latent).to(self.device)
            # Get $w_1$ and $w_2$
            w1 = self.mapping_network(z1)
            w2 = self.mapping_network(z2)
            # Expand $w_1$ and $w_2$ for the generator blocks and concatenate
            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point,
                                       -1, -1)
            return torch.cat((w1, w2), dim=0)
        # Without mixing
        else:
            # Sample $z$ and $z$
            z = torch.randn(batch_size, self.d_latent).to(self.device)
            # Get $w$ and $w$
            w = self.mapping_network(z)
            # Expand $w$ for the generator blocks
            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)

    def get_noise(self, batch_size: int):
        """
        ### Generate noise

        This generates noise for each [generator block](index.html#generator_block)
        """
        # List to store noise
        noise = []
        # Noise resolution starts from $4$
        resolution = 4

        # Generate noise for each generator block
        for i in range(self.n_gen_blocks):
            # The first block has only one $3 \times 3$ convolution
            if i == 0:
                n1 = None
            # Generate noise to add after the first convolution layer
            else:
                n1 = torch.randn(batch_size,
                                 1,
                                 resolution,
                                 resolution,
                                 device=self.device)
            # Generate noise to add after the second convolution layer
            n2 = torch.randn(batch_size,
                             1,
                             resolution,
                             resolution,
                             device=self.device)

            # Add noise tensors to the list
            noise.append((n1, n2))

            # Next block has $2 \times$ resolution
            resolution *= 2

        # Return noise tensors
        return noise

    def generate_images(self, batch_size: int):
        """
        ### Generate images

        This generate images using the generator
        """

        # Get $w$
        w = self.get_w(batch_size)
        # Get noise
        noise = self.get_noise(batch_size)

        # Generate images
        images = self.generator(w, noise)

        # Return images and $w$
        return images, w

    def step(self, idx: int):
        """
        ### Training Step
        """

        # Train the discriminator
        with monit.section('Discriminator'):
            # Reset gradients
            self.discriminator_optimizer.zero_grad()

            # Accumulate gradients for `gradient_accumulate_steps`
            for i in range(self.gradient_accumulate_steps):
                # Update `mode`. Set whether to log activation
                with self.mode.update(is_log_activations=(idx + 1) %
                                      self.log_generated_interval == 0):
                    # Sample images from generator
                    generated_images, _ = self.generate_images(self.batch_size)
                    # Discriminator classification for generated images
                    fake_output = self.discriminator(generated_images.detach())

                    # Get real images from the data loader
                    real_images = next(self.loader).to(self.device)
                    # We need to calculate gradients w.r.t. real images for gradient penalty
                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
                        real_images.requires_grad_()
                    # Discriminator classification for real images
                    real_output = self.discriminator(real_images)

                    # Get discriminator loss
                    real_loss, fake_loss = self.discriminator_loss(
                        real_output, fake_output)
                    disc_loss = real_loss + fake_loss

                    # Add gradient penalty
                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
                        # Calculate and log gradient penalty
                        gp = self.gradient_penalty(real_images, real_output)
                        tracker.add('loss.gp', gp)
                        # Multiply by coefficient and add gradient penalty
                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval

                    # Compute gradients
                    disc_loss.backward()

                    # Log discriminator loss
                    tracker.add('loss.discriminator', disc_loss)

            if (idx + 1) % self.log_generated_interval == 0:
                # Log discriminator model parameters occasionally
                tracker.add('discriminator', self.discriminator)

            # Clip gradients for stabilization
            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(),
                                           max_norm=1.0)
            # Take optimizer step
            self.discriminator_optimizer.step()

        # Train the generator
        with monit.section('Generator'):
            # Reset gradients
            self.generator_optimizer.zero_grad()
            self.mapping_network_optimizer.zero_grad()

            # Accumulate gradients for `gradient_accumulate_steps`
            for i in range(self.gradient_accumulate_steps):
                # Sample images from generator
                generated_images, w = self.generate_images(self.batch_size)
                # Discriminator classification for generated images
                fake_output = self.discriminator(generated_images)

                # Get generator loss
                gen_loss = self.generator_loss(fake_output)

                # Add path length penalty
                if idx > self.lazy_path_penalty_after and (
                        idx + 1) % self.lazy_path_penalty_interval == 0:
                    # Calculate path length penalty
                    plp = self.path_length_penalty(w, generated_images)
                    # Ignore if `nan`
                    if not torch.isnan(plp):
                        tracker.add('loss.plp', plp)
                        gen_loss = gen_loss + plp

                # Calculate gradients
                gen_loss.backward()

                # Log generator loss
                tracker.add('loss.generator', gen_loss)

            if (idx + 1) % self.log_generated_interval == 0:
                # Log discriminator model parameters occasionally
                tracker.add('generator', self.generator)
                tracker.add('mapping_network', self.mapping_network)

            # Clip gradients for stabilization
            torch.nn.utils.clip_grad_norm_(self.generator.parameters(),
                                           max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(),
                                           max_norm=1.0)

            # Take optimizer step
            self.generator_optimizer.step()
            self.mapping_network_optimizer.step()

        # Log generated images
        if (idx + 1) % self.log_generated_interval == 0:
            tracker.add(
                'generated',
                torch.cat([generated_images[:6], real_images[:3]], dim=0))
        # Save model checkpoints
        if (idx + 1) % self.save_checkpoint_interval == 0:
            experiment.save_checkpoint()

        # Flush tracker
        tracker.save()

    def train(self):
        """
        ## Train model
        """

        # Loop for `training_steps`
        for i in monit.loop(self.training_steps):
            # Take a training step
            self.step(i)
            #
            if (i + 1) % self.log_generated_interval == 0:
                tracker.new_line()
Пример #22
0
class Configs(MNISTConfigs, TrainValidConfigs):
    device: torch.device = DeviceConfigs()
    epochs: int = 10

    is_save_models = True
    discriminator: Module
    generator: Module
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam
    generator_loss: GeneratorLogitsLoss
    discriminator_loss: DiscriminatorLogitsLoss
    label_smoothing: float = 0.2
    discriminator_k: int = 1

    def init(self):
        self.state_modules = []
        self.generator = Generator().to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device)
        self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device)

        hook_model_outputs(self.mode, self.generator, 'generator')
        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
        tracker.set_scalar("loss.generator.*", True)
        tracker.set_scalar("loss.discriminator.*", True)
        tracker.set_image("generated", True, 1 / 100)

    def step(self, batch: Any, batch_idx: BatchIndex):
        self.generator.train(self.mode.is_train)
        self.discriminator.train(self.mode.is_train)

        data, target = batch[0].to(self.device), batch[1].to(self.device)

        # Increment step in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Train the discriminator
        with monit.section("discriminator"):
            for _ in range(self.discriminator_k):
                latent = torch.randn(data.shape[0], 100, device=self.device)
                logits_true = self.discriminator(data)
                logits_false = self.discriminator(self.generator(latent).detach())
                loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
                loss = loss_true + loss_false

                # Log stuff
                tracker.add("loss.discriminator.true.", loss_true)
                tracker.add("loss.discriminator.false.", loss_false)
                tracker.add("loss.discriminator.", loss)

                # Train
                if self.mode.is_train:
                    self.discriminator_optimizer.zero_grad()
                    loss.backward()
                    if batch_idx.is_last:
                        tracker.add('discriminator', self.discriminator)
                    self.discriminator_optimizer.step()

        # Train the generator
        with monit.section("generator"):
            latent = torch.randn(data.shape[0], 100, device=self.device)
            generated_images = self.generator(latent)
            logits = self.discriminator(generated_images)
            loss = self.generator_loss(logits)

            # Log stuff
            tracker.add('generated', generated_images[0:5])
            tracker.add("loss.generator.", loss)

            # Train
            if self.mode.is_train:
                self.generator_optimizer.zero_grad()
                loss.backward()
                if batch_idx.is_last:
                    tracker.add('generator', self.generator)
                self.generator_optimizer.step()

        tracker.save()
Пример #23
0
class Configs(MNISTConfigs, TrainValidConfigs):
    """
    ## Configurations

    This extends MNIST configurations to get the data loaders and Training and validation loop
    configurations to simplify our implementation.
    """

    device: torch.device = DeviceConfigs()
    dataset_transforms = 'mnist_gan_transforms'
    epochs: int = 10

    is_save_models = True
    discriminator: Module = 'mlp'
    generator: Module = 'mlp'
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam
    generator_loss: GeneratorLogitsLoss = 'original'
    discriminator_loss: DiscriminatorLogitsLoss = 'original'
    label_smoothing: float = 0.2
    discriminator_k: int = 1

    def init(self):
        """
        Initializations
        """
        self.state_modules = []

        hook_model_outputs(self.mode, self.generator, 'generator')
        hook_model_outputs(self.mode, self.discriminator, 'discriminator')
        tracker.set_scalar("loss.generator.*", True)
        tracker.set_scalar("loss.discriminator.*", True)
        tracker.set_image("generated", True, 1 / 100)

    def sample_z(self, batch_size: int):
        """
        $$z \sim p(z)$$
        """
        return torch.randn(batch_size, 100, device=self.device)

    def step(self, batch: Any, batch_idx: BatchIndex):
        """
        Take a training step
        """

        # Set model states
        self.generator.train(self.mode.is_train)
        self.discriminator.train(self.mode.is_train)

        # Get MNIST images
        data = batch[0].to(self.device)

        # Increment step in training mode
        if self.mode.is_train:
            tracker.add_global_step(len(data))

        # Train the discriminator
        with monit.section("discriminator"):
            # Get discriminator loss
            loss = self.calc_discriminator_loss(data)

            # Train
            if self.mode.is_train:
                self.discriminator_optimizer.zero_grad()
                loss.backward()
                if batch_idx.is_last:
                    tracker.add('discriminator', self.discriminator)
                self.discriminator_optimizer.step()

        # Train the generator once in every `discriminator_k`
        if batch_idx.is_interval(self.discriminator_k):
            with monit.section("generator"):
                loss = self.calc_generator_loss(data.shape[0])

                # Train
                if self.mode.is_train:
                    self.generator_optimizer.zero_grad()
                    loss.backward()
                    if batch_idx.is_last:
                        tracker.add('generator', self.generator)
                    self.generator_optimizer.step()

        tracker.save()

    def calc_discriminator_loss(self, data):
        """
        Calculate discriminator loss
        """
        latent = self.sample_z(data.shape[0])
        logits_true = self.discriminator(data)
        logits_false = self.discriminator(self.generator(latent).detach())
        loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
        loss = loss_true + loss_false

        # Log stuff
        tracker.add("loss.discriminator.true.", loss_true)
        tracker.add("loss.discriminator.false.", loss_false)
        tracker.add("loss.discriminator.", loss)

        return loss

    def calc_generator_loss(self, batch_size: int):
        """
        Calculate generator loss
        """
        latent =  self.sample_z(batch_size)
        generated_images = self.generator(latent)
        logits = self.discriminator(generated_images)
        loss = self.generator_loss(logits)

        # Log stuff
        tracker.add('generated', generated_images[0:6])
        tracker.add("loss.generator.", loss)

        return loss