class Configs(TrainValidConfigs): device = DeviceConfigs() model: Module text: TextDataset batch_size: int = 16 seq_len: int = 512 n_tokens: int n_layers: int = 2 dropout: float = 0.2 d_model: int = 512 rnn_size: int = 512 rhn_depth: int = 1 tokenizer: Callable inner_iterations = 100 is_save_models = True transformer: TransformerConfigs def run(self): for _ in self.training_loop: prompt = 'def train(' log = [(prompt, Text.subtle)] for i in monit.iterate('Sample', 25): data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) output, *_ = self.model(data) output = output.argmax(dim=-1).squeeze() prompt += '' + self.text.itos[output[-1]] log += [('' + self.text.itos[output[-1]], Text.value)] logger.log(log) self.run_step()
class Configs(MNISTConfigs, TrainValidConfigs): """ ## Configurable Experiment Definition """ optimizer: torch.optim.Adam model: nn.Module set_seed = SeedConfigs() device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True model: nn.Module inner_iterations = 10 accuracy_func = Accuracy() loss_func = nn.CrossEntropyLoss() def init(self): tracker.set_queue("loss.*", 20, True) tracker.set_scalar("accuracy.*", True) hook_model_outputs(self.mode, self.model, 'model') self.state_modules = [self.accuracy_func] def step(self, batch: any, batch_idx: BatchIndex): # Get the batch data, target = batch[0].to(self.device), batch[1].to(self.device) # Add global step if we are in training mode if self.mode.is_train: tracker.add_global_step(len(data)) # Run the model and specify whether to log the activations with self.mode.update(is_log_activations=batch_idx.is_last): output = self.model(data) # Calculate the loss loss = self.loss_func(output, target) # Calculate the accuracy self.accuracy_func(output, target) # Log the loss tracker.add("loss.", loss) # Optimize if we are in training mode if self.mode.is_train: # Calculate the gradients loss.backward() # Take optimizer step self.optimizer.step() # Log the parameter and gradient L2 norms once per epoch if batch_idx.is_last: tracker.add('model', self.model) tracker.add('optimizer', (self.optimizer, { 'model': self.model })) # Clear the gradients self.optimizer.zero_grad() # Save logs tracker.save()
class Configs(MNISTConfigs, TrainValidConfigs): """ Configurations with MNIST data and Train & Validation setup """ batch_step = 'capsule_network_batch_step' device: torch.device = DeviceConfigs() epochs: int = 10 model = 'capsule_network_model'
class Configs(MNISTConfigs, TrainValidConfigs): device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True discriminator: Module generator: Module generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam discriminator_loss = DiscriminatorLogitsLoss() generator_loss = GeneratorLogitsLoss() batch_step = 'gan_batch_step'
class Configs(MNISTConfigs, TrainValidConfigs): seed = SeedConfigs() device: torch.device = DeviceConfigs() epochs: int = 10 train_batch_size = 1 valid_batch_size = 1 is_save_models = True model: nn.Module loss_func = nn.CrossEntropyLoss() accuracy_func = SimpleAccuracy()
class Configs(MNISTConfigs, TrainValidConfigs): seed = SeedConfigs() device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True model: nn.Module learning_rate: float = 2.5e-4 momentum: float = 0.5 loss_func = 'cross_entropy_loss' accuracy_func = 'simple_accuracy'
class Configs(MNISTConfigs, TrainValidConfigs): device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True discriminator: Module generator: Module generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam generator_loss: GeneratorLogitsLoss discriminator_loss: DiscriminatorLogitsLoss batch_step = 'gan_batch_step' label_smoothing: float = 0.2 discriminator_k: int = 1
class Configs(TrainValidConfigs): """ ## Configurations The default configs can and will be over-ridden when we start the experiment """ transformer: TransformerConfigs model: AutoregressiveModel device: torch.device = DeviceConfigs() text: TextDataset batch_size: int = 20 seq_len: int = 32 n_tokens: int tokenizer: Callable = 'character' is_save_models = True prompt: str prompt_separator: str is_save_ff_input = False optimizer: torch.optim.Adam = 'transformer_optimizer' batch_step = 'auto_regression_batch_step' def sample(self): """ Sampling function to generate samples periodically while training """ prompt = self.prompt log = [(prompt, Text.subtle)] # Sample 25 tokens for i in monit.iterate('Sample', 25): # Tokenize the prompt data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) # Get the model output output = self.model(data) # Get the model prediction (greedy) output = output.argmax(dim=-1).squeeze() # Add the prediction to prompt prompt += self.prompt_separator + self.text.itos[output[-1]] # Add the prediction for logging log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] logger.log(log)
class Configs(MNISTConfigs, TrainValidConfigs): optimizer: torch.optim.Adam model: nn.Module set_seed = SeedConfigs() device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True model: nn.Module inner_iterations = 10 accuracy_func = Accuracy() loss_func = nn.CrossEntropyLoss() def init(self): tracker.set_queue("loss.*", 20, True) tracker.set_scalar("accuracy.*", True) hook_model_outputs(self.mode, self.model, 'model') self.state_modules = [self.accuracy_func] def step(self, batch: any, batch_idx: BatchIndex): data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(len(data)) with self.mode.update(is_log_activations=batch_idx.is_last): output = self.model(data) loss = self.loss_func(output, target) self.accuracy_func(output, target) tracker.add("loss.", loss) if self.mode.is_train: loss.backward() self.optimizer.step() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save()
class Configs(TrainValidConfigs): device = DeviceConfigs() model: Module text: TextDataset batch_size: int = 16 seq_len: int = 512 n_tokens: int d_model: int = 512 n_layers: int = 2 dropout: float = 0.2 d_lstm: int = 512 tokenizer: Callable is_save_models = True transformer: TransformerConfigs def run(self): for _ in self.training_loop: prompt = 'def train(' log = [(prompt, Text.subtle)] for i in monit.iterate('Sample', 25): data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) output, *_ = self.model(data) output = output.argmax(dim=-1).squeeze() prompt += '' + self.text.itos[output[-1]] log += [('' + self.text.itos[output[-1]], Text.value)] logger.log(log) with Mode(is_train=True, is_log_parameters=self.is_log_parameters, is_log_activations=self.is_log_activations): with tracker.namespace('train'): self.trainer() with tracker.namespace('valid'): self.validator()
class NLPAutoRegressionConfigs(TrainValidConfigs): optimizer: torch.optim.Adam device: torch.device = DeviceConfigs() model: Module text: TextDataset batch_size: int = 16 seq_len: int = 512 n_tokens: int tokenizer: Callable = 'character' prompt: str prompt_separator: str is_save_models = True loss_func = CrossEntropyLoss() accuracy = Accuracy() d_model: int = 512 def init(self): tracker.set_scalar("accuracy.*", True) tracker.set_scalar("loss.*", True) hook_model_outputs(self.mode, self.model, 'model') self.state_modules = [self.accuracy] def step(self, batch: any, batch_idx: BatchIndex): data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(data.shape[0] * data.shape[1]) with self.mode.update(is_log_activations=batch_idx.is_last): output, *_ = self.model(data) loss = self.loss_func(output, target) self.accuracy(output, target) self.accuracy.track() tracker.add("loss.", loss) if self.mode.is_train: loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.) self.optimizer.step() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save() def sample(self): """ Sampling function to generate samples periodically while training """ prompt = self.prompt log = [(prompt, Text.subtle)] # Sample 25 tokens for i in monit.iterate('Sample', 25): # Tokenize the prompt data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) # Get the model output output, *_ = self.model(data) # Get the model prediction (greedy) output = output.argmax(dim=-1).squeeze() # Add the prediction to prompt prompt += self.prompt_separator + self.text.itos[output[-1]] # Add the prediction for logging log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] logger.log(log)
class Configs(BaseConfigs): """## Configurations""" # `DeviceConfigs` will pick a GPU if available device: torch.device = DeviceConfigs() # Hyper-parameters epochs: int = 200 dataset_name: str = 'monet2photo' batch_size: int = 1 data_loader_workers = 8 learning_rate = 0.0002 adam_betas = (0.5, 0.999) decay_start = 100 # The paper suggests using a least-squares loss instead of # negative log-likelihood, at it is found to be more stable. gan_loss = torch.nn.MSELoss() # L1 loss is used for cycle loss and identity loss cycle_loss = torch.nn.L1Loss() identity_loss = torch.nn.L1Loss() # Image dimensions img_height = 256 img_width = 256 img_channels = 3 # Number of residual blocks in the generator n_residual_blocks = 9 # Loss coefficients cyclic_loss_coefficient = 10.0 identity_loss_coefficient = 5. sample_interval = 500 # Models generator_xy: GeneratorResNet generator_yx: GeneratorResNet discriminator_x: Discriminator discriminator_y: Discriminator # Optimizers generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam # Learning rate schedules generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR # Data loaders dataloader: DataLoader valid_dataloader: DataLoader def sample_images(self, n: int): """Generate samples from test set and save them""" batch = next(iter(self.valid_dataloader)) self.generator_xy.eval() self.generator_yx.eval() with torch.no_grad(): data_x, data_y = batch['x'].to( self.generator_xy.device), batch['y'].to( self.generator_yx.device) gen_y = self.generator_xy(data_x) gen_x = self.generator_yx(data_y) # Arrange images along x-axis data_x = make_grid(data_x, nrow=5, normalize=True) data_y = make_grid(data_y, nrow=5, normalize=True) gen_x = make_grid(gen_x, nrow=5, normalize=True) gen_y = make_grid(gen_y, nrow=5, normalize=True) # Arrange images along y-axis image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) # Show samples plot_image(image_grid) def initialize(self): """ ## Initialize models and data loaders """ input_shape = (self.img_channels, self.img_height, self.img_width) # Create the models self.generator_xy = GeneratorResNet( self.img_channels, self.n_residual_blocks).to(self.device) self.generator_yx = GeneratorResNet( self.img_channels, self.n_residual_blocks).to(self.device) self.discriminator_x = Discriminator(input_shape).to(self.device) self.discriminator_y = Discriminator(input_shape).to(self.device) # Create the optmizers self.generator_optimizer = torch.optim.Adam(itertools.chain( self.generator_xy.parameters(), self.generator_yx.parameters()), lr=self.learning_rate, betas=self.adam_betas) self.discriminator_optimizer = torch.optim.Adam(itertools.chain( self.discriminator_x.parameters(), self.discriminator_y.parameters()), lr=self.learning_rate, betas=self.adam_betas) # Create the learning rate schedules. # The learning rate stars flat until `decay_start` epochs, # and then linearly reduce to $0$ at end of training. decay_epochs = self.epochs - self.decay_start self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start ) / decay_epochs) self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start ) / decay_epochs) # Image transformations transforms_ = [ transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), transforms.RandomCrop((self.img_height, self.img_width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] # Training data loader self.dataloader = DataLoader( ImageDataset(self.dataset_name, transforms_, 'train'), batch_size=self.batch_size, shuffle=True, num_workers=self.data_loader_workers, ) # Validation data loader self.valid_dataloader = DataLoader( ImageDataset(self.dataset_name, transforms_, "test"), batch_size=5, shuffle=True, num_workers=self.data_loader_workers, ) def run(self): """ ## Training We aim to solve: $$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$ where, $G$ translates images from $X \rightarrow Y$, $F$ translates images from $Y \rightarrow X$, $D_X$ tests if images are from $X$ space, $D_Y$ tests if images are from $Y$ space, and \begin{align} \mathcal{L}(G, F, D_X, D_Y) &= \mathcal{L}_{GAN}(G, D_Y, X, Y) \\ &+ \mathcal{L}_{GAN}(F, D_X, Y, X) \\ &+ \lambda_1 \mathcal{L}_{cyc}(G, F) \\ &+ \lambda_2 \mathcal{L}_{identity}(G, F) \\ \\ \mathcal{L}_{GAN}(G, F, D_Y, X, Y) &= \mathbb{E}_{y \sim p_{data}(y)} \Big[log D_Y(y)\Big] \\ &+ \mathbb{E}_{x \sim p_{data}(x)} \bigg[log\Big(1 - D_Y(G(x))\Big)\bigg] \\ &+ \mathbb{E}_{x \sim p_{data}(x)} \Big[log D_X(x)\Big] \\ &+ \mathbb{E}_{y \sim p_{data}(y)} \bigg[log\Big(1 - D_X(F(y))\Big)\bigg] \\ \\ \mathcal{L}_{cyc}(G, F) &= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(G(x)) - x \lVert_1\Big] \\ &+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(F(y)) - y \rVert_1\Big] \\ \\ \mathcal{L}_{identity}(G, F) &= \mathbb{E}_{x \sim p_{data}(x)} \Big[\lVert F(x) - x \lVert_1\Big] \\ &+ \mathbb{E}_{y \sim p_{data}(y)} \Big[\lVert G(y) - y \rVert_1\Big] \\ \end{align} $\mathcal{L}_{GAN}$ is the generative adversarial loss from the original GAN paper. $\mathcal{L}_{cyc}$ is the cyclic loss, where we try to get $F(G(x))$ to be similar to $x$, and $G(F(y))$ to be similar to $y$. Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss $G(x)$ could generate anything that's from the distribution of $Y$. Now it needs to generate something from the distribution of $Y$ but still has properties of $x$, so that $F(G(x)$ can re-generate something like $x$. $\mathcal{L}_{cyc}$ is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output. To solve $G^{\*}, F^{\*}$, discriminators $D_X$ and $D_Y$ should **ascend** on the gradient, \begin{align} \nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m &\Bigg[ \log D_Y\Big(y^{(i)}\Big) \\ &+ \log \Big(1 - D_Y\Big(G\Big(x^{(i)}\Big)\Big)\Big) \\ &+ \log D_X\Big(x^{(i)}\Big) \\ & +\log\Big(1 - D_X\Big(F\Big(y^{(i)}\Big)\Big)\Big) \Bigg] \end{align} That is descend on *negative* log-likelihood loss. In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient, \begin{align} \nabla_{\theta_{D_X, D_Y}} \frac{1}{m} \sum_{i=1}^m &\Bigg[ \bigg(D_Y\Big(y^{(i)}\Big) - 1\bigg)^2 \\ &+ D_Y\Big(G\Big(x^{(i)}\Big)\Big)^2 \\ &+ \bigg(D_X\Big(x^{(i)}\Big) - 1\bigg)^2 \\ &+ D_X\Big(F\Big(y^{(i)}\Big)\Big)^2 \Bigg] \end{align} We use least-squares for generators also. The generators should *descend* on the gradient, \begin{align} \nabla_{\theta_{F, G}} \frac{1}{m} \sum_{i=1}^m &\Bigg[ \bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 \\ &+ \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2 \\ &+ \mathcal{L}_{cyc}(G, F) + \mathcal{L}_{identity}(G, F) \Bigg] \end{align} We use `generator_xy` for $G$ and `generator_yx$ for $F$. We use `discriminator_x$ for $D_X$ and `discriminator_y` for $D_Y$. """ # Replay buffers to keep generated samples gen_x_buffer = ReplayBuffer() gen_y_buffer = ReplayBuffer() # Loop through epochs for epoch in monit.loop(self.epochs): # Loop through the dataset for i, batch in monit.enum('Train', self.dataloader): # Move images to the device data_x, data_y = batch['x'].to(self.device), batch['y'].to( self.device) # true labels equal to $1$ true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape, device=self.device, requires_grad=False) # false labels equal to $0$ false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape, device=self.device, requires_grad=False) # Train the generators. # This returns the generated images. gen_x, gen_y = self.optimize_generators( data_x, data_y, true_labels) # Train discriminators self.optimize_discriminator(data_x, data_y, gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y), true_labels, false_labels) # Save training statistics and increment the global step counter tracker.save() tracker.add_global_step(max(len(data_x), len(data_y))) # Save images at intervals batches_done = epoch * len(self.dataloader) + i if batches_done % self.sample_interval == 0: # Save models when sampling images experiment.save_checkpoint() # Sample images self.sample_images(batches_done) # Update learning rates self.generator_lr_scheduler.step() self.discriminator_lr_scheduler.step() # New line tracker.new_line() def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor): """ ### Optimize the generators with identity, gan and cycle losses. """ # Change to training mode self.generator_xy.train() self.generator_yx.train() # Identity loss # $$\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\ # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1$$ loss_identity = ( self.identity_loss(self.generator_yx(data_x), data_x) + self.identity_loss(self.generator_xy(data_y), data_y)) # Generate images $G(x)$ and $F(y)$ gen_y = self.generator_xy(data_x) gen_x = self.generator_yx(data_y) # GAN loss # $$\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 # + \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2$$ loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) + self.gan_loss(self.discriminator_x(gen_x), true_labels)) # Cycle loss # $$ # \lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 + # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1 # $$ loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) + self.cycle_loss(self.generator_xy(gen_x), data_y)) # Total loss loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle + self.identity_loss_coefficient * loss_identity) # Take a step in the optimizer self.generator_optimizer.zero_grad() loss_generator.backward() self.generator_optimizer.step() # Log losses tracker.add({ 'loss.generator': loss_generator, 'loss.generator.cycle': loss_cycle, 'loss.generator.gan': loss_gan, 'loss.generator.identity': loss_identity }) # Return generated images return gen_x, gen_y def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor, gen_x: torch.Tensor, gen_y: torch.Tensor, true_labels: torch.Tensor, false_labels: torch.Tensor): """ ### Optimize the discriminators with gan loss. """ # GAN Loss # \begin{align} # \bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2 # + D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\ # \bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2 # + D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2 # \end{align} loss_discriminator = ( self.gan_loss(self.discriminator_x(data_x), true_labels) + self.gan_loss(self.discriminator_x(gen_x), false_labels) + self.gan_loss(self.discriminator_y(data_y), true_labels) + self.gan_loss(self.discriminator_y(gen_y), false_labels)) # Take a step in the optimizer self.discriminator_optimizer.zero_grad() loss_discriminator.backward() self.discriminator_optimizer.step() # Log losses tracker.add({'loss.discriminator': loss_discriminator})
class NLPAutoRegressionConfigs(TrainValidConfigs): """ <a id="NLPAutoRegressionConfigs"></a> ## Trainer configurations This has the basic configurations for NLP auto-regressive task training. All the properties are configurable. """ # Optimizer optimizer: torch.optim.Adam # Training device device: torch.device = DeviceConfigs() # Autoregressive model model: Module # Text dataset text: TextDataset # Batch size batch_size: int = 16 # Length of the sequence, or context size seq_len: int = 512 # Number of token in vocabulary n_tokens: int # Tokenizer tokenizer: Callable = 'character' # Text prompt to start sampling (for illustration) prompt: str # The token separator when sampling (blank for character level tokenization) prompt_separator: str # Whether to periodically save models is_save_models = True # Loss function loss_func = CrossEntropyLoss() # Accuracy function accuracy = Accuracy() # Model embedding size d_model: int = 512 # Gradient clipping grad_norm_clip: float = 1.0 # Training data loader train_loader: DataLoader = 'shuffled_train_loader' # Validation data loader valid_loader: DataLoader = 'shuffled_valid_loader' # Data loaders shuffle with replacement dataloader_shuffle_with_replacement: bool = False # Whether to log model parameters and gradients (once per epoch). # These are summarized stats per layer, but it could still lead # to many indicators for very deep networks. is_log_model_params_grads: bool = False # Whether to log model activations (once per epoch). # These are summarized stats per layer, but it could still lead # to many indicators for very deep networks. is_log_model_activations: bool = False def init(self): """ ### Initialization """ # Set tracker configurations tracker.set_scalar("accuracy.*", True) tracker.set_scalar("loss.*", True) # Add a hook to log module outputs hook_model_outputs(self.mode, self.model, 'model') # Add accuracy as a state module. # The name is probably confusing, since it's meant to store # states between training and validation for RNNs. # This will keep the accuracy metric stats separate for training and validation. self.state_modules = [self.accuracy] def other_metrics(self, output: torch.Tensor, target: torch.Tensor): """Override to calculate and log other metrics""" pass def step(self, batch: any, batch_idx: BatchIndex): """ ### Training or validation step """ # Set training/eval mode self.model.train(self.mode.is_train) # Move data to the device data, target = batch[0].to(self.device), batch[1].to(self.device) # Update global step (number of tokens processed) when in training mode if self.mode.is_train: tracker.add_global_step(data.shape[0] * data.shape[1]) # Whether to capture model outputs with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): # Get model outputs. # It's returning a tuple for states when using RNNs. # This is not implemented yet. 😜 output, *_ = self.model(data) # Calculate and log loss loss = self.loss_func(output, target) tracker.add("loss.", loss) # Calculate and log accuracy self.accuracy(output, target) self.accuracy.track() self.other_metrics(output, target) # Train the model if self.mode.is_train: # Calculate gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) # Take optimizer step self.optimizer.step() # Log the model parameters and gradients on last batch of every epoch if batch_idx.is_last and self.is_log_model_params_grads: tracker.add('model', self.model) # Clear the gradients self.optimizer.zero_grad() # Save the tracked metrics tracker.save() def sample(self): """ ### Sampling function to generate samples periodically while training """ # Starting prompt prompt = self.prompt # Collect output for printing log = [(prompt, Text.subtle)] # Sample 25 tokens for i in monit.iterate('Sample', 25): # Tokenize the prompt data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) # Get the model output output, *_ = self.model(data) # Get the model prediction (greedy) output = output.argmax(dim=-1).squeeze() # Add the prediction to prompt prompt += self.prompt_separator + self.text.itos[output[-1]] # Add the prediction for logging log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] # Print the sampled output logger.log(log)
class Configs(TrainValidConfigs): optimizer: torch.optim.Adam device: torch.device = DeviceConfigs() model: Module text = SourceCodeDataConfigs() n_tokens: int n_layers: int = 2 dropout: float = 0.2 d_model: int = 512 rnn_size: int = 512 rhn_depth: int = 1 inner_iterations = 100 is_save_models = True transformer: TransformerConfigs accuracy = Accuracy() loss_func: 'CrossEntropyLoss' state_updater: 'StateUpdater' state = SimpleStateModule() mem_len: int = 512 grad_norm_clip: float = 1.0 is_token_by_token: bool = False def init(self): tracker.set_queue("loss.*", 20, True) tracker.set_scalar("accuracy.*", True) hook_model_outputs(self.mode, self.model, 'model') self.state_modules = [self.accuracy, self.state] def step(self, batch: any, batch_idx: BatchIndex): data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(target.shape[0] * target.shape[1]) with self.mode.update(is_log_activations=batch_idx.is_last): state = self.state.get() output, new_state = self.model(data, state) state = self.state_updater(state, new_state) self.state.set(state) loss = self.loss_func(output, target) tracker.add("loss.", loss) self.accuracy(output, target) self.accuracy.track() if self.mode.is_train: loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) self.optimizer.step() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save() def sample(self): prompt = 'def train(' log = [(prompt, Text.subtle)] state = None for i in monit.iterate('Sample', 25): data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) output, new_state = self.model(data, state) output = output.argmax(dim=-1).squeeze(1) prompt += '' + self.text.tokenizer.itos[output[-1]] if self.is_token_by_token: prompt = self.text.tokenizer.itos[output[-1]] else: prompt += '' + self.text.tokenizer.itos[output[-1]] log += [('' + self.text.tokenizer.itos[output[-1]], Text.value)] state = self.state_updater(state, new_state) logger.log(log)
class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs): """ <a id="MNISTConfigs"> ## Trainer configurations </a> """ # Optimizer optimizer: torch.optim.Adam # Training device device: torch.device = DeviceConfigs() # Classification model model: Module # Number of epochs to train for epochs: int = 10 # Number of times to switch between training and validation within an epoch inner_iterations = 10 # Accuracy function accuracy = Accuracy() # Loss function loss_func = nn.CrossEntropyLoss() def init(self): """ ### Initialization """ # Set tracker configurations tracker.set_scalar("loss.*", True) tracker.set_scalar("accuracy.*", True) # Add a hook to log module outputs hook_model_outputs(self.mode, self.model, 'model') # Add accuracy as a state module. # The name is probably confusing, since it's meant to store # states between training and validation for RNNs. # This will keep the accuracy metric stats separate for training and validation. self.state_modules = [self.accuracy] def step(self, batch: any, batch_idx: BatchIndex): """ ### Training or validation step """ # Training/Evaluation mode self.model.train(self.mode.is_train) # Move data to the device data, target = batch[0].to(self.device), batch[1].to(self.device) # Update global step (number of samples processed) when in training mode if self.mode.is_train: tracker.add_global_step(len(data)) # Whether to capture model outputs with self.mode.update(is_log_activations=batch_idx.is_last): # Get model outputs. output = self.model(data) # Calculate and log loss loss = self.loss_func(output, target) tracker.add("loss.", loss) # Calculate and log accuracy self.accuracy(output, target) self.accuracy.track() # Train the model if self.mode.is_train: # Calculate gradients loss.backward() # Take optimizer step self.optimizer.step() # Log the model parameters and gradients on last batch of every epoch if batch_idx.is_last: tracker.add('model', self.model) # Clear the gradients self.optimizer.zero_grad() # Save the tracked metrics tracker.save()
class Configs(TrainValidConfigs): """ ## Configurations These are default configurations which can later be adjusted by passing a `dict`. """ # Device configurations to pick the device to run the experiment device: torch.device = DeviceConfigs() # encoder: EncoderRNN decoder: DecoderRNN optimizer: optim.Adam sampler: Sampler dataset_name: str train_loader: DataLoader valid_loader: DataLoader train_dataset: StrokesDataset valid_dataset: StrokesDataset # Encoder and decoder sizes enc_hidden_size = 256 dec_hidden_size = 512 # Batch size batch_size = 100 # Number of features in $z$ d_z = 128 # Number of distributions in the mixture, $M$ n_distributions = 20 # Weight of KL divergence loss, $w_{KL}$ kl_div_loss_weight = 0.5 # Gradient clipping grad_clip = 1. # Temperature $\tau$ for sampling temperature = 0.4 # Filter out stroke sequences longer than $200$ max_seq_length = 200 epochs = 100 kl_div_loss = KLDivLoss() reconstruction_loss = ReconstructionLoss() def init(self): # Initialize encoder & decoder self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) # Set optimizer. Things like type of optimizer and learning rate are configurable optimizer = OptimizerConfigs() optimizer.parameters = list(self.encoder.parameters()) + list( self.decoder.parameters()) self.optimizer = optimizer # Create sampler self.sampler = Sampler(self.encoder, self.decoder) # `npz` file path is `data/sketch/[DATASET NAME].npz` path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' # Load the numpy file dataset = np.load(str(path), encoding='latin1', allow_pickle=True) # Create training dataset self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) # Create validation dataset self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) # Create training data loader self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) # Create validation data loader self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) # Add hooks to monitor layer outputs on Tensorboard hook_model_outputs(self.mode, self.encoder, 'encoder') hook_model_outputs(self.mode, self.decoder, 'decoder') # Configure the tracker to print the total train/validation loss tracker.set_scalar("loss.total.*", True) self.state_modules = [] def step(self, batch: Any, batch_idx: BatchIndex): self.encoder.train(self.mode.is_train) self.decoder.train(self.mode.is_train) # Move `data` and `mask` to device and swap the sequence and batch dimensions. # `data` will have shape `[seq_len, batch_size, 5]` and # `mask` will have shape `[seq_len, batch_size]`. data = batch[0].to(self.device).transpose(0, 1) mask = batch[1].to(self.device).transpose(0, 1) # Increment step in training mode if self.mode.is_train: tracker.add_global_step(len(data)) # Encode the sequence of strokes with monit.section("encoder"): # Get $z$, $\mu$, and $\hat{\sigma}$ z, mu, sigma_hat = self.encoder(data) # Decode the mixture of distributions and $\hat{q}$ with monit.section("decoder"): # Concatenate $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1) inputs = torch.cat([data[:-1], z_stack], 2) # Get mixture of distributions and $\hat{q}$ dist, q_logits, _ = self.decoder(inputs, z, None) # Compute the loss with monit.section('loss'): # $L_{KL}$ kl_loss = self.kl_div_loss(sigma_hat, mu) # $L_R$ reconstruction_loss = self.reconstruction_loss( mask, data[1:], dist, q_logits) # $Loss = L_R + w_{KL} L_{KL}$ loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss # Track losses tracker.add("loss.kl.", kl_loss) tracker.add("loss.reconstruction.", reconstruction_loss) tracker.add("loss.total.", loss) # Only if we are in training state if self.mode.is_train: # Run optimizer with monit.section('optimize'): # Set `grad` to zero self.optimizer.zero_grad() # Compute gradients loss.backward() # Log model parameters and gradients if batch_idx.is_last: tracker.add(encoder=self.encoder, decoder=self.decoder) # Clip gradients nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip) nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) # Optimize self.optimizer.step() tracker.save() def sample(self): # Randomly pick a sample from validation dataset to encoder data, *_ = self.valid_dataset[np.random.choice(len( self.valid_dataset))] # Add batch dimension and move it to device data = data.unsqueeze(1).to(self.device) # Sample self.sampler.sample(data, self.temperature)
class Configs(BaseConfigs): """ ## Configurations """ # Device to train the model on. # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() # U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$ eps_model: UNet # [DDPM algorithm](index.html) diffusion: DenoiseDiffusion # Number of channels in the image. $3$ for RGB. image_channels: int = 3 # Image size image_size: int = 32 # Number of channels in the initial feature map n_channels: int = 64 # The list of channel numbers at each resolution. # The number of channels is `channel_multipliers[i] * n_channels` channel_multipliers: List[int] = [1, 2, 2, 4] # The list of booleans that indicate whether to use attention at each resolution is_attention: List[int] = [False, False, False, True] # Number of time steps $T$ n_steps: int = 1_000 # Batch size batch_size: int = 64 # Number of samples to generate n_samples: int = 16 # Learning rate learning_rate: float = 2e-5 # Number of training epochs epochs: int = 1_000 # Dataset dataset: torch.utils.data.Dataset # Dataloader data_loader: torch.utils.data.DataLoader # Adam optimizer optimizer: torch.optim.Adam def init(self): # Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model self.eps_model = UNet( image_channels=self.image_channels, n_channels=self.n_channels, ch_mults=self.channel_multipliers, is_attn=self.is_attention, ).to(self.device) # Create [DDPM class](index.html) self.diffusion = DenoiseDiffusion( eps_model=self.eps_model, n_steps=self.n_steps, device=self.device, ) # Create dataloader self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True) # Create optimizer self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate) # Image logging tracker.set_image("sample", True) def sample(self): """ ### Sample images """ with torch.no_grad(): # $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$ x = torch.randn([ self.n_samples, self.image_channels, self.image_size, self.image_size ], device=self.device) # Remove noise for $T$ steps for t_ in monit.iterate('Sample', self.n_steps): # $t$ t = self.n_steps - t_ - 1 # Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$ x = self.diffusion.p_sample( x, x.new_full((self.n_samples, ), t, dtype=torch.long)) # Log samples tracker.save('sample', x) def train(self): """ ### Train """ # Iterate through the dataset for data in monit.iterate('Train', self.data_loader): # Increment global step tracker.add_global_step() # Move data to device data = data.to(self.device) # Make the gradients zero self.optimizer.zero_grad() # Calculate loss loss = self.diffusion.loss(data) # Compute gradients loss.backward() # Take an optimization step self.optimizer.step() # Track the loss tracker.save('loss', loss) def run(self): """ ### Training loop """ for _ in monit.loop(self.epochs): # Train the model self.train() # Sample some images self.sample() # New line in the console tracker.new_line() # Save the model experiment.save_checkpoint()
class Configs(TrainValidConfigs): optimizer: torch.optim.Adam device: torch.device = DeviceConfigs() model: Module text: TextDataset batch_size: int = 16 seq_len: int = 512 n_tokens: int n_layers: int = 2 dropout: float = 0.2 d_model: int = 512 rnn_size: int = 512 rhn_depth: int = 1 tokenizer: Callable inner_iterations = 100 is_save_models = True transformer: TransformerConfigs accuracy_func = Accuracy() loss_func: 'CrossEntropyLoss' def init(self): tracker.set_queue("loss.*", 20, True) tracker.set_scalar("accuracy.*", True) hook_model_outputs(self.mode, self.model, 'model') self.state_modules = [self.accuracy_func] def step(self, batch: any, batch_idx: BatchIndex): data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(len(data)) with self.mode.update(is_log_activations=batch_idx.is_last): output, *_ = self.model(data) loss = self.loss_func(output, target) self.accuracy_func(output, target) self.accuracy_func.track() tracker.add("loss.", loss) if self.mode.is_train: loss.backward() self.optimizer.step() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save() def sample(self): prompt = 'def train(' log = [(prompt, Text.subtle)] for i in monit.iterate('Sample', 25): data = self.text.text_to_i(prompt).unsqueeze(-1) data = data.to(self.device) output, *_ = self.model(data) output = output.argmax(dim=-1).squeeze() prompt += '' + self.text.itos[output[-1]] log += [('' + self.text.itos[output[-1]], Text.value)] logger.log(log)
class Configs(BaseConfigs): """ ## Configurations """ # Model model: GAT # Number of nodes to train on training_samples: int = 500 # Number of features per node in the input in_features: int # Number of features in the first graph attention layer n_hidden: int = 64 # Number of heads n_heads: int = 8 # Number of classes for classification n_classes: int # Dropout probability dropout: float = 0.6 # Whether to include the citation network include_edges: bool = True # Dataset dataset: CoraDataset # Number of training iterations epochs: int = 1_000 # Loss function loss_func = nn.CrossEntropyLoss() # Device to train on # # This creates configs for device, so that # we can change the device by passing a config value device: torch.device = DeviceConfigs() # Optimizer optimizer: torch.optim.Adam def run(self): """ ### Training loop We do full batch training since the dataset is small. If we were to sample and train we will have to sample a set of nodes for each training step along with the edges that span across those selected nodes. """ # Move the feature vectors to the device features = self.dataset.features.to(self.device) # Move the labels to the device labels = self.dataset.labels.to(self.device) # Move the adjacency matrix to the device edges_adj = self.dataset.adj_mat.to(self.device) # Add an empty third dimension for the heads edges_adj = edges_adj.unsqueeze(-1) # Random indexes idx_rand = torch.randperm(len(labels)) # Nodes for training idx_train = idx_rand[:self.training_samples] # Nodes for validation idx_valid = idx_rand[self.training_samples:] # Training loop for epoch in monit.loop(self.epochs): # Set the model to training mode self.model.train() # Make all the gradients zero self.optimizer.zero_grad() # Evaluate the model output = self.model(features, edges_adj) # Get the loss for training nodes loss = self.loss_func(output[idx_train], labels[idx_train]) # Calculate gradients loss.backward() # Take optimization step self.optimizer.step() # Log the loss tracker.add('loss.train', loss) # Log the accuracy tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train])) # Set mode to evaluation mode for validation self.model.eval() # No need to compute gradients with torch.no_grad(): # Evaluate the model again output = self.model(features, edges_adj) # Calculate the loss for validation nodes loss = self.loss_func(output[idx_valid], labels[idx_valid]) # Log the loss tracker.add('loss.valid', loss) # Log the accuracy tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid])) # Save logs tracker.save()
class NLPClassificationConfigs(TrainValidConfigs): """ <a id="NLPClassificationConfigs"></a> ## Trainer configurations This has the basic configurations for NLP classification task training. All the properties are configurable. """ # Optimizer optimizer: torch.optim.Adam # Training device device: torch.device = DeviceConfigs() # Autoregressive model model: Module # Batch size batch_size: int = 16 # Length of the sequence, or context size seq_len: int = 512 # Vocabulary vocab: Vocab = 'ag_news' # Number of token in vocabulary n_tokens: int # Number of classes n_classes: int = 'ag_news' # Tokenizer tokenizer: Callable = 'character' # Whether to periodically save models is_save_models = True # Loss function loss_func = nn.CrossEntropyLoss() # Accuracy function accuracy = Accuracy() # Model embedding size d_model: int = 512 # Gradient clipping grad_norm_clip: float = 1.0 # Training data loader train_loader: DataLoader = 'ag_news' # Validation data loader valid_loader: DataLoader = 'ag_news' # Whether to log model parameters and gradients (once per epoch). # These are summarized stats per layer, but it could still lead # to many indicators for very deep networks. is_log_model_params_grads: bool = False # Whether to log model activations (once per epoch). # These are summarized stats per layer, but it could still lead # to many indicators for very deep networks. is_log_model_activations: bool = False def init(self): """ ### Initialization """ # Set tracker configurations tracker.set_scalar("accuracy.*", True) tracker.set_scalar("loss.*", True) # Add a hook to log module outputs hook_model_outputs(self.mode, self.model, 'model') # Add accuracy as a state module. # The name is probably confusing, since it's meant to store # states between training and validation for RNNs. # This will keep the accuracy metric stats separate for training and validation. self.state_modules = [self.accuracy] def step(self, batch: any, batch_idx: BatchIndex): """ ### Training or validation step """ # Move data to the device data, target = batch[0].to(self.device), batch[1].to(self.device) # Update global step (number of tokens processed) when in training mode if self.mode.is_train: tracker.add_global_step(data.shape[1]) # Whether to capture model outputs with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): # Get model outputs. # It's returning a tuple for states when using RNNs. # This is not implemented yet. 😜 output, *_ = self.model(data) # Calculate and log loss loss = self.loss_func(output, target) tracker.add("loss.", loss) # Calculate and log accuracy self.accuracy(output, target) self.accuracy.track() # Train the model if self.mode.is_train: # Calculate gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) # Take optimizer step self.optimizer.step() # Log the model parameters and gradients on last batch of every epoch if batch_idx.is_last and self.is_log_model_params_grads: tracker.add('model', self.model) # Clear the gradients self.optimizer.zero_grad() # Save the tracked metrics tracker.save()
class Configs(BaseConfigs): """ ## Configurations """ # Device to train the model on. # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() # [StyleGAN2 Discriminator](index.html#discriminator) discriminator: Discriminator # [StyleGAN2 Generator](index.html#generator) generator: Generator # [Mapping network](index.html#mapping_network) mapping_network: MappingNetwork # Discriminator and generator loss functions. # We use [Wasserstein loss](../wasserstein/index.html) discriminator_loss: DiscriminatorLoss generator_loss: GeneratorLoss # Optimizers generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam mapping_network_optimizer: torch.optim.Adam # [Gradient Penalty Regularization Loss](index.html#gradient_penalty) gradient_penalty = GradientPenalty() # Gradient penalty coefficient $\gamma$ gradient_penalty_coefficient: float = 10. # [Path length penalty](index.html#path_length_penalty) path_length_penalty: PathLengthPenalty # Data loader loader: Iterator # Batch size batch_size: int = 32 # Dimensionality of $z$ and $w$ d_latent: int = 512 # Height/width of the image image_size: int = 32 # Number of layers in the mapping network mapping_network_layers: int = 8 # Generator & Discriminator learning rate learning_rate: float = 1e-3 # Mapping network learning rate ($100 \times$ lower than the others) mapping_network_learning_rate: float = 1e-5 # Number of steps to accumulate gradients on. Use this to increase the effective batch size. gradient_accumulate_steps: int = 1 # $\beta_1$ and $\beta_2$ for Adam optimizer adam_betas: Tuple[float, float] = (0.0, 0.99) # Probability of mixing styles style_mixing_prob: float = 0.9 # Total number of training steps training_steps: int = 150_000 # Number of blocks in the generator (calculated based on image resolution) n_gen_blocks: int # ### Lazy regularization # Instead of calculating the regularization losses, the paper proposes lazy regularization # where the regularization terms are calculated once in a while. # This improves the training efficiency a lot. # The interval at which to compute gradient penalty lazy_gradient_penalty_interval: int = 4 # Path length penalty calculation interval lazy_path_penalty_interval: int = 32 # Skip calculating path length penalty during the initial phase of training lazy_path_penalty_after: int = 5_000 # How often to log generated images log_generated_interval: int = 500 # How often to save model checkpoints save_checkpoint_interval: int = 2_000 # Training mode state for logging activations mode: ModeState # Whether to log model layer outputs log_layer_outputs: bool = False # <a id="dataset_path"></a> # We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans). # You can find the download instruction in this # [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3). # Save the images inside `data/stylegan` folder. dataset_path: str = str(lab.get_data_path() / 'stylegan2') def init(self): """ ### Initialize """ # Create dataset dataset = Dataset(self.dataset_path, self.image_size) # Create data loader dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True, pin_memory=True) # Continuous [cyclic loader](../../utils.html#cycle_dataloader) self.loader = cycle_dataloader(dataloader) # $\log_2$ of image resolution log_resolution = int(math.log2(self.image_size)) # Create discriminator and generator self.discriminator = Discriminator(log_resolution).to(self.device) self.generator = Generator(log_resolution, self.d_latent).to(self.device) # Get number of generator blocks for creating style and noise inputs self.n_gen_blocks = self.generator.n_blocks # Create mapping network self.mapping_network = MappingNetwork( self.d_latent, self.mapping_network_layers).to(self.device) # Create path length penalty loss self.path_length_penalty = PathLengthPenalty(0.99).to(self.device) # Add model hooks to monitor layer outputs if self.log_layer_outputs: hook_model_outputs(self.mode, self.discriminator, 'discriminator') hook_model_outputs(self.mode, self.generator, 'generator') hook_model_outputs(self.mode, self.mapping_network, 'mapping_network') # Discriminator and generator losses self.discriminator_loss = DiscriminatorLoss().to(self.device) self.generator_loss = GeneratorLoss().to(self.device) # Create optimizers self.discriminator_optimizer = torch.optim.Adam( self.discriminator.parameters(), lr=self.learning_rate, betas=self.adam_betas) self.generator_optimizer = torch.optim.Adam( self.generator.parameters(), lr=self.learning_rate, betas=self.adam_betas) self.mapping_network_optimizer = torch.optim.Adam( self.mapping_network.parameters(), lr=self.mapping_network_learning_rate, betas=self.adam_betas) # Set tracker configurations tracker.set_image("generated", True) def get_w(self, batch_size: int): """ ### Sample $w$ This samples $z$ randomly and get $w$ from the mapping network. We also apply style mixing sometimes where we generate two latent variables $z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$. Then we randomly sample a cross-over point and apply $w_1$ to the generator blocks before the cross-over point and $w_2$ to the blocks after. """ # Mix styles if torch.rand(()).item() < self.style_mixing_prob: # Random cross-over point cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks) # Sample $z_1$ and $z_2$ z2 = torch.randn(batch_size, self.d_latent).to(self.device) z1 = torch.randn(batch_size, self.d_latent).to(self.device) # Get $w_1$ and $w_2$ w1 = self.mapping_network(z1) w2 = self.mapping_network(z2) # Expand $w_1$ and $w_2$ for the generator blocks and concatenate w1 = w1[None, :, :].expand(cross_over_point, -1, -1) w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1) return torch.cat((w1, w2), dim=0) # Without mixing else: # Sample $z$ and $z$ z = torch.randn(batch_size, self.d_latent).to(self.device) # Get $w$ and $w$ w = self.mapping_network(z) # Expand $w$ for the generator blocks return w[None, :, :].expand(self.n_gen_blocks, -1, -1) def get_noise(self, batch_size: int): """ ### Generate noise This generates noise for each [generator block](index.html#generator_block) """ # List to store noise noise = [] # Noise resolution starts from $4$ resolution = 4 # Generate noise for each generator block for i in range(self.n_gen_blocks): # The first block has only one $3 \times 3$ convolution if i == 0: n1 = None # Generate noise to add after the first convolution layer else: n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device) # Generate noise to add after the second convolution layer n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device) # Add noise tensors to the list noise.append((n1, n2)) # Next block has $2 \times$ resolution resolution *= 2 # Return noise tensors return noise def generate_images(self, batch_size: int): """ ### Generate images This generate images using the generator """ # Get $w$ w = self.get_w(batch_size) # Get noise noise = self.get_noise(batch_size) # Generate images images = self.generator(w, noise) # Return images and $w$ return images, w def step(self, idx: int): """ ### Training Step """ # Train the discriminator with monit.section('Discriminator'): # Reset gradients self.discriminator_optimizer.zero_grad() # Accumulate gradients for `gradient_accumulate_steps` for i in range(self.gradient_accumulate_steps): # Update `mode`. Set whether to log activation with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0): # Sample images from generator generated_images, _ = self.generate_images(self.batch_size) # Discriminator classification for generated images fake_output = self.discriminator(generated_images.detach()) # Get real images from the data loader real_images = next(self.loader).to(self.device) # We need to calculate gradients w.r.t. real images for gradient penalty if (idx + 1) % self.lazy_gradient_penalty_interval == 0: real_images.requires_grad_() # Discriminator classification for real images real_output = self.discriminator(real_images) # Get discriminator loss real_loss, fake_loss = self.discriminator_loss( real_output, fake_output) disc_loss = real_loss + fake_loss # Add gradient penalty if (idx + 1) % self.lazy_gradient_penalty_interval == 0: # Calculate and log gradient penalty gp = self.gradient_penalty(real_images, real_output) tracker.add('loss.gp', gp) # Multiply by coefficient and add gradient penalty disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval # Compute gradients disc_loss.backward() # Log discriminator loss tracker.add('loss.discriminator', disc_loss) if (idx + 1) % self.log_generated_interval == 0: # Log discriminator model parameters occasionally tracker.add('discriminator', self.discriminator) # Clip gradients for stabilization torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0) # Take optimizer step self.discriminator_optimizer.step() # Train the generator with monit.section('Generator'): # Reset gradients self.generator_optimizer.zero_grad() self.mapping_network_optimizer.zero_grad() # Accumulate gradients for `gradient_accumulate_steps` for i in range(self.gradient_accumulate_steps): # Sample images from generator generated_images, w = self.generate_images(self.batch_size) # Discriminator classification for generated images fake_output = self.discriminator(generated_images) # Get generator loss gen_loss = self.generator_loss(fake_output) # Add path length penalty if idx > self.lazy_path_penalty_after and ( idx + 1) % self.lazy_path_penalty_interval == 0: # Calculate path length penalty plp = self.path_length_penalty(w, generated_images) # Ignore if `nan` if not torch.isnan(plp): tracker.add('loss.plp', plp) gen_loss = gen_loss + plp # Calculate gradients gen_loss.backward() # Log generator loss tracker.add('loss.generator', gen_loss) if (idx + 1) % self.log_generated_interval == 0: # Log discriminator model parameters occasionally tracker.add('generator', self.generator) tracker.add('mapping_network', self.mapping_network) # Clip gradients for stabilization torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0) # Take optimizer step self.generator_optimizer.step() self.mapping_network_optimizer.step() # Log generated images if (idx + 1) % self.log_generated_interval == 0: tracker.add( 'generated', torch.cat([generated_images[:6], real_images[:3]], dim=0)) # Save model checkpoints if (idx + 1) % self.save_checkpoint_interval == 0: experiment.save_checkpoint() # Flush tracker tracker.save() def train(self): """ ## Train model """ # Loop for `training_steps` for i in monit.loop(self.training_steps): # Take a training step self.step(i) # if (i + 1) % self.log_generated_interval == 0: tracker.new_line()
class Configs(MNISTConfigs, TrainValidConfigs): device: torch.device = DeviceConfigs() epochs: int = 10 is_save_models = True discriminator: Module generator: Module generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam generator_loss: GeneratorLogitsLoss discriminator_loss: DiscriminatorLogitsLoss label_smoothing: float = 0.2 discriminator_k: int = 1 def init(self): self.state_modules = [] self.generator = Generator().to(self.device) self.discriminator = Discriminator().to(self.device) self.generator_loss = GeneratorLogitsLoss(self.label_smoothing).to(self.device) self.discriminator_loss = DiscriminatorLogitsLoss(self.label_smoothing).to(self.device) hook_model_outputs(self.mode, self.generator, 'generator') hook_model_outputs(self.mode, self.discriminator, 'discriminator') tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.discriminator.*", True) tracker.set_image("generated", True, 1 / 100) def step(self, batch: Any, batch_idx: BatchIndex): self.generator.train(self.mode.is_train) self.discriminator.train(self.mode.is_train) data, target = batch[0].to(self.device), batch[1].to(self.device) # Increment step in training mode if self.mode.is_train: tracker.add_global_step(len(data)) # Train the discriminator with monit.section("discriminator"): for _ in range(self.discriminator_k): latent = torch.randn(data.shape[0], 100, device=self.device) logits_true = self.discriminator(data) logits_false = self.discriminator(self.generator(latent).detach()) loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) loss = loss_true + loss_false # Log stuff tracker.add("loss.discriminator.true.", loss_true) tracker.add("loss.discriminator.false.", loss_false) tracker.add("loss.discriminator.", loss) # Train if self.mode.is_train: self.discriminator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: tracker.add('discriminator', self.discriminator) self.discriminator_optimizer.step() # Train the generator with monit.section("generator"): latent = torch.randn(data.shape[0], 100, device=self.device) generated_images = self.generator(latent) logits = self.discriminator(generated_images) loss = self.generator_loss(logits) # Log stuff tracker.add('generated', generated_images[0:5]) tracker.add("loss.generator.", loss) # Train if self.mode.is_train: self.generator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: tracker.add('generator', self.generator) self.generator_optimizer.step() tracker.save()
class Configs(MNISTConfigs, TrainValidConfigs): """ ## Configurations This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation. """ device: torch.device = DeviceConfigs() dataset_transforms = 'mnist_gan_transforms' epochs: int = 10 is_save_models = True discriminator: Module = 'mlp' generator: Module = 'mlp' generator_optimizer: torch.optim.Adam discriminator_optimizer: torch.optim.Adam generator_loss: GeneratorLogitsLoss = 'original' discriminator_loss: DiscriminatorLogitsLoss = 'original' label_smoothing: float = 0.2 discriminator_k: int = 1 def init(self): """ Initializations """ self.state_modules = [] hook_model_outputs(self.mode, self.generator, 'generator') hook_model_outputs(self.mode, self.discriminator, 'discriminator') tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.discriminator.*", True) tracker.set_image("generated", True, 1 / 100) def sample_z(self, batch_size: int): """ $$z \sim p(z)$$ """ return torch.randn(batch_size, 100, device=self.device) def step(self, batch: Any, batch_idx: BatchIndex): """ Take a training step """ # Set model states self.generator.train(self.mode.is_train) self.discriminator.train(self.mode.is_train) # Get MNIST images data = batch[0].to(self.device) # Increment step in training mode if self.mode.is_train: tracker.add_global_step(len(data)) # Train the discriminator with monit.section("discriminator"): # Get discriminator loss loss = self.calc_discriminator_loss(data) # Train if self.mode.is_train: self.discriminator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: tracker.add('discriminator', self.discriminator) self.discriminator_optimizer.step() # Train the generator once in every `discriminator_k` if batch_idx.is_interval(self.discriminator_k): with monit.section("generator"): loss = self.calc_generator_loss(data.shape[0]) # Train if self.mode.is_train: self.generator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: tracker.add('generator', self.generator) self.generator_optimizer.step() tracker.save() def calc_discriminator_loss(self, data): """ Calculate discriminator loss """ latent = self.sample_z(data.shape[0]) logits_true = self.discriminator(data) logits_false = self.discriminator(self.generator(latent).detach()) loss_true, loss_false = self.discriminator_loss(logits_true, logits_false) loss = loss_true + loss_false # Log stuff tracker.add("loss.discriminator.true.", loss_true) tracker.add("loss.discriminator.false.", loss_false) tracker.add("loss.discriminator.", loss) return loss def calc_generator_loss(self, batch_size: int): """ Calculate generator loss """ latent = self.sample_z(batch_size) generated_images = self.generator(latent) logits = self.discriminator(generated_images) loss = self.generator_loss(logits) # Log stuff tracker.add('generated', generated_images[0:6]) tracker.add("loss.generator.", loss) return loss