def train_net(Gen, Discr): G = Gen(in_noise=128, out_ch=3) G_polyak = copy.deepcopy(G).eval() D = Discr() print(G) print(D) def G_fun(batch): z = torch.randn(BS, 128, device=device) fake = G(z) preds = D(fake * 2 - 1).squeeze() loss = gan_loss.generated(preds) loss.backward() return {'loss': loss.item(), 'imgs': fake.detach()} def G_polyak_fun(batch): z = torch.randn(BS, 128, device=device) fake = G_polyak(z) return {'imgs': fake.detach()} def D_fun(batch): z = torch.randn(BS, 128, device=device) fake = G(z) fake_loss = gan_loss.fake(D(fake * 2 - 1)) fake_loss.backward() x = batch[0] real_loss = gan_loss.real(D(x * 2 - 1)) real_loss.backward() loss = real_loss.item() + fake_loss.item() return { 'loss': loss, 'real_loss': real_loss.item(), 'fake_loss': fake_loss.item() } loop = GANRecipe(G, D, G_fun, D_fun, G_polyak_fun, dl, log_every=100).to(device) loop.register('polyak', G_polyak) loop.G_loop.callbacks.add_callbacks([ tcb.Optimizer( tch.optim.RAdamW(G.parameters(), lr=1e-4, betas=(0., 0.99))), tcb.Polyak(G, G_polyak), ]) loop.register('G_polyak', G_polyak) loop.callbacks.add_callbacks([ tcb.Log('batch.0', 'x'), tcb.WindowedMetricAvg('real_loss'), tcb.WindowedMetricAvg('fake_loss'), tcb.Optimizer( tch.optim.RAdamW(D.parameters(), lr=4e-4, betas=(0., 0.99))), ]) loop.test_loop.callbacks.add_callbacks([ tcb.Log('imgs', 'polyak_imgs'), tcb.VisdomLogger('main', prefix='test') ]) loop.to(device).run(100)
def test_tesorboard(): from torchelie.recipes import Recipe batch_size = 4 class Dataset: def __init__(self, batch_size): self.batch_size = batch_size self.mnist = FashionMNIST('.', download=True, transform=PILToTensor()) self.classes = self.mnist.classes self.num_classes = len(self.mnist.class_to_idx) self.target_by_classes = [[idx for idx in range(len(self.mnist)) if self.mnist.targets[idx] == i] for i in range(self.num_classes)] def __len__(self): return self.batch_size * self.num_classes def __getitem__(self, item): idx = self.target_by_classes[item//self.batch_size][item] x, y = self.mnist[idx] x = torch.stack(3*[x]).squeeze() x[2] = 0 return x, y dst = Dataset(batch_size) def train(b): x, y = b return {'letter_number_int': int(y[0]), 'letter_number_tensor': y[0], 'letter_text': dst.classes[int(y[0])], 'test_html': '<b>test HTML</b>', 'letter_gray_img_HW': x[0, 0], 'letter_gray_img_CHW': x[0, :1], 'letter_gray_imgs_NCHW': x[:, :1], 'letter_color_img_CHW': x[0], 'letter_color_imgs_NCHW': x} r = Recipe(train, DataLoader(dst, batch_size)) r.callbacks.add_callbacks([ tcb.Counter(), tcb.TensorboardLogger(log_every=1), tcb.Log('letter_number_int', 'letter_number_int'), tcb.Log('letter_number_tensor', 'letter_number_tensor'), tcb.Log('letter_text', 'letter_text'), tcb.Log('test_html', 'test_html'), tcb.Log('letter_gray_img_HW', 'letter_gray_img_HW'), tcb.Log('letter_gray_img_CHW', 'letter_gray_img_CHW'), tcb.Log('letter_gray_imgs_NCHW', 'letter_gray_imgs_NCHW'), tcb.Log('letter_color_img_CHW', 'letter_color_img_CHW'), tcb.Log('letter_color_imgs_NCHW', 'letter_color_imgs_NCHW'), ]) r.run(1)
def fit(self, iters, content_img, style_img, style_ratio, content_layers=None): """ Run the recipe Args: n_iters (int): number of iterations to run content (PIL.Image): content image style (PIL.Image): style image ratio (float): weight of style loss content_layers (list of str): layers on which to reconstruct content """ self.loss.to(self.device) self.loss.set_style(pil2t(style_img).to(self.device), style_ratio) self.loss.set_content(pil2t(content_img).to(self.device), content_layers) canvas = ParameterizedImg(3, content_img.height, content_img.width, init_sd=0.00) self.opt = tch.optim.RAdamW(canvas.parameters(), 1e-2, (0.7, 0.7), eps=0.00001, weight_decay=0) def forward(_): self.opt.zero_grad() img = canvas() loss, losses = self.loss(img) loss.backward() return { 'loss': loss, 'content_loss': losses['content_loss'], 'style_loss': losses['style_loss'], 'img': img } loop = Recipe(forward, range(iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.WindowedMetricAvg('loss'), tcb.WindowedMetricAvg('content_loss'), tcb.WindowedMetricAvg('style_loss'), tcb.Log('img', 'img'), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10), tcb.Optimizer(self.opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt, threshold=0.001, cooldown=500), step_each_batch=True) ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def make_loop(hourglass, body, display, num_iter, lr): loop = TrainAndCall(hourglass, body, display, range(num_iter), test_every=50, checkpoint=None) opt = tch.optim.RAdamW(hourglass.parameters(), lr=lr) loop.callbacks.add_callbacks([ tcb.WindowedMetricAvg('loss'), tcb.Optimizer(opt, clip_grad_norm=0.5, log_lr=True), ]) loop.test_loop.callbacks.add_callbacks([ tcb.Log('recon', 'img'), tcb.Log('orig', 'orig'), tcb.Log('loss', 'loss'), ]) return loop
def fit(self, n_iters, neuron): """ Run the recipe Args: n_iters (int): number of iterations to run neuron (int): the feature map to maximize Returns: the optimized image """ canvas = ParameterizedImg(3, self.input_size + 10, self.input_size + 10) def forward(_): cim = canvas() rnd = random.randint(0, cim.shape[2] // 10) im = cim[:, :, rnd:, rnd:] im = torch.nn.functional.interpolate(im, size=(self.input_size, self.input_size), mode='bilinear') _, acts = self.model(self.norm(im), detach=False) fmap = acts[self.layer] loss = -fmap[0][neuron].sum() loss.backward() return {'loss': loss, 'img': cim} loop = Recipe(forward, range(n_iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.Log('loss', 'loss'), tcb.Log('img', 'img'), tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=self.lr)), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10) ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'): """ Args: lr (float, optional): the learning rate visdom_env (str or None): the name of the visdom env to use, or None to disable Visdom """ ref_tensor = TF.ToTensor()(ref).unsqueeze(0) canvas = ParameterizedImg(1, 3, ref_tensor.shape[2], ref_tensor.shape[3], init_img=ref_tensor, space='spectral', colors='uncorr') def forward(_): img = canvas() rnd = random.randint(0, 10) loss = self.loss(self.norm(img[:, :, rnd:, rnd:])) loss.backward() return {'loss': loss, 'img': img} loop = Recipe(forward, range(iters)) loop.register('model', self) loop.register('canvas', canvas) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.Log('loss', 'loss'), tcb.Log('img', 'img'), tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)), tcb.VisdomLogger(visdom_env=visdom_env, log_every=10), tcb.StdoutLogger(log_every=10) ]) loop.to(device) loop.run(1) return canvas.render().cpu()
def inpainting(img, mask, hourglass, input_dim, iters, lr, noise_std=1 / 30, device='cuda'): im = TFF.to_tensor(img)[None].to(device) mask = TFF.to_tensor(mask)[None].to(device) z = input_noise((im.shape[2], im.shape[3]), input_dim) z = z.to(device) print(hourglass) def body(batch): recon = hourglass(z + torch.randn_like(z) * noise_std) loss = torch.sum( F.mse_loss(F.interpolate(recon, size=im.shape[2:], mode='nearest'), im, reduction='none') * mask / mask.sum()) loss.backward() return {"loss": loss} def display(): recon = hourglass(z) recon = F.interpolate(recon, size=im.shape[2:], mode='nearest') loss = F.mse_loss(recon * mask, im) result = recon * (1 - mask) + im * mask return { "loss": loss, "recon": recon.clamp(0, 1), 'orig': im, 'result': result.clamp(0, 1) } loop = make_loop(hourglass, body, display, iters, lr) loop.test_loop.callbacks.add_callbacks([tcb.Log('result', 'result')]) loop.to(device) loop.run(1) with torch.no_grad(): hourglass.eval() return TFF.to_pil_image(hourglass(z)[0].cpu())
def GANRecipe(G, D, G_fun, D_fun, loader, *, visdom_env='main', checkpoint='model', log_every=10): def D_wrap(batch): tu.freeze(G) G.eval() tu.unfreeze(D) D.train() return D_fun(batch) def G_wrap(batch): tu.freeze(D) D.eval() tu.unfreeze(G) G.train() return G_fun(batch) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, range(1)) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) def prepare_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'), tcb.WindowedMetricAvg('loss'), tcb.Log('G_metrics.loss', 'G_loss'), tcb.Log('G_metrics.imgs', 'G_imgs'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop) ]) if checkpoint is not None: D_loop.callbacks.add_epilogues( [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)]) G_loop.callbacks.add_epilogues([ tcb.Log('loss', 'loss'), tcb.Log('imgs', 'imgs'), tcb.WindowedMetricAvg('loss') ]) return D_loop
def StyleGAN2Recipe(G: nn.Module, D: nn.Module, dataloader, noise_size: int, gpu_id: int, total_num_gpus: int, *, G_lr: float = 2e-3, D_lr: float = 2e-3, tag: str = 'model', ada: bool = True): """ StyleGAN2 Recipe distributed with DistributedDataParallel Args: G (nn.Module): a Generator. D (nn.Module): a Discriminator. dataloader: a dataloader conforming to torchvision's API. noise_size (int): the size of the input noise vector. gpu_id (int): the GPU index on which to run. total_num_gpus (int): how many GPUs are they G_lr (float): RAdamW lr for G D_lr (float): RAdamW lr for D tag (str): tag for Visdom and checkpoints ada (bool): whether to enable Adaptive Data Augmentation Returns: recipe, G EMA model """ G_polyak = copy.copy(G) G = nn.parallel.DistributedDataParallel(G.to(gpu_id), [gpu_id], gpu_id) D = nn.parallel.DistributedDataParallel(D.to(gpu_id), [gpu_id], gpu_id) print(G) print(D) optG: torch.optim.Optimizer = RAdamW(G.parameters(), G_lr, betas=(0., 0.99), weight_decay=0) optD: torch.optim.Optimizer = RAdamW(D.parameters(), D_lr, betas=(0., 0.99), weight_decay=0) batch_size = len(next(iter(dataloader))[0]) diffTF = ADATF(-2 if not ada else -0.9, 50000 / (batch_size * total_num_gpus)) ppl = PPL(4) def G_train(batch): with G.no_sync(): pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id)) ############## # G pass # ############## imgs = G(torch.randn(batch_size, noise_size, device=gpu_id)) pred = D(diffTF(imgs) * 2 - 1) score = gan_loss.generated(pred) score.backward() return {'G_loss': score.item(), 'ppl': pl} gradient_penalty = GradientPenalty(0.1) def D_train(batch): ################### # Fake pass # ################### with D.no_sync(): # Sync the gradient on the last backward noise = torch.randn(batch_size, noise_size, device=gpu_id) with torch.no_grad(): fake = G(noise) fake.requires_grad_(True) fake.retain_grad() fake_tf = diffTF(fake) * 2 - 1 fakeness = D(fake_tf).squeeze(1) fake_loss = gan_loss.fake(fakeness) fake_loss.backward() correct = (fakeness < 0).int().eq(1).float().sum() fake_grad = fake.grad.detach().norm(dim=1, keepdim=True) fake_grad /= fake_grad.max() tfmed = diffTF(batch[0]) * 2 - 1 with D.no_sync(): grad_norm = gradient_penalty(D, batch[0] * 2 - 1, fake.detach() * 2 - 1) ################### # Real pass # ################### real_out = D(tfmed) correct += (real_out > 0).detach().int().eq(1).float().sum() real_loss = gan_loss.real(real_out) real_loss.backward() pos_ratio = real_out.gt(0).float().mean().cpu().item() diffTF.log_loss(-pos_ratio) return { 'imgs': fake.detach(), 'i_grad': fake_grad, 'loss': real_loss.item() + fake_loss.item(), 'fake_loss': fake_loss.item(), 'real_loss': real_loss.item(), 'ADA-p': diffTF.p, 'D-correct': correct / (2 * real_out.numel()), 'grad_norm': grad_norm } tu.freeze(G_polyak) def test(batch): G_polyak.eval() def sample(N, n_iter, alpha=0.01, show_every=10): noise = torch.randn(N, noise_size, device=gpu_id, requires_grad=True) opt = torch.optim.Adam([noise], lr=alpha) fakes = [] for i in range(n_iter): noise += torch.randn_like(noise) / 10 fake_batch = [] opt.zero_grad() for j in range(0, N, batch_size): with torch.enable_grad(): n_batch = noise[j:j + batch_size] fake = G_polyak(n_batch, mixing=False) fake_batch.append(fake) log_prob = n_batch[:, 32:].pow(2).mul_(-0.5) fakeness = -D(fake * 2 - 1).sum() - log_prob.sum() fakeness.backward() opt.step() fake_batch = torch.cat(fake_batch, dim=0) if i % show_every == 0: fakes.append(fake_batch.cpu().detach().clone()) fakes.append(fake_batch.cpu().detach().clone()) return torch.cat(fakes, dim=0) fake = sample(8, 50, alpha=0.001, show_every=10) noise1 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id) noise2 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id) t = torch.linspace(0, 1, 8, device=noise1.device).view(8, 1) noise = noise1 * t + noise2 * (1 - t) noise = noise.view(-1, noise_size) interp = torch.cat([ G_polyak(n, mixing=False) for n in torch.split(noise, batch_size) ], dim=0) return { 'polyak_imgs': fake, 'polyak_interp': interp, } recipe = GANRecipe(G, D, G_train, D_train, test, dataloader, visdom_env=tag if gpu_id == 0 else None, log_every=10, test_every=1000, checkpoint=tag if gpu_id == 0 else None, g_every=1) recipe.callbacks.add_callbacks([ tcb.Log('batch.0', 'x'), tcb.WindowedMetricAvg('fake_loss'), tcb.WindowedMetricAvg('real_loss'), tcb.WindowedMetricAvg('grad_norm'), tcb.WindowedMetricAvg('ADA-p'), tcb.WindowedMetricAvg('D-correct'), tcb.Log('i_grad', 'img_grad'), tch.callbacks.Optimizer(optD), ]) recipe.G_loop.callbacks.add_callbacks([ tch.callbacks.Optimizer(optG), tcb.Polyak(G.module, G_polyak, 0.5**((batch_size * total_num_gpus) / 20000)), tcb.WindowedMetricAvg('ppl'), ]) recipe.test_loop.callbacks.add_callbacks([ tcb.Log('polyak_imgs', 'polyak'), tcb.Log('polyak_interp', 'interp'), ]) recipe.register('G_polyak', G_polyak) recipe.to(gpu_id) return recipe, G_polyak
def fit(self, iters, content_img, style_img, style_ratio, content_layers=None): """ Run the recipe Args: n_iters (int): number of iterations to run content (PIL.Image): content image style (PIL.Image): style image ratio (float): weight of style loss content_layers (list of str): layers on which to reconstruct content """ self.loss.to(self.device) self.loss.set_style(pil2t(style_img).to(self.device), style_ratio) self.loss.set_content( pil2t(content_img).to(self.device), content_layers) self.loss2.to(self.device) self.loss2.set_style( torch.nn.functional.interpolate(pil2t(style_img)[None], scale_factor=0.5, mode='bilinear')[0].to( self.device), style_ratio) self.loss2.set_content( torch.nn.functional.interpolate(pil2t(content_img)[None], scale_factor=0.5, mode='bilinear')[0].to( self.device), content_layers) canvas = ParameterizedImg(3, content_img.height, content_img.width, init_img=pil2t(content_img)) self.opt = tch.optim.RAdamW(canvas.parameters(), 3e-2) def forward(_): img = canvas() loss, losses = self.loss(img) loss.backward() loss, losses = self.loss2( torch.nn.functional.interpolate(canvas(), scale_factor=0.5, mode='bilinear')) loss.backward() return { 'loss': loss, 'content_loss': losses['content_loss'], 'style_loss': losses['style_loss'], 'img': img } loop = Recipe(forward, range(iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.WindowedMetricAvg('loss'), tcb.WindowedMetricAvg('content_loss'), tcb.WindowedMetricAvg('style_loss'), tcb.Log('img', 'img'), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10), tcb.Optimizer(self.opt, log_lr=True), ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def GANRecipe(G: nn.Module, D: nn.Module, G_fun, D_fun, test_fun, loader: Iterable[Any], *, visdom_env: Optional[str] = 'main', checkpoint: Optional[str] = 'model', test_every: int = 1000, log_every: int = 10, g_every: int = 1) -> Recipe: def D_wrap(batch): tu.freeze(G) tu.unfreeze(D) return D_fun(batch) def G_wrap(batch): tu.freeze(D) tu.unfreeze(G) return G_fun(batch) def test_wrap(batch): tu.freeze(G) tu.freeze(D) D.eval() G.eval() with torch.no_grad(): out = test_fun(batch) D.train() G.train() return out class NoLim: def __init__(self): self.i = iter(loader) self.did_send = False def __iter__(self): return self def __next__(self): if self.did_send: self.did_send = False raise StopIteration self.did_send = True try: return next(self.i) except: self.i = iter(loader) return next(self.i) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, NoLim()) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) test_loop = Recipe(test_wrap, NoLim()) D_loop.test_loop = test_loop D_loop.register('test_loop', test_loop) def G_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) def prepare_test(state): test_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.Log('imgs', 'G_imgs'), tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test, prefix='Test'), ]) G_loop.callbacks.add_epilogues([ tcb.WindowedMetricAvg('G_loss'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every, post_epoch_ends=False) ]) if checkpoint is not None: test_loop.callbacks.add_epilogues([ tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop), tcb.VisdomLogger(visdom_env=visdom_env), ]) return D_loop