Exemple #1
0
def train_net(Gen, Discr):
    G = Gen(in_noise=128, out_ch=3)
    G_polyak = copy.deepcopy(G).eval()
    D = Discr(in_ch=3, out_ch=1)

    def G_fun(batch):
        z = torch.randn(BS, 128, device=device)
        fake = G(z)
        preds = D(fake * 2 - 1).squeeze()
        loss = gan_loss.generated(preds)
        loss.backward()
        return {'loss': loss.item(), 'imgs': fake.detach()}

    def G_polyak_fun(batch):
        print('POLYAK')
        z = torch.randn(BS, 128, device=device)
        fake = G_polyak(z)
        return {'imgs': fake.detach()}

    def D_fun(batch):
        z = torch.randn(BS, 128, device=device)
        fake = G(z)
        fake_loss = gan_loss.fake(D(fake * 2 - 1))
        fake_loss.backward()

        x = batch[0]
        x = x.expand(-1, 3, -1, -1)

        real_loss = gan_loss.real(D(x * 2 - 1))
        real_loss.backward()

        loss = real_loss.item() + fake_loss.item()
        return {
            'loss': loss,
            'real_loss': real_loss.item(),
            'fake_loss': fake_loss.item()
        }

    polyak_test = Recipe(G_polyak_fun, range(1))
    polyak_test.callbacks.add_callbacks([tcb.Log('imgs', 'imgs')])
    loop = GANRecipe(G, D, G_fun, D_fun, G_polyak_fun, dl,
                     log_every=100).to(device)
    loop.G_loop.callbacks.add_callbacks([
        tcb.Optimizer(AdamW(G.parameters(), lr=1e-4, betas=(0., 0.99))),
    ])
    loop.register('G_polyak', G_polyak)
    loop.callbacks.add_callbacks([
        tcb.Polyak(G, G_polyak),
        tcb.CallRecipe(polyak_test, prefix='polyak'),
        tcb.Log('polyak_metrics.imgs', 'polyak_imgs'),
        tcb.Log('batch.0', 'x'),
        tcb.WindowedMetricAvg('real_loss'),
        tcb.WindowedMetricAvg('fake_loss'),
        tcb.Optimizer(AdamW(D.parameters(), lr=4e-4, betas=(0., 0.99))),
    ])
    loop.to(device).run(1)
Exemple #2
0
def TrainAndTest(model,
                 train_fun,
                 test_fun,
                 train_loader,
                 test_loader,
                 test_every=100,
                 visdom_env='main',
                 log_every=10,
                 checkpoint='model'):
    """
    Perform training and testing on datasets. The model is
    automatically checkpointed, VisdomLogger and StdoutLogger callbacks are
    also already provided. The gradients are disabled and the model is
    automatically set to evaluation mode for the evaluation procedure.

    Args:
        model (nn.Model): a model
        train_fun (Callabble): a function that takes a batch as a single
            argument, performs a training step and return a dict of values to
            populate the recipe's state.
        test_fun (Callable): a function taking a batch as a single argument
            then performs something to evaluate your model and returns a dict
            to populate the state.
        train_loader (DataLoader): Training set dataloader
        test_loader (DataLoader): Testing set dataloader
        test_every (int): testing frequency, in number of iterations (default:
            100)
        visdom_env (str): name of the visdom environment to use, or None for
            not using Visdom (default: None)
        log_every (int): logging frequency, in number of iterations (default:
            100)
        checkpoint (str): checkpointing path or None for no checkpointing
    """
    def eval_call(batch):
        model.eval()
        with torch.no_grad():
            out = test_fun(batch)
        model.train()
        return out

    train_loop = Recipe(train_fun, train_loader)
    train_loop.register('model', model)

    test_loop = Recipe(eval_call, test_loader)
    train_loop.test_loop = test_loop
    train_loop.register('test_loop', test_loop)

    def prepare_test(state):
        test_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    train_loop.callbacks.add_prologues([tcb.Counter()])
    train_loop.callbacks.add_epilogues([
        tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
    ])

    test_loop.callbacks.add_epilogues([
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=-1, prefix='test_'),
        tcb.StdoutLogger(log_every=-1, prefix='Test'),
    ])

    if checkpoint is not None:
        test_loop.callbacks.add_epilogues([
            tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth',
                           train_loop,
                           key_best=lambda state: -state['test_loop'][
                               'callbacks']['state']['metrics']['loss'])
        ])

    return train_loop
Exemple #3
0
def GANRecipe(G,
              D,
              G_fun,
              D_fun,
              loader,
              *,
              visdom_env='main',
              checkpoint='model',
              log_every=10):
    def D_wrap(batch):
        tu.freeze(G)
        G.eval()
        tu.unfreeze(D)
        D.train()

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        D.eval()
        tu.unfreeze(G)
        G.train()

        return G_fun(batch)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, range(1))
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    def prepare_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'),
        tcb.WindowedMetricAvg('loss'),
        tcb.Log('G_metrics.loss', 'G_loss'),
        tcb.Log('G_metrics.imgs', 'G_imgs'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop)
    ])

    if checkpoint is not None:
        D_loop.callbacks.add_epilogues(
            [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)])

    G_loop.callbacks.add_epilogues([
        tcb.Log('loss', 'loss'),
        tcb.Log('imgs', 'imgs'),
        tcb.WindowedMetricAvg('loss')
    ])

    return D_loop
Exemple #4
0
def GANRecipe(G: nn.Module,
              D: nn.Module,
              G_fun,
              D_fun,
              test_fun,
              loader: Iterable[Any],
              *,
              visdom_env: Optional[str] = 'main',
              checkpoint: Optional[str] = 'model',
              test_every: int = 1000,
              log_every: int = 10,
              g_every: int = 1) -> Recipe:
    def D_wrap(batch):
        tu.freeze(G)
        tu.unfreeze(D)

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        tu.unfreeze(G)

        return G_fun(batch)

    def test_wrap(batch):
        tu.freeze(G)
        tu.freeze(D)
        D.eval()
        G.eval()

        with torch.no_grad():
            out = test_fun(batch)

        D.train()
        G.train()
        return out

    class NoLim:
        def __init__(self):
            self.i = iter(loader)
            self.did_send = False

        def __iter__(self):
            return self

        def __next__(self):
            if self.did_send:
                self.did_send = False
                raise StopIteration
            self.did_send = True
            try:
                return next(self.i)
            except:
                self.i = iter(loader)
                return next(self.i)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, NoLim())
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    test_loop = Recipe(test_wrap, NoLim())
    D_loop.test_loop = test_loop
    D_loop.register('test_loop', test_loop)

    def G_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    def prepare_test(state):
        test_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.Log('imgs', 'G_imgs'),
        tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.CallRecipe(test_loop,
                       test_every,
                       init_fun=prepare_test,
                       prefix='Test'),
    ])

    G_loop.callbacks.add_epilogues([
        tcb.WindowedMetricAvg('G_loss'),
        tcb.VisdomLogger(visdom_env=visdom_env,
                         log_every=log_every,
                         post_epoch_ends=False)
    ])

    if checkpoint is not None:
        test_loop.callbacks.add_epilogues([
            tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop),
            tcb.VisdomLogger(visdom_env=visdom_env),
        ])

    return D_loop