Beispiel #1
0
def train_net(Gen, Discr):
    G = Gen(in_noise=128, out_ch=3)
    G_polyak = copy.deepcopy(G).eval()
    D = Discr()
    print(G)
    print(D)

    def G_fun(batch):
        z = torch.randn(BS, 128, device=device)
        fake = G(z)
        preds = D(fake * 2 - 1).squeeze()
        loss = gan_loss.generated(preds)
        loss.backward()
        return {'loss': loss.item(), 'imgs': fake.detach()}

    def G_polyak_fun(batch):
        z = torch.randn(BS, 128, device=device)
        fake = G_polyak(z)
        return {'imgs': fake.detach()}

    def D_fun(batch):
        z = torch.randn(BS, 128, device=device)
        fake = G(z)
        fake_loss = gan_loss.fake(D(fake * 2 - 1))
        fake_loss.backward()

        x = batch[0]

        real_loss = gan_loss.real(D(x * 2 - 1))
        real_loss.backward()

        loss = real_loss.item() + fake_loss.item()
        return {
            'loss': loss,
            'real_loss': real_loss.item(),
            'fake_loss': fake_loss.item()
        }

    loop = GANRecipe(G, D, G_fun, D_fun, G_polyak_fun, dl,
                     log_every=100).to(device)
    loop.register('polyak', G_polyak)
    loop.G_loop.callbacks.add_callbacks([
        tcb.Optimizer(
            tch.optim.RAdamW(G.parameters(), lr=1e-4, betas=(0., 0.99))),
        tcb.Polyak(G, G_polyak),
    ])
    loop.register('G_polyak', G_polyak)
    loop.callbacks.add_callbacks([
        tcb.Log('batch.0', 'x'),
        tcb.WindowedMetricAvg('real_loss'),
        tcb.WindowedMetricAvg('fake_loss'),
        tcb.Optimizer(
            tch.optim.RAdamW(D.parameters(), lr=4e-4, betas=(0., 0.99))),
    ])
    loop.test_loop.callbacks.add_callbacks([
        tcb.Log('imgs', 'polyak_imgs'),
        tcb.VisdomLogger('main', prefix='test')
    ])
    loop.to(device).run(100)
def test_tesorboard():
    from torchelie.recipes import Recipe

    batch_size = 4

    class Dataset:
        def __init__(self, batch_size):
            self.batch_size = batch_size
            self.mnist = FashionMNIST('.', download=True, transform=PILToTensor())
            self.classes = self.mnist.classes
            self.num_classes = len(self.mnist.class_to_idx)
            self.target_by_classes = [[idx for idx in range(len(self.mnist)) if self.mnist.targets[idx] == i]
                                      for i in range(self.num_classes)]

        def __len__(self):
            return self.batch_size * self.num_classes

        def __getitem__(self, item):
            idx = self.target_by_classes[item//self.batch_size][item]
            x, y = self.mnist[idx]
            x = torch.stack(3*[x]).squeeze()
            x[2] = 0
            return x, y

    dst = Dataset(batch_size)

    def train(b):
        x, y = b
        return {'letter_number_int':     int(y[0]),
                'letter_number_tensor':  y[0],
                'letter_text':  dst.classes[int(y[0])],
                'test_html':  '<b>test HTML</b>',
                'letter_gray_img_HW':   x[0, 0],
                'letter_gray_img_CHW':   x[0, :1],
                'letter_gray_imgs_NCHW':  x[:, :1],
                'letter_color_img_CHW':  x[0],
                'letter_color_imgs_NCHW': x}

    r = Recipe(train, DataLoader(dst, batch_size))
    r.callbacks.add_callbacks([
        tcb.Counter(),
        tcb.TensorboardLogger(log_every=1),
        tcb.Log('letter_number_int', 'letter_number_int'),
        tcb.Log('letter_number_tensor', 'letter_number_tensor'),
        tcb.Log('letter_text', 'letter_text'),
        tcb.Log('test_html', 'test_html'),
        tcb.Log('letter_gray_img_HW', 'letter_gray_img_HW'),
        tcb.Log('letter_gray_img_CHW', 'letter_gray_img_CHW'),
        tcb.Log('letter_gray_imgs_NCHW', 'letter_gray_imgs_NCHW'),
        tcb.Log('letter_color_img_CHW', 'letter_color_img_CHW'),
        tcb.Log('letter_color_imgs_NCHW', 'letter_color_imgs_NCHW'),
    ])
    r.run(1)
Beispiel #3
0
    def fit(self,
            iters,
            content_img,
            style_img,
            style_ratio,
            content_layers=None):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            content (PIL.Image): content image
            style (PIL.Image): style image
            ratio (float): weight of style loss
            content_layers (list of str): layers on which to reconstruct
                content
        """
        self.loss.to(self.device)
        self.loss.set_style(pil2t(style_img).to(self.device), style_ratio)
        self.loss.set_content(pil2t(content_img).to(self.device), content_layers)

        canvas = ParameterizedImg(3, content_img.height,
                                  content_img.width, init_sd=0.00)

        self.opt = tch.optim.RAdamW(canvas.parameters(), 1e-2, (0.7, 0.7),
                eps=0.00001, weight_decay=0)

        def forward(_):
            self.opt.zero_grad()
            img = canvas()
            loss, losses = self.loss(img)
            loss.backward()

            return {
                'loss': loss,
                'content_loss': losses['content_loss'],
                'style_loss': losses['style_loss'],
                'img': img
            }

        loop = Recipe(forward, range(iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.WindowedMetricAvg('loss'),
            tcb.WindowedMetricAvg('content_loss'),
            tcb.WindowedMetricAvg('style_loss'),
            tcb.Log('img', 'img'),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10),
            tcb.Optimizer(self.opt, log_lr=True),
            tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(self.opt,
                threshold=0.001, cooldown=500),
                step_each_batch=True)
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
Beispiel #4
0
def make_loop(hourglass, body, display, num_iter, lr):
    loop = TrainAndCall(hourglass,
                        body,
                        display,
                        range(num_iter),
                        test_every=50,
                        checkpoint=None)
    opt = tch.optim.RAdamW(hourglass.parameters(), lr=lr)
    loop.callbacks.add_callbacks([
        tcb.WindowedMetricAvg('loss'),
        tcb.Optimizer(opt, clip_grad_norm=0.5, log_lr=True),
    ])
    loop.test_loop.callbacks.add_callbacks([
        tcb.Log('recon', 'img'),
        tcb.Log('orig', 'orig'),
        tcb.Log('loss', 'loss'),
    ])
    return loop
Beispiel #5
0
    def fit(self, n_iters, neuron):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            neuron (int): the feature map to maximize

        Returns:
            the optimized image
        """
        canvas = ParameterizedImg(3, self.input_size + 10,
                                  self.input_size + 10)

        def forward(_):
            cim = canvas()
            rnd = random.randint(0, cim.shape[2] // 10)
            im = cim[:, :, rnd:, rnd:]
            im = torch.nn.functional.interpolate(im,
                                                 size=(self.input_size,
                                                       self.input_size),
                                                 mode='bilinear')
            _, acts = self.model(self.norm(im), detach=False)
            fmap = acts[self.layer]
            loss = -fmap[0][neuron].sum()
            loss.backward()

            return {'loss': loss, 'img': cim}

        loop = Recipe(forward, range(n_iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.Log('loss', 'loss'),
            tcb.Log('img', 'img'),
            tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=self.lr)),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10)
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
Beispiel #6
0
    def fit(self, ref, iters, lr=3e-4, device='cpu', visdom_env='deepdream'):
        """
        Args:
            lr (float, optional): the learning rate
            visdom_env (str or None): the name of the visdom env to use, or None
                to disable Visdom
        """
        ref_tensor = TF.ToTensor()(ref).unsqueeze(0)
        canvas = ParameterizedImg(1,
                                  3,
                                  ref_tensor.shape[2],
                                  ref_tensor.shape[3],
                                  init_img=ref_tensor,
                                  space='spectral',
                                  colors='uncorr')

        def forward(_):
            img = canvas()
            rnd = random.randint(0, 10)
            loss = self.loss(self.norm(img[:, :, rnd:, rnd:]))
            loss.backward()
            return {'loss': loss, 'img': img}

        loop = Recipe(forward, range(iters))
        loop.register('model', self)
        loop.register('canvas', canvas)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.Log('loss', 'loss'),
            tcb.Log('img', 'img'),
            tcb.Optimizer(DeepDreamOptim(canvas.parameters(), lr=lr)),
            tcb.VisdomLogger(visdom_env=visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10)
        ])
        loop.to(device)
        loop.run(1)
        return canvas.render().cpu()
Beispiel #7
0
def inpainting(img,
               mask,
               hourglass,
               input_dim,
               iters,
               lr,
               noise_std=1 / 30,
               device='cuda'):
    im = TFF.to_tensor(img)[None].to(device)
    mask = TFF.to_tensor(mask)[None].to(device)
    z = input_noise((im.shape[2], im.shape[3]), input_dim)
    z = z.to(device)
    print(hourglass)

    def body(batch):
        recon = hourglass(z + torch.randn_like(z) * noise_std)
        loss = torch.sum(
            F.mse_loss(F.interpolate(recon, size=im.shape[2:], mode='nearest'),
                       im,
                       reduction='none') * mask / mask.sum())
        loss.backward()
        return {"loss": loss}

    def display():
        recon = hourglass(z)
        recon = F.interpolate(recon, size=im.shape[2:], mode='nearest')
        loss = F.mse_loss(recon * mask, im)

        result = recon * (1 - mask) + im * mask
        return {
            "loss": loss,
            "recon": recon.clamp(0, 1),
            'orig': im,
            'result': result.clamp(0, 1)
        }

    loop = make_loop(hourglass, body, display, iters, lr)
    loop.test_loop.callbacks.add_callbacks([tcb.Log('result', 'result')])
    loop.to(device)
    loop.run(1)
    with torch.no_grad():
        hourglass.eval()
        return TFF.to_pil_image(hourglass(z)[0].cpu())
Beispiel #8
0
def GANRecipe(G,
              D,
              G_fun,
              D_fun,
              loader,
              *,
              visdom_env='main',
              checkpoint='model',
              log_every=10):
    def D_wrap(batch):
        tu.freeze(G)
        G.eval()
        tu.unfreeze(D)
        D.train()

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        D.eval()
        tu.unfreeze(G)
        G.train()

        return G_fun(batch)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, range(1))
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    def prepare_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.CallRecipe(G_loop, 1, init_fun=prepare_test, prefix='G'),
        tcb.WindowedMetricAvg('loss'),
        tcb.Log('G_metrics.loss', 'G_loss'),
        tcb.Log('G_metrics.imgs', 'G_imgs'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.Checkpoint((visdom_env or 'model') + '/ckpt_{iters}.pth', D_loop)
    ])

    if checkpoint is not None:
        D_loop.callbacks.add_epilogues(
            [tcb.Checkpoint(checkpoint + '/ckpt', D_loop)])

    G_loop.callbacks.add_epilogues([
        tcb.Log('loss', 'loss'),
        tcb.Log('imgs', 'imgs'),
        tcb.WindowedMetricAvg('loss')
    ])

    return D_loop
Beispiel #9
0
def StyleGAN2Recipe(G: nn.Module,
                    D: nn.Module,
                    dataloader,
                    noise_size: int,
                    gpu_id: int,
                    total_num_gpus: int,
                    *,
                    G_lr: float = 2e-3,
                    D_lr: float = 2e-3,
                    tag: str = 'model',
                    ada: bool = True):
    """
    StyleGAN2 Recipe distributed with DistributedDataParallel

    Args:
        G (nn.Module): a Generator.
        D (nn.Module): a Discriminator.
        dataloader: a dataloader conforming to torchvision's API.
        noise_size (int): the size of the input noise vector.
        gpu_id (int): the GPU index on which to run.
        total_num_gpus (int): how many GPUs are they
        G_lr (float): RAdamW lr for G
        D_lr (float): RAdamW lr for D
        tag (str): tag for Visdom and checkpoints
        ada (bool): whether to enable Adaptive Data Augmentation

    Returns:
        recipe, G EMA model
    """
    G_polyak = copy.copy(G)

    G = nn.parallel.DistributedDataParallel(G.to(gpu_id), [gpu_id], gpu_id)
    D = nn.parallel.DistributedDataParallel(D.to(gpu_id), [gpu_id], gpu_id)
    print(G)
    print(D)

    optG: torch.optim.Optimizer = RAdamW(G.parameters(),
                                         G_lr,
                                         betas=(0., 0.99),
                                         weight_decay=0)
    optD: torch.optim.Optimizer = RAdamW(D.parameters(),
                                         D_lr,
                                         betas=(0., 0.99),
                                         weight_decay=0)

    batch_size = len(next(iter(dataloader))[0])
    diffTF = ADATF(-2 if not ada else -0.9,
                   50000 / (batch_size * total_num_gpus))

    ppl = PPL(4)

    def G_train(batch):
        with G.no_sync():
            pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id))
        ##############
        #   G pass   #
        ##############
        imgs = G(torch.randn(batch_size, noise_size, device=gpu_id))
        pred = D(diffTF(imgs) * 2 - 1)
        score = gan_loss.generated(pred)
        score.backward()

        return {'G_loss': score.item(), 'ppl': pl}

    gradient_penalty = GradientPenalty(0.1)

    def D_train(batch):
        ###################
        #    Fake pass    #
        ###################
        with D.no_sync():
            # Sync the gradient on the last backward
            noise = torch.randn(batch_size, noise_size, device=gpu_id)
            with torch.no_grad():
                fake = G(noise)
            fake.requires_grad_(True)
            fake.retain_grad()
            fake_tf = diffTF(fake) * 2 - 1
            fakeness = D(fake_tf).squeeze(1)
            fake_loss = gan_loss.fake(fakeness)
            fake_loss.backward()

            correct = (fakeness < 0).int().eq(1).float().sum()
        fake_grad = fake.grad.detach().norm(dim=1, keepdim=True)
        fake_grad /= fake_grad.max()

        tfmed = diffTF(batch[0]) * 2 - 1

        with D.no_sync():
            grad_norm = gradient_penalty(D, batch[0] * 2 - 1,
                                         fake.detach() * 2 - 1)

        ###################
        #    Real pass    #
        ###################
        real_out = D(tfmed)
        correct += (real_out > 0).detach().int().eq(1).float().sum()
        real_loss = gan_loss.real(real_out)
        real_loss.backward()
        pos_ratio = real_out.gt(0).float().mean().cpu().item()
        diffTF.log_loss(-pos_ratio)
        return {
            'imgs': fake.detach(),
            'i_grad': fake_grad,
            'loss': real_loss.item() + fake_loss.item(),
            'fake_loss': fake_loss.item(),
            'real_loss': real_loss.item(),
            'ADA-p': diffTF.p,
            'D-correct': correct / (2 * real_out.numel()),
            'grad_norm': grad_norm
        }

    tu.freeze(G_polyak)

    def test(batch):
        G_polyak.eval()

        def sample(N, n_iter, alpha=0.01, show_every=10):
            noise = torch.randn(N,
                                noise_size,
                                device=gpu_id,
                                requires_grad=True)
            opt = torch.optim.Adam([noise], lr=alpha)
            fakes = []
            for i in range(n_iter):
                noise += torch.randn_like(noise) / 10
                fake_batch = []
                opt.zero_grad()
                for j in range(0, N, batch_size):
                    with torch.enable_grad():
                        n_batch = noise[j:j + batch_size]
                        fake = G_polyak(n_batch, mixing=False)
                        fake_batch.append(fake)
                        log_prob = n_batch[:, 32:].pow(2).mul_(-0.5)
                        fakeness = -D(fake * 2 - 1).sum() - log_prob.sum()
                        fakeness.backward()
                opt.step()
                fake_batch = torch.cat(fake_batch, dim=0)

                if i % show_every == 0:
                    fakes.append(fake_batch.cpu().detach().clone())

            fakes.append(fake_batch.cpu().detach().clone())

            return torch.cat(fakes, dim=0)

        fake = sample(8, 50, alpha=0.001, show_every=10)

        noise1 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id)
        noise2 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id)
        t = torch.linspace(0, 1, 8, device=noise1.device).view(8, 1)
        noise = noise1 * t + noise2 * (1 - t)
        noise = noise.view(-1, noise_size)
        interp = torch.cat([
            G_polyak(n, mixing=False) for n in torch.split(noise, batch_size)
        ],
                           dim=0)
        return {
            'polyak_imgs': fake,
            'polyak_interp': interp,
        }

    recipe = GANRecipe(G,
                       D,
                       G_train,
                       D_train,
                       test,
                       dataloader,
                       visdom_env=tag if gpu_id == 0 else None,
                       log_every=10,
                       test_every=1000,
                       checkpoint=tag if gpu_id == 0 else None,
                       g_every=1)
    recipe.callbacks.add_callbacks([
        tcb.Log('batch.0', 'x'),
        tcb.WindowedMetricAvg('fake_loss'),
        tcb.WindowedMetricAvg('real_loss'),
        tcb.WindowedMetricAvg('grad_norm'),
        tcb.WindowedMetricAvg('ADA-p'),
        tcb.WindowedMetricAvg('D-correct'),
        tcb.Log('i_grad', 'img_grad'),
        tch.callbacks.Optimizer(optD),
    ])
    recipe.G_loop.callbacks.add_callbacks([
        tch.callbacks.Optimizer(optG),
        tcb.Polyak(G.module, G_polyak,
                   0.5**((batch_size * total_num_gpus) / 20000)),
        tcb.WindowedMetricAvg('ppl'),
    ])
    recipe.test_loop.callbacks.add_callbacks([
        tcb.Log('polyak_imgs', 'polyak'),
        tcb.Log('polyak_interp', 'interp'),
    ])
    recipe.register('G_polyak', G_polyak)
    recipe.to(gpu_id)
    return recipe, G_polyak
Beispiel #10
0
    def fit(self,
            iters,
            content_img,
            style_img,
            style_ratio,
            content_layers=None):
        """
        Run the recipe

        Args:
            n_iters (int): number of iterations to run
            content (PIL.Image): content image
            style (PIL.Image): style image
            ratio (float): weight of style loss
            content_layers (list of str): layers on which to reconstruct
                content
        """
        self.loss.to(self.device)
        self.loss.set_style(pil2t(style_img).to(self.device), style_ratio)
        self.loss.set_content(
            pil2t(content_img).to(self.device), content_layers)

        self.loss2.to(self.device)
        self.loss2.set_style(
            torch.nn.functional.interpolate(pil2t(style_img)[None],
                                            scale_factor=0.5,
                                            mode='bilinear')[0].to(
                                                self.device), style_ratio)
        self.loss2.set_content(
            torch.nn.functional.interpolate(pil2t(content_img)[None],
                                            scale_factor=0.5,
                                            mode='bilinear')[0].to(
                                                self.device), content_layers)

        canvas = ParameterizedImg(3,
                                  content_img.height,
                                  content_img.width,
                                  init_img=pil2t(content_img))

        self.opt = tch.optim.RAdamW(canvas.parameters(), 3e-2)

        def forward(_):
            img = canvas()
            loss, losses = self.loss(img)
            loss.backward()
            loss, losses = self.loss2(
                torch.nn.functional.interpolate(canvas(),
                                                scale_factor=0.5,
                                                mode='bilinear'))
            loss.backward()

            return {
                'loss': loss,
                'content_loss': losses['content_loss'],
                'style_loss': losses['style_loss'],
                'img': img
            }

        loop = Recipe(forward, range(iters))
        loop.register('canvas', canvas)
        loop.register('model', self)
        loop.callbacks.add_callbacks([
            tcb.Counter(),
            tcb.WindowedMetricAvg('loss'),
            tcb.WindowedMetricAvg('content_loss'),
            tcb.WindowedMetricAvg('style_loss'),
            tcb.Log('img', 'img'),
            tcb.VisdomLogger(visdom_env=self.visdom_env, log_every=10),
            tcb.StdoutLogger(log_every=10),
            tcb.Optimizer(self.opt, log_lr=True),
        ])
        loop.to(self.device)
        loop.run(1)
        return canvas.render().cpu()
Beispiel #11
0
def GANRecipe(G: nn.Module,
              D: nn.Module,
              G_fun,
              D_fun,
              test_fun,
              loader: Iterable[Any],
              *,
              visdom_env: Optional[str] = 'main',
              checkpoint: Optional[str] = 'model',
              test_every: int = 1000,
              log_every: int = 10,
              g_every: int = 1) -> Recipe:
    def D_wrap(batch):
        tu.freeze(G)
        tu.unfreeze(D)

        return D_fun(batch)

    def G_wrap(batch):
        tu.freeze(D)
        tu.unfreeze(G)

        return G_fun(batch)

    def test_wrap(batch):
        tu.freeze(G)
        tu.freeze(D)
        D.eval()
        G.eval()

        with torch.no_grad():
            out = test_fun(batch)

        D.train()
        G.train()
        return out

    class NoLim:
        def __init__(self):
            self.i = iter(loader)
            self.did_send = False

        def __iter__(self):
            return self

        def __next__(self):
            if self.did_send:
                self.did_send = False
                raise StopIteration
            self.did_send = True
            try:
                return next(self.i)
            except:
                self.i = iter(loader)
                return next(self.i)

    D_loop = Recipe(D_wrap, loader)
    D_loop.register('G', G)
    D_loop.register('D', D)
    G_loop = Recipe(G_wrap, NoLim())
    D_loop.G_loop = G_loop
    D_loop.register('G_loop', G_loop)

    test_loop = Recipe(test_wrap, NoLim())
    D_loop.test_loop = test_loop
    D_loop.register('test_loop', test_loop)

    def G_test(state):
        G_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    def prepare_test(state):
        test_loop.callbacks.update_state({
            'epoch': state['epoch'],
            'iters': state['iters'],
            'epoch_batch': state['epoch_batch']
        })

    D_loop.callbacks.add_prologues([tcb.Counter()])

    D_loop.callbacks.add_epilogues([
        tcb.Log('imgs', 'G_imgs'),
        tcb.CallRecipe(G_loop, g_every, init_fun=G_test, prefix='G'),
        tcb.VisdomLogger(visdom_env=visdom_env, log_every=log_every),
        tcb.StdoutLogger(log_every=log_every),
        tcb.CallRecipe(test_loop,
                       test_every,
                       init_fun=prepare_test,
                       prefix='Test'),
    ])

    G_loop.callbacks.add_epilogues([
        tcb.WindowedMetricAvg('G_loss'),
        tcb.VisdomLogger(visdom_env=visdom_env,
                         log_every=log_every,
                         post_epoch_ends=False)
    ])

    if checkpoint is not None:
        test_loop.callbacks.add_epilogues([
            tcb.Checkpoint(checkpoint + '/ckpt_{iters}.pth', D_loop),
            tcb.VisdomLogger(visdom_env=visdom_env),
        ])

    return D_loop