示例#1
0
    def fit(self,
            iters,
            content_img,
            style_img,
            style_ratio,
            content_layers=None):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            content (PIL.Image): content image
            style (PIL.Image): style image
            ratio (float): weight of style loss
            content_layers (list of str): layers on which to reconstruct
                content
        """
        self.loss.to(self.device)
        self.loss.set_style(pil2t(style_img).to(self.device), style_ratio)
        self.loss.set_content(pil2t(content_img).to(self.device), content_layers)

        canvas = ParameterizedImg(3, content_img.height,
                                  content_img.width, init_sd=0.00)

        self.opt = tch.optim.RAdamW(canvas.parameters(), 1e-2, (0.7, 0.7),
                eps=0.00001, weight_decay=0)

        def forward(_):
            self.opt.zero_grad()
            img = canvas()
            loss, losses = self.loss(img)
            loss.backward()

            return {
                'loss': loss,
                'content_loss': losses['content_loss'],
                'style_loss': losses['style_loss'],
                'img': img
            }

        loop = Recipe(forward, range(iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.WindowedMetricAvg('loss'),
            tcb.WindowedMetricAvg('content_loss'),
            tcb.WindowedMetricAvg('style_loss'),
            tcb.Log('img', 'img'),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10),
            tcb.Optimizer(self.opt, log_lr=True),
            tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt,
                threshold=0.001, cooldown=500),
                step_each_batch=True)
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
示例#2
0
    def fit(self, n_iters, neuron):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            neuron (int): the feature map to maximize

        Returns:
            the optimized image
        """
        canvas = ParameterizedImg(3, self.input_size + 10,
                                  self.input_size + 10)

        def forward(_):
            cim = canvas()
            rnd = random.randint(0, cim.shape[2] // 10)
            im = cim[:, :, rnd:, rnd:]
            im = torch.nn.functional.interpolate(im,
                                                 size=(self.input_size,
                                                       self.input_size),
                                                 mode='bilinear')
            _, acts = self.model(self.norm(im), detach=False)
            fmap = acts[self.layer]
            loss = -fmap[0][neuron].sum()
            loss.backward()

            return {'loss': loss, 'img': cim}

        loop = Recipe(forward, range(n_iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.Log('loss', 'loss'),
            tcb.Log('img', 'img'),
            tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=self.lr)),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10)
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
示例#3
0
    def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'):
        """
        Args:
            lr (float, optional): the learning rate
            visdom_env (str or None): the name of the visdom env to use, or None
                to disable Visdom
        """
        ref_tensor = TF.ToTensor()(ref).unsqueeze(0)
        canvas = ParameterizedImg(1,
                                  3,
                                  ref_tensor.shape[2],
                                  ref_tensor.shape[3],
                                  init_img=ref_tensor,
                                  space='spectral',
                                  colors='uncorr')

        def forward(_):
            img = canvas()
            rnd = random.randint(0, 10)
            loss = self.loss(self.norm(img[:, :, rnd:, rnd:]))
            loss.backward()
            return {'loss': loss, 'img': img}

        loop = Recipe(forward, range(iters))
        loop.register('model', self)
        loop.register('canvas', canvas)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.Log('loss', 'loss'),
            tcb.Log('img', 'img'),
            tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)),
            tcb.VisdomLogger(visdom_env=visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10)
        ])
        loop.to(device)
        loop.run(1)
        return canvas.render().cpu()
示例#4
0
def TrainAndTest(model,
                 train_fun,
                 test_fun,
                 train_loader,
                 test_loader,
                 test_every=100,
                 visdom_env='main',
                 log_every=10,
                 checkpoint='model'):
    """
    Perform training and testing on datasets. The model is
    automatically checkpointed, VisdomLogger and StdoutLogger callbacks are
    also already provided. The gradients are disabled and the model is
    automatically set to evaluation mode for the evaluation procedure.

    Args:
        model (nn.Model): a model
        train_fun (Callabble): a function that takes a batch as a single
            argument, performs a training step and return a dict of values to
            populate the recipe's state.
        test_fun (Callable): a function taking a batch as a single argument
            then performs something to evaluate your model and returns a dict
            to populate the state.
        train_loader (DataLoader): Training set dataloader
        test_loader (DataLoader): Testing set dataloader
        test_every (int): testing frequency, in number of iterations (default:
            100)
        visdom_env (str): name of the visdom environment to use, or None for
            not using Visdom (default: None)
        log_every (int): logging frequency, in number of iterations (default:
            100)
        checkpoint (str): checkpointing path or None for no checkpointing
    """
    def eval_call(batch):
        model.eval()
        with torch.no_grad():
            out = test_fun(batch)
        model.train()
        return out

    train_loop = Recipe(train_fun, train_loader)
    train_loop.register('model', model)

    test_loop = Recipe(eval_call, test_loader)
    train_loop.test_loop = test_loop
    train_loop.register('test_loop', test_loop)

    def prepare_test(state):
        test_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    train_loop.callbacks.add_prologues([tcb.Counter()])
    train_loop.callbacks.add_epilogues([
        tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
    ])

    test_loop.callbacks.add_epilogues([
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=-1, prefix='test_'),
        tcb.StdoutLogger(log_every=-1, prefix='Test'),
    ])

    if checkpoint is not None:
        test_loop.callbacks.add_epilogues([
            tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth',
                           train_loop,
                           key_best=lambda state: -state['test_loop'][
                               'callbacks']['state']['metrics']['loss'])
        ])

    return train_loop
示例#5
0
文件: gan.py 项目: AriMKatz/Torchelie
def GANRecipe(G,
              D,
              G_fun,
              D_fun,
              loader,
              *,
              visdom_env='main',
              checkpoint='model',
              log_every=10):
    def D_wrap(batch):
        tu.freeze(G)
        G.eval()
        tu.unfreeze(D)
        D.train()

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        D.eval()
        tu.unfreeze(G)
        G.train()

        return G_fun(batch)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, range(1))
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    def prepare_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'),
        tcb.WindowedMetricAvg('loss'),
        tcb.Log('G_metrics.loss', 'G_loss'),
        tcb.Log('G_metrics.imgs', 'G_imgs'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop)
    ])

    if checkpoint is not None:
        D_loop.callbacks.add_epilogues(
            [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)])

    G_loop.callbacks.add_epilogues([
        tcb.Log('loss', 'loss'),
        tcb.Log('imgs', 'imgs'),
        tcb.WindowedMetricAvg('loss')
    ])

    return D_loop
示例#6
0
    def fit(self,
            iters,
            content_img,
            style_img,
            style_ratio,
            content_layers=None):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            content (PIL.Image): content image
            style (PIL.Image): style image
            ratio (float): weight of style loss
            content_layers (list of str): layers on which to reconstruct
                content
        """
        self.loss.to(self.device)
        self.loss.set_style(pil2t(style_img).to(self.device), style_ratio)
        self.loss.set_content(
            pil2t(content_img).to(self.device), content_layers)

        self.loss2.to(self.device)
        self.loss2.set_style(
            torch.nn.functional.interpolate(pil2t(style_img)[None],
                                            scale_factor=0.5,
                                            mode='bilinear')[0].to(
                                                self.device), style_ratio)
        self.loss2.set_content(
            torch.nn.functional.interpolate(pil2t(content_img)[None],
                                            scale_factor=0.5,
                                            mode='bilinear')[0].to(
                                                self.device), content_layers)

        canvas = ParameterizedImg(3,
                                  content_img.height,
                                  content_img.width,
                                  init_img=pil2t(content_img))

        self.opt = tch.optim.RAdamW(canvas.parameters(), 3e-2)

        def forward(_):
            img = canvas()
            loss, losses = self.loss(img)
            loss.backward()
            loss, losses = self.loss2(
                torch.nn.functional.interpolate(canvas(),
                                                scale_factor=0.5,
                                                mode='bilinear'))
            loss.backward()

            return {
                'loss': loss,
                'content_loss': losses['content_loss'],
                'style_loss': losses['style_loss'],
                'img': img
            }

        loop = Recipe(forward, range(iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.WindowedMetricAvg('loss'),
            tcb.WindowedMetricAvg('content_loss'),
            tcb.WindowedMetricAvg('style_loss'),
            tcb.Log('img', 'img'),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10),
            tcb.Optimizer(self.opt, log_lr=True),
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
示例#7
0
def GANRecipe(G: nn.Module,
              D: nn.Module,
              G_fun,
              D_fun,
              test_fun,
              loader: Iterable[Any],
              *,
              visdom_env: Optional[str] = 'main',
              checkpoint: Optional[str] = 'model',
              test_every: int = 1000,
              log_every: int = 10,
              g_every: int = 1) -> Recipe:
    def D_wrap(batch):
        tu.freeze(G)
        tu.unfreeze(D)

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        tu.unfreeze(G)

        return G_fun(batch)

    def test_wrap(batch):
        tu.freeze(G)
        tu.freeze(D)
        D.eval()
        G.eval()

        with torch.no_grad():
            out = test_fun(batch)

        D.train()
        G.train()
        return out

    class NoLim:
        def __init__(self):
            self.i = iter(loader)
            self.did_send = False

        def __iter__(self):
            return self

        def __next__(self):
            if self.did_send:
                self.did_send = False
                raise StopIteration
            self.did_send = True
            try:
                return next(self.i)
            except:
                self.i = iter(loader)
                return next(self.i)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, NoLim())
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    test_loop = Recipe(test_wrap, NoLim())
    D_loop.test_loop = test_loop
    D_loop.register('test_loop', test_loop)

    def G_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    def prepare_test(state):
        test_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.Log('imgs', 'G_imgs'),
        tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.CallRecipe(test_loop,
                       test_every,
                       init_fun=prepare_test,
                       prefix='Test'),
    ])

    G_loop.callbacks.add_epilogues([
        tcb.WindowedMetricAvg('G_loss'),
        tcb.VisdomLogger(visdom_env=visdom_env,
                         log_every=log_every,
                         post_epoch_ends=False)
    ])

    if checkpoint is not None:
        test_loop.callbacks.add_epilogues([
            tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop),
            tcb.VisdomLogger(visdom_env=visdom_env),
        ])

    return D_loop