def fit(self, iters, content_img, style_img, style_ratio, content_layers=None): """ Run the recipe Args: n_iters (int): number of iterations to run content (PIL.Image): content image style (PIL.Image): style image ratio (float): weight of style loss content_layers (list of str): layers on which to reconstruct content """ self.loss.to(self.device) self.loss.set_style(pil2t(style_img).to(self.device), style_ratio) self.loss.set_content(pil2t(content_img).to(self.device), content_layers) canvas = ParameterizedImg(3, content_img.height, content_img.width, init_sd=0.00) self.opt = tch.optim.RAdamW(canvas.parameters(), 1e-2, (0.7, 0.7), eps=0.00001, weight_decay=0) def forward(_): self.opt.zero_grad() img = canvas() loss, losses = self.loss(img) loss.backward() return { 'loss': loss, 'content_loss': losses['content_loss'], 'style_loss': losses['style_loss'], 'img': img } loop = Recipe(forward, range(iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.WindowedMetricAvg('loss'), tcb.WindowedMetricAvg('content_loss'), tcb.WindowedMetricAvg('style_loss'), tcb.Log('img', 'img'), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10), tcb.Optimizer(self.opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt, threshold=0.001, cooldown=500), step_each_batch=True) ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def fit(self, n_iters, neuron): """ Run the recipe Args: n_iters (int): number of iterations to run neuron (int): the feature map to maximize Returns: the optimized image """ canvas = ParameterizedImg(3, self.input_size + 10, self.input_size + 10) def forward(_): cim = canvas() rnd = random.randint(0, cim.shape[2] // 10) im = cim[:, :, rnd:, rnd:] im = torch.nn.functional.interpolate(im, size=(self.input_size, self.input_size), mode='bilinear') _, acts = self.model(self.norm(im), detach=False) fmap = acts[self.layer] loss = -fmap[0][neuron].sum() loss.backward() return {'loss': loss, 'img': cim} loop = Recipe(forward, range(n_iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.Log('loss', 'loss'), tcb.Log('img', 'img'), tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=self.lr)), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10) ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'): """ Args: lr (float, optional): the learning rate visdom_env (str or None): the name of the visdom env to use, or None to disable Visdom """ ref_tensor = TF.ToTensor()(ref).unsqueeze(0) canvas = ParameterizedImg(1, 3, ref_tensor.shape[2], ref_tensor.shape[3], init_img=ref_tensor, space='spectral', colors='uncorr') def forward(_): img = canvas() rnd = random.randint(0, 10) loss = self.loss(self.norm(img[:, :, rnd:, rnd:])) loss.backward() return {'loss': loss, 'img': img} loop = Recipe(forward, range(iters)) loop.register('model', self) loop.register('canvas', canvas) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.Log('loss', 'loss'), tcb.Log('img', 'img'), tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)), tcb.VisdomLogger(visdom_env=visdom_env, log_every=10), tcb.StdoutLogger(log_every=10) ]) loop.to(device) loop.run(1) return canvas.render().cpu()
def GANRecipe(G, D, G_fun, D_fun, loader, *, visdom_env='main', checkpoint='model', log_every=10): def D_wrap(batch): tu.freeze(G) G.eval() tu.unfreeze(D) D.train() return D_fun(batch) def G_wrap(batch): tu.freeze(D) D.eval() tu.unfreeze(G) G.train() return G_fun(batch) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, range(1)) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) def prepare_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'), tcb.WindowedMetricAvg('loss'), tcb.Log('G_metrics.loss', 'G_loss'), tcb.Log('G_metrics.imgs', 'G_imgs'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop) ]) if checkpoint is not None: D_loop.callbacks.add_epilogues( [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)]) G_loop.callbacks.add_epilogues([ tcb.Log('loss', 'loss'), tcb.Log('imgs', 'imgs'), tcb.WindowedMetricAvg('loss') ]) return D_loop
def fit(self, iters, content_img, style_img, style_ratio, content_layers=None): """ Run the recipe Args: n_iters (int): number of iterations to run content (PIL.Image): content image style (PIL.Image): style image ratio (float): weight of style loss content_layers (list of str): layers on which to reconstruct content """ self.loss.to(self.device) self.loss.set_style(pil2t(style_img).to(self.device), style_ratio) self.loss.set_content( pil2t(content_img).to(self.device), content_layers) self.loss2.to(self.device) self.loss2.set_style( torch.nn.functional.interpolate(pil2t(style_img)[None], scale_factor=0.5, mode='bilinear')[0].to( self.device), style_ratio) self.loss2.set_content( torch.nn.functional.interpolate(pil2t(content_img)[None], scale_factor=0.5, mode='bilinear')[0].to( self.device), content_layers) canvas = ParameterizedImg(3, content_img.height, content_img.width, init_img=pil2t(content_img)) self.opt = tch.optim.RAdamW(canvas.parameters(), 3e-2) def forward(_): img = canvas() loss, losses = self.loss(img) loss.backward() loss, losses = self.loss2( torch.nn.functional.interpolate(canvas(), scale_factor=0.5, mode='bilinear')) loss.backward() return { 'loss': loss, 'content_loss': losses['content_loss'], 'style_loss': losses['style_loss'], 'img': img } loop = Recipe(forward, range(iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.WindowedMetricAvg('loss'), tcb.WindowedMetricAvg('content_loss'), tcb.WindowedMetricAvg('style_loss'), tcb.Log('img', 'img'), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10), tcb.Optimizer(self.opt, log_lr=True), ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def GANRecipe(G: nn.Module, D: nn.Module, G_fun, D_fun, test_fun, loader: Iterable[Any], *, visdom_env: Optional[str] = 'main', checkpoint: Optional[str] = 'model', test_every: int = 1000, log_every: int = 10, g_every: int = 1) -> Recipe: def D_wrap(batch): tu.freeze(G) tu.unfreeze(D) return D_fun(batch) def G_wrap(batch): tu.freeze(D) tu.unfreeze(G) return G_fun(batch) def test_wrap(batch): tu.freeze(G) tu.freeze(D) D.eval() G.eval() with torch.no_grad(): out = test_fun(batch) D.train() G.train() return out class NoLim: def __init__(self): self.i = iter(loader) self.did_send = False def __iter__(self): return self def __next__(self): if self.did_send: self.did_send = False raise StopIteration self.did_send = True try: return next(self.i) except: self.i = iter(loader) return next(self.i) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, NoLim()) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) test_loop = Recipe(test_wrap, NoLim()) D_loop.test_loop = test_loop D_loop.register('test_loop', test_loop) def G_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) def prepare_test(state): test_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.Log('imgs', 'G_imgs'), tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test, prefix='Test'), ]) G_loop.callbacks.add_epilogues([ tcb.WindowedMetricAvg('G_loss'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every, post_epoch_ends=False) ]) if checkpoint is not None: test_loop.callbacks.add_epilogues([ tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop), tcb.VisdomLogger(visdom_env=visdom_env), ]) return D_loop