Beispiel #1
0
def train(model, loader):
    def train_step(batch):
        x = batch[0]
        x = x.expand(-1, 3, -1, -1)

        x2 = model(x * 2 - 1)
        loss = F.cross_entropy(x2, (x * 255).long())
        loss.backward()
        reconstruction = x2.argmax(dim=1).float() / 255.0
        return {'loss': loss, 'reconstruction': reconstruction}

    def after_train():
        imgs = model.sample(1, 4).expand(-1, 3, -1, -1)
        return {'imgs': imgs}

    opt = RAdamW(model.parameters(), lr=3e-3)
    trainer = TrainAndCall(model,
                           train_step,
                           after_train,
                           dl,
                           test_every=500,
                           visdom_env='pixelcnn')
    trainer.callbacks.add_callbacks([
        tcb.WindowedMetricAvg('loss'),
        tcb.Log('reconstruction', 'reconstruction'),
        tcb.Optimizer(opt, log_lr=True),
        tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt))
    ])
    trainer.test_loop.callbacks.add_callbacks([
        tcb.Log('imgs', 'imgs'),
    ])

    trainer.to('cuda')
    trainer.run(10)
Beispiel #2
0
 def make_optimizer(self):
     return RAdamW(self.model.parameters(), lr=1e-2)
Beispiel #3
0
 def make_optimizer(self):
     return RAdamW(self.model.parameters(), lr=3e-3, betas=(0.95, 0.9995))
Beispiel #4
0
def StyleGAN2Recipe(G: nn.Module,
                    D: nn.Module,
                    dataloader,
                    noise_size: int,
                    gpu_id: int,
                    total_num_gpus: int,
                    *,
                    G_lr: float = 2e-3,
                    D_lr: float = 2e-3,
                    tag: str = 'model',
                    ada: bool = True):
    """
    StyleGAN2 Recipe distributed with DistributedDataParallel

    Args:
        G (nn.Module): a Generator.
        D (nn.Module): a Discriminator.
        dataloader: a dataloader conforming to torchvision's API.
        noise_size (int): the size of the input noise vector.
        gpu_id (int): the GPU index on which to run.
        total_num_gpus (int): how many GPUs are they
        G_lr (float): RAdamW lr for G
        D_lr (float): RAdamW lr for D
        tag (str): tag for Visdom and checkpoints
        ada (bool): whether to enable Adaptive Data Augmentation

    Returns:
        recipe, G EMA model
    """
    G_polyak = copy.copy(G)

    G = nn.parallel.DistributedDataParallel(G.to(gpu_id), [gpu_id], gpu_id)
    D = nn.parallel.DistributedDataParallel(D.to(gpu_id), [gpu_id], gpu_id)
    print(G)
    print(D)

    optG: torch.optim.Optimizer = RAdamW(G.parameters(),
                                         G_lr,
                                         betas=(0., 0.99),
                                         weight_decay=0)
    optD: torch.optim.Optimizer = RAdamW(D.parameters(),
                                         D_lr,
                                         betas=(0., 0.99),
                                         weight_decay=0)

    batch_size = len(next(iter(dataloader))[0])
    diffTF = ADATF(-2 if not ada else -0.9,
                   50000 / (batch_size * total_num_gpus))

    ppl = PPL(4)

    def G_train(batch):
        with G.no_sync():
            pl = ppl(G, torch.randn(batch_size, noise_size, device=gpu_id))
        ##############
        #   G pass   #
        ##############
        imgs = G(torch.randn(batch_size, noise_size, device=gpu_id))
        pred = D(diffTF(imgs) * 2 - 1)
        score = gan_loss.generated(pred)
        score.backward()

        return {'G_loss': score.item(), 'ppl': pl}

    gradient_penalty = GradientPenalty(0.1)

    def D_train(batch):
        ###################
        #    Fake pass    #
        ###################
        with D.no_sync():
            # Sync the gradient on the last backward
            noise = torch.randn(batch_size, noise_size, device=gpu_id)
            with torch.no_grad():
                fake = G(noise)
            fake.requires_grad_(True)
            fake.retain_grad()
            fake_tf = diffTF(fake) * 2 - 1
            fakeness = D(fake_tf).squeeze(1)
            fake_loss = gan_loss.fake(fakeness)
            fake_loss.backward()

            correct = (fakeness < 0).int().eq(1).float().sum()
        fake_grad = fake.grad.detach().norm(dim=1, keepdim=True)
        fake_grad /= fake_grad.max()

        tfmed = diffTF(batch[0]) * 2 - 1

        with D.no_sync():
            grad_norm = gradient_penalty(D, batch[0] * 2 - 1,
                                         fake.detach() * 2 - 1)

        ###################
        #    Real pass    #
        ###################
        real_out = D(tfmed)
        correct += (real_out > 0).detach().int().eq(1).float().sum()
        real_loss = gan_loss.real(real_out)
        real_loss.backward()
        pos_ratio = real_out.gt(0).float().mean().cpu().item()
        diffTF.log_loss(-pos_ratio)
        return {
            'imgs': fake.detach(),
            'i_grad': fake_grad,
            'loss': real_loss.item() + fake_loss.item(),
            'fake_loss': fake_loss.item(),
            'real_loss': real_loss.item(),
            'ADA-p': diffTF.p,
            'D-correct': correct / (2 * real_out.numel()),
            'grad_norm': grad_norm
        }

    tu.freeze(G_polyak)

    def test(batch):
        G_polyak.eval()

        def sample(N, n_iter, alpha=0.01, show_every=10):
            noise = torch.randn(N,
                                noise_size,
                                device=gpu_id,
                                requires_grad=True)
            opt = torch.optim.Adam([noise], lr=alpha)
            fakes = []
            for i in range(n_iter):
                noise += torch.randn_like(noise) / 10
                fake_batch = []
                opt.zero_grad()
                for j in range(0, N, batch_size):
                    with torch.enable_grad():
                        n_batch = noise[j:j + batch_size]
                        fake = G_polyak(n_batch, mixing=False)
                        fake_batch.append(fake)
                        log_prob = n_batch[:, 32:].pow(2).mul_(-0.5)
                        fakeness = -D(fake * 2 - 1).sum() - log_prob.sum()
                        fakeness.backward()
                opt.step()
                fake_batch = torch.cat(fake_batch, dim=0)

                if i % show_every == 0:
                    fakes.append(fake_batch.cpu().detach().clone())

            fakes.append(fake_batch.cpu().detach().clone())

            return torch.cat(fakes, dim=0)

        fake = sample(8, 50, alpha=0.001, show_every=10)

        noise1 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id)
        noise2 = torch.randn(batch_size * 2 // 8, 1, noise_size, device=gpu_id)
        t = torch.linspace(0, 1, 8, device=noise1.device).view(8, 1)
        noise = noise1 * t + noise2 * (1 - t)
        noise = noise.view(-1, noise_size)
        interp = torch.cat([
            G_polyak(n, mixing=False) for n in torch.split(noise, batch_size)
        ],
                           dim=0)
        return {
            'polyak_imgs': fake,
            'polyak_interp': interp,
        }

    recipe = GANRecipe(G,
                       D,
                       G_train,
                       D_train,
                       test,
                       dataloader,
                       visdom_env=tag if gpu_id == 0 else None,
                       log_every=10,
                       test_every=1000,
                       checkpoint=tag if gpu_id == 0 else None,
                       g_every=1)
    recipe.callbacks.add_callbacks([
        tcb.Log('batch.0', 'x'),
        tcb.WindowedMetricAvg('fake_loss'),
        tcb.WindowedMetricAvg('real_loss'),
        tcb.WindowedMetricAvg('grad_norm'),
        tcb.WindowedMetricAvg('ADA-p'),
        tcb.WindowedMetricAvg('D-correct'),
        tcb.Log('i_grad', 'img_grad'),
        tch.callbacks.Optimizer(optD),
    ])
    recipe.G_loop.callbacks.add_callbacks([
        tch.callbacks.Optimizer(optG),
        tcb.Polyak(G.module, G_polyak,
                   0.5**((batch_size * total_num_gpus) / 20000)),
        tcb.WindowedMetricAvg('ppl'),
    ])
    recipe.test_loop.callbacks.add_callbacks([
        tcb.Log('polyak_imgs', 'polyak'),
        tcb.Log('polyak_interp', 'interp'),
    ])
    recipe.register('G_polyak', G_polyak)
    recipe.to(gpu_id)
    return recipe, G_polyak
Beispiel #5
0
def CrossEntropyClassification(model,
                               train_loader,
                               test_loader,
                               classes,
                               lr=3e-3,
                               beta1=0.9,
                               wd=1e-2,
                               visdom_env='main',
                               test_every=1000,
                               log_every=100):
    """
    A Classification recipe with a default froward training / testing pass
    using cross entropy, and extended with RAdamW and ReduceLROnPlateau.

    Args:
        model (nn.Module): a model learnable with cross entropy
        train_loader (DataLoader): Training set dataloader
        test_loader (DataLoader): Testing set dataloader
        classes (list of str): classes name, in order
        lr (float): the learning rate
        beta1 (float): RAdamW's beta1
        wd (float): weight decay
        visdom_env (str): name of the visdom environment to use, or None for
            not using Visdom (default: None)
        test_every (int): testing frequency, in number of iterations (default:
            1000)
        log_every (int): logging frequency, in number of iterations (default:
            1000)
    """

    def train_step(batch):
        x, y = batch
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        loss.backward()
        return {'loss': loss, 'pred': pred}

    def validation_step(batch):
        x, y = batch
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        return {'loss': loss, 'pred': pred}

    loop = Classification(model,
                          train_step,
                          validation_step,
                          train_loader,
                          test_loader,
                          classes,
                          visdom_env=visdom_env,
                          test_every=test_every,
                          log_every=log_every)

    opt = RAdamW(model.parameters(),
                 lr=lr,
                 betas=(beta1, 0.999),
                 weight_decay=wd)
    loop.callbacks.add_callbacks([
        tcb.Optimizer(opt, log_lr=True),
        tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt))
    ])
    return loop
Beispiel #6
0
def MixupClassification(model,
                        train_loader,
                        test_loader,
                        classes,
                        *,
                        lr=3e-3,
                        beta1=0.9,
                        wd=1e-2,
                        visdom_env='main',
                        test_every=1000,
                        log_every=100):
    """
    A Classification recipe with a default froward training / testing pass
    using cross entropy and mixup, and extended with RAdamW and
    ReduceLROnPlateau.

    Args:
        model (nn.Module): a model learnable with cross entropy
        train_loader (DataLoader): Training set dataloader. Must have soft
            targets. Should be a DataLoader loading a MixupDataset or
            compatible.
        test_loader (DataLoader): Testing set dataloader. Dataset must have
            categorical targets.
        classes (list of str): classes name, in order
        lr (float): the learning rate
        beta1 (float): RAdamW's beta1
        wd (float): weight decay
        visdom_env (str): name of the visdom environment to use, or None for
            not using Visdom (default: None)
        test_every (int): testing frequency, in number of iterations (default:
            1000)
        log_every (int): logging frequency, in number of iterations (default:
            1000)
    """

    from torchelie.loss import continuous_cross_entropy

    def train_step(batch):
        x, y = batch
        pred = model(x)
        loss = continuous_cross_entropy(pred, y)
        loss.backward()
        return {'loss': loss}

    def validation_step(batch):
        x, y = batch
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        return {'loss': loss, 'pred': pred}

    loop = TrainAndTest(model,
                        train_step,
                        validation_step,
                        train_loader,
                        test_loader,
                        visdom_env=visdom_env,
                        test_every=test_every,
                        log_every=log_every)

    loop.callbacks.add_callbacks([
        tcb.WindowedMetricAvg('loss'),
    ])
    loop.register('classes', classes)

    loop.test_loop.callbacks.add_callbacks([
        tcb.AccAvg(post_each_batch=False),
        tcb.WindowedMetricAvg('loss', False),
    ])

    if visdom_env is not None:
        loop.callbacks.add_epilogues(
            [tcb.ImageGradientVis(),
             tcb.MetricsTable()])

    if len(classes) <= 25:
        loop.test_loop.callbacks.add_callbacks([
            tcb.ConfusionMatrix(classes),
        ])

    loop.test_loop.callbacks.add_callbacks([
        tcb.ClassificationInspector(30, classes, False),
        tcb.MetricsTable(False)
    ])

    opt = RAdamW(model.parameters(),
                 lr=lr,
                 betas=(beta1, 0.999),
                 weight_decay=wd)
    loop.callbacks.add_callbacks([
        tcb.Optimizer(opt, log_lr=True),
        tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt))
    ])
    return loop
Beispiel #7
0
def CrossEntropyClassification(model,
                               train_loader,
                               test_loader,
                               classes,
                               lr=3e-3,
                               beta1=0.9,
                               wd=1e-2,
                               visdom_env='main',
                               test_every=1000,
                               log_every=100):
    """
    Extends Classification with default cross entropy forward passes. Also adds
    RAdamW and ReduceLROnPlateau.

    Inherited training callbacks:

    - AccAvg for displaying accuracy
    - WindowedMetricAvg for displaying loss
    - ConfusionMatrix if len(classes) <= 25
    - ClassificationInspector
    - MetricsTable
    - ImageGradientVis
    - Counter for counting iterations, connected to the testing loop as well
    - VisdomLogger
    - StdoutLogger

    Training callbacks:

    - Optimizer with RAdamW
    - LRSched with ReduceLROnPlateau

    Testing:

    Testing loop is in :code:`.test_loop`.

    Inherited testing callbacks:

    - AccAvg
    - WindowedMetricAvg
    - ConfusionMatrix if :code:`len(classes) <= 25`
    - ClassificationInspector
    - MetricsTable
    - VisdomLogger
    - StdoutLogger
    - Checkpoint saving the best testing loss

    Args:
        model (nn.Module): a model learnable with cross entropy
        train_loader (DataLoader): Training set dataloader
        test_loader (DataLoader): Testing set dataloader
        classes (list of str): classes name, in order
        lr (float): the learning rate
        beta1 (float): RAdamW's beta1
        wd (float): weight decay
        visdom_env (str): name of the visdom environment to use, or None for
            not using Visdom (default: None)
        test_every (int): testing frequency, in number of iterations (default:
            1000)
        log_every (int): logging frequency, in number of iterations (default:
            1000)
    """
    def train_step(batch):
        x, y = batch
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        loss.backward()
        return {'loss': loss, 'pred': pred}

    def validation_step(batch):
        x, y = batch
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        return {'loss': loss, 'pred': pred}

    loop = Classification(model,
                          train_step,
                          validation_step,
                          train_loader,
                          test_loader,
                          classes,
                          visdom_env=visdom_env,
                          test_every=test_every,
                          log_every=log_every)

    opt = RAdamW(model.parameters(),
                 lr=lr,
                 betas=(beta1, 0.999),
                 weight_decay=wd)
    loop.callbacks.add_callbacks([
        tcb.Optimizer(opt, log_lr=True),
        tcb.LRSched(torch.optim.lr_scheduler.ReduceLROnPlateau(opt))
    ])
    return loop