def train(model, loader): def train_step(batch): x = batch[0] x = x.expand(-1, 3, -1, -1) x2 = model(x * 2 - 1) loss = F.cross_entropy(x2, (x * 255).long()) loss.backward() reconstruction = x2.argmax(dim=1).float() / 255.0 return {'loss': loss, 'reconstruction': reconstruction} def after_train(): imgs = model.sample(1, 4).expand(-1, 3, -1, -1) return {'imgs': imgs} opt = RAdamW(model.parameters(), lr=3e-3) trainer = TrainAndCall(model, train_step, after_train, dl, test_every=500, visdom_env='pixelcnn') trainer.callbacks.add_callbacks([ tcb.WindowedMetricAvg('loss'), tcb.Log('reconstruction', 'reconstruction'), tcb.Optimizer(opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt)) ]) trainer.test_loop.callbacks.add_callbacks([ tcb.Log('imgs', 'imgs'), ]) trainer.to('cuda') trainer.run(10)
def make_optimizer(self): return RAdamW(self.model.parameters(), lr=1e-2)
def make_optimizer(self): return RAdamW(self.model.parameters(), lr=3e-3, betas=(0.95, 0.9995))
def StyleGAN2Recipe(G: nn.Module, D: nn.Module, dataloader, noise_size: int, gpu_id: int, total_num_gpus: int, *, G_lr: float = 2e-3, D_lr: float = 2e-3, tag: str = 'model', ada: bool = True): """ StyleGAN2 Recipe distributed with DistributedDataParallel Args: G (nn.Module): a Generator. D (nn.Module): a Discriminator. dataloader: a dataloader conforming to torchvision's API. noise_size (int): the size of the input noise vector. gpu_id (int): the GPU index on which to run. total_num_gpus (int): how many GPUs are they G_lr (float): RAdamW lr for G D_lr (float): RAdamW lr for D tag (str): tag for Visdom and checkpoints ada (bool): whether to enable Adaptive Data Augmentation Returns: recipe, G EMA model """ G_polyak = copy.copy(G) G = nn.parallel.DistributedDataParallel(G.to(gpu_id), [gpu_id], gpu_id) D = nn.parallel.DistributedDataParallel(D.to(gpu_id), [gpu_id], gpu_id) print(G) print(D) optG: torch.optim.Optimizer = RAdamW(G.parameters(), G_lr, betas=(0., 0.99), weight_decay=0) optD: torch.optim.Optimizer = RAdamW(D.parameters(), D_lr, betas=(0., 0.99), weight_decay=0) batch_size = len(next(iter(dataloader))[0]) diffTF = ADATF(-2 if not ada else -0.9, 50000 / (batch_size * total_num_gpus)) ppl = PPL(4) def G_train(batch): with G.no_sync(): pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id)) ############## # G pass # ############## imgs = G(torch.randn(batch_size, noise_size, device=gpu_id)) pred = D(diffTF(imgs) * 2 - 1) score = gan_loss.generated(pred) score.backward() return {'G_loss': score.item(), 'ppl': pl} gradient_penalty = GradientPenalty(0.1) def D_train(batch): ################### # Fake pass # ################### with D.no_sync(): # Sync the gradient on the last backward noise = torch.randn(batch_size, noise_size, device=gpu_id) with torch.no_grad(): fake = G(noise) fake.requires_grad_(True) fake.retain_grad() fake_tf = diffTF(fake) * 2 - 1 fakeness = D(fake_tf).squeeze(1) fake_loss = gan_loss.fake(fakeness) fake_loss.backward() correct = (fakeness < 0).int().eq(1).float().sum() fake_grad = fake.grad.detach().norm(dim=1, keepdim=True) fake_grad /= fake_grad.max() tfmed = diffTF(batch[0]) * 2 - 1 with D.no_sync(): grad_norm = gradient_penalty(D, batch[0] * 2 - 1, fake.detach() * 2 - 1) ################### # Real pass # ################### real_out = D(tfmed) correct += (real_out > 0).detach().int().eq(1).float().sum() real_loss = gan_loss.real(real_out) real_loss.backward() pos_ratio = real_out.gt(0).float().mean().cpu().item() diffTF.log_loss(-pos_ratio) return { 'imgs': fake.detach(), 'i_grad': fake_grad, 'loss': real_loss.item() + fake_loss.item(), 'fake_loss': fake_loss.item(), 'real_loss': real_loss.item(), 'ADA-p': diffTF.p, 'D-correct': correct / (2 * real_out.numel()), 'grad_norm': grad_norm } tu.freeze(G_polyak) def test(batch): G_polyak.eval() def sample(N, n_iter, alpha=0.01, show_every=10): noise = torch.randn(N, noise_size, device=gpu_id, requires_grad=True) opt = torch.optim.Adam([noise], lr=alpha) fakes = [] for i in range(n_iter): noise += torch.randn_like(noise) / 10 fake_batch = [] opt.zero_grad() for j in range(0, N, batch_size): with torch.enable_grad(): n_batch = noise[j:j + batch_size] fake = G_polyak(n_batch, mixing=False) fake_batch.append(fake) log_prob = n_batch[:, 32:].pow(2).mul_(-0.5) fakeness = -D(fake * 2 - 1).sum() - log_prob.sum() fakeness.backward() opt.step() fake_batch = torch.cat(fake_batch, dim=0) if i % show_every == 0: fakes.append(fake_batch.cpu().detach().clone()) fakes.append(fake_batch.cpu().detach().clone()) return torch.cat(fakes, dim=0) fake = sample(8, 50, alpha=0.001, show_every=10) noise1 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id) noise2 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id) t = torch.linspace(0, 1, 8, device=noise1.device).view(8, 1) noise = noise1 * t + noise2 * (1 - t) noise = noise.view(-1, noise_size) interp = torch.cat([ G_polyak(n, mixing=False) for n in torch.split(noise, batch_size) ], dim=0) return { 'polyak_imgs': fake, 'polyak_interp': interp, } recipe = GANRecipe(G, D, G_train, D_train, test, dataloader, visdom_env=tag if gpu_id == 0 else None, log_every=10, test_every=1000, checkpoint=tag if gpu_id == 0 else None, g_every=1) recipe.callbacks.add_callbacks([ tcb.Log('batch.0', 'x'), tcb.WindowedMetricAvg('fake_loss'), tcb.WindowedMetricAvg('real_loss'), tcb.WindowedMetricAvg('grad_norm'), tcb.WindowedMetricAvg('ADA-p'), tcb.WindowedMetricAvg('D-correct'), tcb.Log('i_grad', 'img_grad'), tch.callbacks.Optimizer(optD), ]) recipe.G_loop.callbacks.add_callbacks([ tch.callbacks.Optimizer(optG), tcb.Polyak(G.module, G_polyak, 0.5**((batch_size * total_num_gpus) / 20000)), tcb.WindowedMetricAvg('ppl'), ]) recipe.test_loop.callbacks.add_callbacks([ tcb.Log('polyak_imgs', 'polyak'), tcb.Log('polyak_interp', 'interp'), ]) recipe.register('G_polyak', G_polyak) recipe.to(gpu_id) return recipe, G_polyak
def CrossEntropyClassification(model, train_loader, test_loader, classes, lr=3e-3, beta1=0.9, wd=1e-2, visdom_env='main', test_every=1000, log_every=100): """ A Classification recipe with a default froward training / testing pass using cross entropy, and extended with RAdamW and ReduceLROnPlateau. Args: model (nn.Module): a model learnable with cross entropy train_loader (DataLoader): Training set dataloader test_loader (DataLoader): Testing set dataloader classes (list of str): classes name, in order lr (float): the learning rate beta1 (float): RAdamW's beta1 wd (float): weight decay visdom_env (str): name of the visdom environment to use, or None for not using Visdom (default: None) test_every (int): testing frequency, in number of iterations (default: 1000) log_every (int): logging frequency, in number of iterations (default: 1000) """ def train_step(batch): x, y = batch pred = model(x) loss = torch.nn.functional.cross_entropy(pred, y) loss.backward() return {'loss': loss, 'pred': pred} def validation_step(batch): x, y = batch pred = model(x) loss = torch.nn.functional.cross_entropy(pred, y) return {'loss': loss, 'pred': pred} loop = Classification(model, train_step, validation_step, train_loader, test_loader, classes, visdom_env=visdom_env, test_every=test_every, log_every=log_every) opt = RAdamW(model.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=wd) loop.callbacks.add_callbacks([ tcb.Optimizer(opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt)) ]) return loop
def MixupClassification(model, train_loader, test_loader, classes, *, lr=3e-3, beta1=0.9, wd=1e-2, visdom_env='main', test_every=1000, log_every=100): """ A Classification recipe with a default froward training / testing pass using cross entropy and mixup, and extended with RAdamW and ReduceLROnPlateau. Args: model (nn.Module): a model learnable with cross entropy train_loader (DataLoader): Training set dataloader. Must have soft targets. Should be a DataLoader loading a MixupDataset or compatible. test_loader (DataLoader): Testing set dataloader. Dataset must have categorical targets. classes (list of str): classes name, in order lr (float): the learning rate beta1 (float): RAdamW's beta1 wd (float): weight decay visdom_env (str): name of the visdom environment to use, or None for not using Visdom (default: None) test_every (int): testing frequency, in number of iterations (default: 1000) log_every (int): logging frequency, in number of iterations (default: 1000) """ from torchelie.loss import continuous_cross_entropy def train_step(batch): x, y = batch pred = model(x) loss = continuous_cross_entropy(pred, y) loss.backward() return {'loss': loss} def validation_step(batch): x, y = batch pred = model(x) loss = torch.nn.functional.cross_entropy(pred, y) return {'loss': loss, 'pred': pred} loop = TrainAndTest(model, train_step, validation_step, train_loader, test_loader, visdom_env=visdom_env, test_every=test_every, log_every=log_every) loop.callbacks.add_callbacks([ tcb.WindowedMetricAvg('loss'), ]) loop.register('classes', classes) loop.test_loop.callbacks.add_callbacks([ tcb.AccAvg(post_each_batch=False), tcb.WindowedMetricAvg('loss', False), ]) if visdom_env is not None: loop.callbacks.add_epilogues( [tcb.ImageGradientVis(), tcb.MetricsTable()]) if len(classes) <= 25: loop.test_loop.callbacks.add_callbacks([ tcb.ConfusionMatrix(classes), ]) loop.test_loop.callbacks.add_callbacks([ tcb.ClassificationInspector(30, classes, False), tcb.MetricsTable(False) ]) opt = RAdamW(model.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=wd) loop.callbacks.add_callbacks([ tcb.Optimizer(opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt)) ]) return loop
def CrossEntropyClassification(model, train_loader, test_loader, classes, lr=3e-3, beta1=0.9, wd=1e-2, visdom_env='main', test_every=1000, log_every=100): """ Extends Classification with default cross entropy forward passes. Also adds RAdamW and ReduceLROnPlateau. Inherited training callbacks: - AccAvg for displaying accuracy - WindowedMetricAvg for displaying loss - ConfusionMatrix if len(classes) <= 25 - ClassificationInspector - MetricsTable - ImageGradientVis - Counter for counting iterations, connected to the testing loop as well - VisdomLogger - StdoutLogger Training callbacks: - Optimizer with RAdamW - LRSched with ReduceLROnPlateau Testing: Testing loop is in :code:`.test_loop`. Inherited testing callbacks: - AccAvg - WindowedMetricAvg - ConfusionMatrix if :code:`len(classes) <= 25` - ClassificationInspector - MetricsTable - VisdomLogger - StdoutLogger - Checkpoint saving the best testing loss Args: model (nn.Module): a model learnable with cross entropy train_loader (DataLoader): Training set dataloader test_loader (DataLoader): Testing set dataloader classes (list of str): classes name, in order lr (float): the learning rate beta1 (float): RAdamW's beta1 wd (float): weight decay visdom_env (str): name of the visdom environment to use, or None for not using Visdom (default: None) test_every (int): testing frequency, in number of iterations (default: 1000) log_every (int): logging frequency, in number of iterations (default: 1000) """ def train_step(batch): x, y = batch pred = model(x) loss = torch.nn.functional.cross_entropy(pred, y) loss.backward() return {'loss': loss, 'pred': pred} def validation_step(batch): x, y = batch pred = model(x) loss = torch.nn.functional.cross_entropy(pred, y) return {'loss': loss, 'pred': pred} loop = Classification(model, train_step, validation_step, train_loader, test_loader, classes, visdom_env=visdom_env, test_every=test_every, log_every=log_every) opt = RAdamW(model.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=wd) loop.callbacks.add_callbacks([ tcb.Optimizer(opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt)) ]) return loop