def train_net(Gen, Discr): G = Gen(in_noise=128, out_ch=3) G_polyak = copy.deepcopy(G).eval() D = Discr(in_ch=3, out_ch=1) def G_fun(batch): z = torch.randn(BS, 128, device=device) fake = G(z) preds = D(fake * 2 - 1).squeeze() loss = gan_loss.generated(preds) loss.backward() return {'loss': loss.item(), 'imgs': fake.detach()} def G_polyak_fun(batch): print('POLYAK') z = torch.randn(BS, 128, device=device) fake = G_polyak(z) return {'imgs': fake.detach()} def D_fun(batch): z = torch.randn(BS, 128, device=device) fake = G(z) fake_loss = gan_loss.fake(D(fake * 2 - 1)) fake_loss.backward() x = batch[0] x = x.expand(-1, 3, -1, -1) real_loss = gan_loss.real(D(x * 2 - 1)) real_loss.backward() loss = real_loss.item() + fake_loss.item() return { 'loss': loss, 'real_loss': real_loss.item(), 'fake_loss': fake_loss.item() } polyak_test = Recipe(G_polyak_fun, range(1)) polyak_test.callbacks.add_callbacks([tcb.Log('imgs', 'imgs')]) loop = GANRecipe(G, D, G_fun, D_fun, G_polyak_fun, dl, log_every=100).to(device) loop.G_loop.callbacks.add_callbacks([ tcb.Optimizer(AdamW(G.parameters(), lr=1e-4, betas=(0., 0.99))), ]) loop.register('G_polyak', G_polyak) loop.callbacks.add_callbacks([ tcb.Polyak(G, G_polyak), tcb.CallRecipe(polyak_test, prefix='polyak'), tcb.Log('polyak_metrics.imgs', 'polyak_imgs'), tcb.Log('batch.0', 'x'), tcb.WindowedMetricAvg('real_loss'), tcb.WindowedMetricAvg('fake_loss'), tcb.Optimizer(AdamW(D.parameters(), lr=4e-4, betas=(0., 0.99))), ]) loop.to(device).run(1)
def TrainAndTest(model, train_fun, test_fun, train_loader, test_loader, test_every=100, visdom_env='main', log_every=10, checkpoint='model'): """ Perform training and testing on datasets. The model is automatically checkpointed, VisdomLogger and StdoutLogger callbacks are also already provided. The gradients are disabled and the model is automatically set to evaluation mode for the evaluation procedure. Args: model (nn.Model): a model train_fun (Callabble): a function that takes a batch as a single argument, performs a training step and return a dict of values to populate the recipe's state. test_fun (Callable): a function taking a batch as a single argument then performs something to evaluate your model and returns a dict to populate the state. train_loader (DataLoader): Training set dataloader test_loader (DataLoader): Testing set dataloader test_every (int): testing frequency, in number of iterations (default: 100) visdom_env (str): name of the visdom environment to use, or None for not using Visdom (default: None) log_every (int): logging frequency, in number of iterations (default: 100) checkpoint (str): checkpointing path or None for no checkpointing """ def eval_call(batch): model.eval() with torch.no_grad(): out = test_fun(batch) model.train() return out train_loop = Recipe(train_fun, train_loader) train_loop.register('model', model) test_loop = Recipe(eval_call, test_loader) train_loop.test_loop = test_loop train_loop.register('test_loop', test_loop) def prepare_test(state): test_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) train_loop.callbacks.add_prologues([tcb.Counter()]) train_loop.callbacks.add_epilogues([ tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), ]) test_loop.callbacks.add_epilogues([ tcb.VisdomLogger(visdom_env=visdom_env, log_every=-1, prefix='test_'), tcb.StdoutLogger(log_every=-1, prefix='Test'), ]) if checkpoint is not None: test_loop.callbacks.add_epilogues([ tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', train_loop, key_best=lambda state: -state['test_loop'][ 'callbacks']['state']['metrics']['loss']) ]) return train_loop
def GANRecipe(G, D, G_fun, D_fun, loader, *, visdom_env='main', checkpoint='model', log_every=10): def D_wrap(batch): tu.freeze(G) G.eval() tu.unfreeze(D) D.train() return D_fun(batch) def G_wrap(batch): tu.freeze(D) D.eval() tu.unfreeze(G) G.train() return G_fun(batch) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, range(1)) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) def prepare_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'), tcb.WindowedMetricAvg('loss'), tcb.Log('G_metrics.loss', 'G_loss'), tcb.Log('G_metrics.imgs', 'G_imgs'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop) ]) if checkpoint is not None: D_loop.callbacks.add_epilogues( [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)]) G_loop.callbacks.add_epilogues([ tcb.Log('loss', 'loss'), tcb.Log('imgs', 'imgs'), tcb.WindowedMetricAvg('loss') ]) return D_loop
def GANRecipe(G: nn.Module, D: nn.Module, G_fun, D_fun, test_fun, loader: Iterable[Any], *, visdom_env: Optional[str] = 'main', checkpoint: Optional[str] = 'model', test_every: int = 1000, log_every: int = 10, g_every: int = 1) -> Recipe: def D_wrap(batch): tu.freeze(G) tu.unfreeze(D) return D_fun(batch) def G_wrap(batch): tu.freeze(D) tu.unfreeze(G) return G_fun(batch) def test_wrap(batch): tu.freeze(G) tu.freeze(D) D.eval() G.eval() with torch.no_grad(): out = test_fun(batch) D.train() G.train() return out class NoLim: def __init__(self): self.i = iter(loader) self.did_send = False def __iter__(self): return self def __next__(self): if self.did_send: self.did_send = False raise StopIteration self.did_send = True try: return next(self.i) except: self.i = iter(loader) return next(self.i) D_loop = Recipe(D_wrap, loader) D_loop.register('G', G) D_loop.register('D', D) G_loop = Recipe(G_wrap, NoLim()) D_loop.G_loop = G_loop D_loop.register('G_loop', G_loop) test_loop = Recipe(test_wrap, NoLim()) D_loop.test_loop = test_loop D_loop.register('test_loop', test_loop) def G_test(state): G_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) def prepare_test(state): test_loop.callbacks.update_state({ 'epoch': state['epoch'], 'iters': state['iters'], 'epoch_batch': state['epoch_batch'] }) D_loop.callbacks.add_prologues([tcb.Counter()]) D_loop.callbacks.add_epilogues([ tcb.Log('imgs', 'G_imgs'), tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every), tcb.StdoutLogger(log_every=log_every), tcb.CallRecipe(test_loop, test_every, init_fun=prepare_test, prefix='Test'), ]) G_loop.callbacks.add_epilogues([ tcb.WindowedMetricAvg('G_loss'), tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every, post_epoch_ends=False) ]) if checkpoint is not None: test_loop.callbacks.add_epilogues([ tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop), tcb.VisdomLogger(visdom_env=visdom_env), ]) return D_loop