Exemplo n.º 1
0
    def fit(self,
            iters,
            content_img,
            style_img,
            style_ratio,
            content_layers=None):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            content (PIL.Image): content image
            style (PIL.Image): style image
            ratio (float): weight of style loss
            content_layers (list of str): layers on which to reconstruct
                content
        """
        self.loss.to(self.device)
        self.loss.set_style(pil2t(style_img).to(self.device), style_ratio)
        self.loss.set_content(pil2t(content_img).to(self.device), content_layers)

        canvas = ParameterizedImg(3, content_img.height,
                                  content_img.width, init_sd=0.00)

        self.opt = tch.optim.RAdamW(canvas.parameters(), 1e-2, (0.7, 0.7),
                eps=0.00001, weight_decay=0)

        def forward(_):
            self.opt.zero_grad()
            img = canvas()
            loss, losses = self.loss(img)
            loss.backward()

            return {
                'loss': loss,
                'content_loss': losses['content_loss'],
                'style_loss': losses['style_loss'],
                'img': img
            }

        loop = Recipe(forward, range(iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.WindowedMetricAvg('loss'),
            tcb.WindowedMetricAvg('content_loss'),
            tcb.WindowedMetricAvg('style_loss'),
            tcb.Log('img', 'img'),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10),
            tcb.Optimizer(self.opt, log_lr=True),
            tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt,
                threshold=0.001, cooldown=500),
                step_each_batch=True)
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
Exemplo n.º 2
0
    def fit(self, n_iters, neuron):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            neuron (int): the feature map to maximize

        Returns:
            the optimized image
        """
        canvas = ParameterizedImg(3, self.input_size + 10,
                                  self.input_size + 10)

        def forward(_):
            cim = canvas()
            rnd = random.randint(0, cim.shape[2] // 10)
            im = cim[:, :, rnd:, rnd:]
            im = torch.nn.functional.interpolate(im,
                                                 size=(self.input_size,
                                                       self.input_size),
                                                 mode='bilinear')
            _, acts = self.model(self.norm(im), detach=False)
            fmap = acts[self.layer]
            loss = -fmap[0][neuron].sum()
            loss.backward()

            return {'loss': loss, 'img': cim}

        loop = Recipe(forward, range(n_iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.Log('loss', 'loss'),
            tcb.Log('img', 'img'),
            tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=self.lr)),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10)
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
Exemplo n.º 3
0
    def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'):
        """
        Args:
            lr (float, optional): the learning rate
            visdom_env (str or None): the name of the visdom env to use, or None
                to disable Visdom
        """
        ref_tensor = TF.ToTensor()(ref).unsqueeze(0)
        canvas = ParameterizedImg(1,
                                  3,
                                  ref_tensor.shape[2],
                                  ref_tensor.shape[3],
                                  init_img=ref_tensor,
                                  space='spectral',
                                  colors='uncorr')

        def forward(_):
            img = canvas()
            rnd = random.randint(0, 10)
            loss = self.loss(self.norm(img[:, :, rnd:, rnd:]))
            loss.backward()
            return {'loss': loss, 'img': img}

        loop = Recipe(forward, range(iters))
        loop.register('model', self)
        loop.register('canvas', canvas)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.Log('loss', 'loss'),
            tcb.Log('img', 'img'),
            tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)),
            tcb.VisdomLogger(visdom_env=visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10)
        ])
        loop.to(device)
        loop.run(1)
        return canvas.render().cpu()
Exemplo n.º 4
0
def GANRecipe(G,
              D,
              G_fun,
              D_fun,
              loader,
              *,
              visdom_env='main',
              checkpoint='model',
              log_every=10):
    def D_wrap(batch):
        tu.freeze(G)
        G.eval()
        tu.unfreeze(D)
        D.train()

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        D.eval()
        tu.unfreeze(G)
        G.train()

        return G_fun(batch)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, range(1))
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    def prepare_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'),
        tcb.WindowedMetricAvg('loss'),
        tcb.Log('G_metrics.loss', 'G_loss'),
        tcb.Log('G_metrics.imgs', 'G_imgs'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop)
    ])

    if checkpoint is not None:
        D_loop.callbacks.add_epilogues(
            [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)])

    G_loop.callbacks.add_epilogues([
        tcb.Log('loss', 'loss'),
        tcb.Log('imgs', 'imgs'),
        tcb.WindowedMetricAvg('loss')
    ])

    return D_loop
Exemplo n.º 5
0
    def fit(self,
            iters,
            content_img,
            style_img,
            style_ratio,
            content_layers=None):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            content (PIL.Image): content image
            style (PIL.Image): style image
            ratio (float): weight of style loss
            content_layers (list of str): layers on which to reconstruct
                content
        """
        self.loss.to(self.device)
        self.loss.set_style(pil2t(style_img).to(self.device), style_ratio)
        self.loss.set_content(
            pil2t(content_img).to(self.device), content_layers)

        self.loss2.to(self.device)
        self.loss2.set_style(
            torch.nn.functional.interpolate(pil2t(style_img)[None],
                                            scale_factor=0.5,
                                            mode='bilinear')[0].to(
                                                self.device), style_ratio)
        self.loss2.set_content(
            torch.nn.functional.interpolate(pil2t(content_img)[None],
                                            scale_factor=0.5,
                                            mode='bilinear')[0].to(
                                                self.device), content_layers)

        canvas = ParameterizedImg(3,
                                  content_img.height,
                                  content_img.width,
                                  init_img=pil2t(content_img))

        self.opt = tch.optim.RAdamW(canvas.parameters(), 3e-2)

        def forward(_):
            img = canvas()
            loss, losses = self.loss(img)
            loss.backward()
            loss, losses = self.loss2(
                torch.nn.functional.interpolate(canvas(),
                                                scale_factor=0.5,
                                                mode='bilinear'))
            loss.backward()

            return {
                'loss': loss,
                'content_loss': losses['content_loss'],
                'style_loss': losses['style_loss'],
                'img': img
            }

        loop = Recipe(forward, range(iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.WindowedMetricAvg('loss'),
            tcb.WindowedMetricAvg('content_loss'),
            tcb.WindowedMetricAvg('style_loss'),
            tcb.Log('img', 'img'),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10),
            tcb.Optimizer(self.opt, log_lr=True),
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
Exemplo n.º 6
0
def GANRecipe(G: nn.Module,
              D: nn.Module,
              G_fun,
              D_fun,
              test_fun,
              loader: Iterable[Any],
              *,
              visdom_env: Optional[str] = 'main',
              checkpoint: Optional[str] = 'model',
              test_every: int = 1000,
              log_every: int = 10,
              g_every: int = 1) -> Recipe:
    def D_wrap(batch):
        tu.freeze(G)
        tu.unfreeze(D)

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        tu.unfreeze(G)

        return G_fun(batch)

    def test_wrap(batch):
        tu.freeze(G)
        tu.freeze(D)
        D.eval()
        G.eval()

        with torch.no_grad():
            out = test_fun(batch)

        D.train()
        G.train()
        return out

    class NoLim:
        def __init__(self):
            self.i = iter(loader)
            self.did_send = False

        def __iter__(self):
            return self

        def __next__(self):
            if self.did_send:
                self.did_send = False
                raise StopIteration
            self.did_send = True
            try:
                return next(self.i)
            except:
                self.i = iter(loader)
                return next(self.i)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, NoLim())
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    test_loop = Recipe(test_wrap, NoLim())
    D_loop.test_loop = test_loop
    D_loop.register('test_loop', test_loop)

    def G_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    def prepare_test(state):
        test_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.Log('imgs', 'G_imgs'),
        tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.CallRecipe(test_loop,
                       test_every,
                       init_fun=prepare_test,
                       prefix='Test'),
    ])

    G_loop.callbacks.add_epilogues([
        tcb.WindowedMetricAvg('G_loss'),
        tcb.VisdomLogger(visdom_env=visdom_env,
                         log_every=log_every,
                         post_epoch_ends=False)
    ])

    if checkpoint is not None:
        test_loop.callbacks.add_epilogues([
            tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop),
            tcb.VisdomLogger(visdom_env=visdom_env),
        ])

    return D_loop