def fit(self, iters, content_img, style_img, style_ratio, content_layers=None): """ Run the recipe Args: n_iters (int): number of iterations to run content (PIL.Image): content image style (PIL.Image): style image ratio (float): weight of style loss content_layers (list of str): layers on which to reconstruct content """ self.loss.to(self.device) self.loss.set_style(pil2t(style_img).to(self.device), style_ratio) self.loss.set_content(pil2t(content_img).to(self.device), content_layers) canvas = ParameterizedImg(3, content_img.height, content_img.width, init_sd=0.00) self.opt = tch.optim.RAdamW(canvas.parameters(), 1e-2, (0.7, 0.7), eps=0.00001, weight_decay=0) def forward(_): self.opt.zero_grad() img = canvas() loss, losses = self.loss(img) loss.backward() return { 'loss': loss, 'content_loss': losses['content_loss'], 'style_loss': losses['style_loss'], 'img': img } loop = Recipe(forward, range(iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.WindowedMetricAvg('loss'), tcb.WindowedMetricAvg('content_loss'), tcb.WindowedMetricAvg('style_loss'), tcb.Log('img', 'img'), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10), tcb.Optimizer(self.opt, log_lr=True), tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt, threshold=0.001, cooldown=500), step_each_batch=True) ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def fit(self, n_iters, neuron): """ Run the recipe Args: n_iters (int): number of iterations to run neuron (int): the feature map to maximize Returns: the optimized image """ canvas = ParameterizedImg(3, self.input_size + 10, self.input_size + 10) def forward(_): cim = canvas() rnd = random.randint(0, cim.shape[2] // 10) im = cim[:, :, rnd:, rnd:] im = torch.nn.functional.interpolate(im, size=(self.input_size, self.input_size), mode='bilinear') _, acts = self.model(self.norm(im), detach=False) fmap = acts[self.layer] loss = -fmap[0][neuron].sum() loss.backward() return {'loss': loss, 'img': cim} loop = Recipe(forward, range(n_iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.Log('loss', 'loss'), tcb.Log('img', 'img'), tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=self.lr)), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10) ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'): """ Args: lr (float, optional): the learning rate visdom_env (str or None): the name of the visdom env to use, or None to disable Visdom """ ref_tensor = TF.ToTensor()(ref).unsqueeze(0) canvas = ParameterizedImg(1, 3, ref_tensor.shape[2], ref_tensor.shape[3], init_img=ref_tensor, space='spectral', colors='uncorr') def forward(_): img = canvas() rnd = random.randint(0, 10) loss = self.loss(self.norm(img[:, :, rnd:, rnd:])) loss.backward() return {'loss': loss, 'img': img} loop = Recipe(forward, range(iters)) loop.register('model', self) loop.register('canvas', canvas) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.Log('loss', 'loss'), tcb.Log('img', 'img'), tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)), tcb.VisdomLogger(visdom_env=visdom_env, log_every=10), tcb.StdoutLogger(log_every=10) ]) loop.to(device) loop.run(1) return canvas.render().cpu()
def TrainAndTest(model, train_fun, test_fun, train_loader, test_loader, test_every=100, visdom_env='main', log_every=10, checkpoint='model'): """ Perform training and testing on datasets. The model is automatically checkpointed, VisdomLogger and StdoutLogger callbacks are also already provided. The gradients are disabled and the model is automatically set to evaluation mode for the evaluation procedure. Args: model (nn.Model): a model train_fun (Callabble): a function that takes a batch as a single argument, performs a training step and return a dict of values to populate the recipe's state. test_fun (Callable): a function taking a batch as a single argument then performs something to evaluate your model and returns a dict to populate the state. train_loader (DataLoader): Training set dataloader test_loader (DataLoader): Testing set dataloader test_every (int): testing frequency, in number of iterations (default: 100) visdom_env (str): name of the visdom environment to use, or None for not using Visdom (default: None) log_every (int): logging frequency, in number of iterations (default: 100) checkpoint (str): checkpointing path or None for no checkpointing """ def eval_call(batch): model.eval() with torch.no_grad(): out = test_fun(batch) model.train() return out train_loop = Recipe(train_fun, train_loader) train_loop.register('model', model) test_loop = Recipe(eval_call, test_loader) train_loop.test_loop = test_loop train_loop.register('test_loop', test_loop) def prepare_test(state): test_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) train_loop.callbacks.add_prologues([tcb.Counter()]) train_loop.callbacks.add_epilogues([ tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), ]) test_loop.callbacks.add_epilogues([ tcb.VisdomLogger(visdom_env=visdom_env, log_every=-1, prefix='test_'), tcb.StdoutLogger(log_every=-1, prefix='Test'), ]) if checkpoint is not None: test_loop.callbacks.add_epilogues([ tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', train_loop, key_best=lambda state: -state['test_loop'][ 'callbacks']['state']['metrics']['loss']) ]) return train_loop
def GANRecipe(G, D, G_fun, D_fun, loader, *, visdom_env='main', checkpoint='model', log_every=10): def D_wrap(batch): tu.freeze(G) G.eval() tu.unfreeze(D) D.train() return D_fun(batch) def G_wrap(batch): tu.freeze(D) D.eval() tu.unfreeze(G) G.train() return G_fun(batch) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, range(1)) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) def prepare_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'), tcb.WindowedMetricAvg('loss'), tcb.Log('G_metrics.loss', 'G_loss'), tcb.Log('G_metrics.imgs', 'G_imgs'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop) ]) if checkpoint is not None: D_loop.callbacks.add_epilogues( [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)]) G_loop.callbacks.add_epilogues([ tcb.Log('loss', 'loss'), tcb.Log('imgs', 'imgs'), tcb.WindowedMetricAvg('loss') ]) return D_loop
def fit(self, iters, content_img, style_img, style_ratio, content_layers=None): """ Run the recipe Args: n_iters (int): number of iterations to run content (PIL.Image): content image style (PIL.Image): style image ratio (float): weight of style loss content_layers (list of str): layers on which to reconstruct content """ self.loss.to(self.device) self.loss.set_style(pil2t(style_img).to(self.device), style_ratio) self.loss.set_content( pil2t(content_img).to(self.device), content_layers) self.loss2.to(self.device) self.loss2.set_style( torch.nn.functional.interpolate(pil2t(style_img)[None], scale_factor=0.5, mode='bilinear')[0].to( self.device), style_ratio) self.loss2.set_content( torch.nn.functional.interpolate(pil2t(content_img)[None], scale_factor=0.5, mode='bilinear')[0].to( self.device), content_layers) canvas = ParameterizedImg(3, content_img.height, content_img.width, init_img=pil2t(content_img)) self.opt = tch.optim.RAdamW(canvas.parameters(), 3e-2) def forward(_): img = canvas() loss, losses = self.loss(img) loss.backward() loss, losses = self.loss2( torch.nn.functional.interpolate(canvas(), scale_factor=0.5, mode='bilinear')) loss.backward() return { 'loss': loss, 'content_loss': losses['content_loss'], 'style_loss': losses['style_loss'], 'img': img } loop = Recipe(forward, range(iters)) loop.register('canvas', canvas) loop.register('model', self) loop.callbacks.add_callbacks([ tcb.Counter(), tcb.WindowedMetricAvg('loss'), tcb.WindowedMetricAvg('content_loss'), tcb.WindowedMetricAvg('style_loss'), tcb.Log('img', 'img'), tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10), tcb.StdoutLogger(log_every=10), tcb.Optimizer(self.opt, log_lr=True), ]) loop.to(self.device) loop.run(1) return canvas.render().cpu()
def GANRecipe(G: nn.Module, D: nn.Module, G_fun, D_fun, test_fun, loader: Iterable[Any], *, visdom_env: Optional[str] = 'main', checkpoint: Optional[str] = 'model', test_every: int = 1000, log_every: int = 10, g_every: int = 1) -> Recipe: def D_wrap(batch): tu.freeze(G) tu.unfreeze(D) return D_fun(batch) def G_wrap(batch): tu.freeze(D) tu.unfreeze(G) return G_fun(batch) def test_wrap(batch): tu.freeze(G) tu.freeze(D) D.eval() G.eval() with torch.no_grad(): out = test_fun(batch) D.train() G.train() return out class NoLim: def __init__(self): self.i = iter(loader) self.did_send = False def __iter__(self): return self def __next__(self): if self.did_send: self.did_send = False raise StopIteration self.did_send = True try: return next(self.i) except: self.i = iter(loader) return next(self.i) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, NoLim()) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) test_loop = Recipe(test_wrap, NoLim()) D_loop.test_loop = test_loop D_loop.register('test_loop', test_loop) def G_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) def prepare_test(state): test_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.Log('imgs', 'G_imgs'), tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test, prefix='Test'), ]) G_loop.callbacks.add_epilogues([ tcb.WindowedMetricAvg('G_loss'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every, post_epoch_ends=False) ]) if checkpoint is not None: test_loop.callbacks.add_epilogues([ tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop), tcb.VisdomLogger(visdom_env=visdom_env), ]) return D_loop