Пример #1
0
def batch_apply(inputs, model, func=lambda x: x, batch_size=32, device=None):
    device = device or ('cuda' if CUDA else 'cpu')

    model.eval()
    model.to(device)

    if torch.is_tensor(inputs):
        inputs = (inputs,)

    if isinstance(inputs, Sequence) and all(torch.is_tensor(t) for t in inputs):
        it = batchify(inputs, batch_size=batch_size)
    else:
        transforms = Compose([
            ToTensor()
        ])
        ds = _ImageDataset(inputs, transforms)
        it = DataLoader(ds, batch_size=batch_size)

    preds = []
    for batch in it:
        x = to_device(batch, device)
        if torch.is_tensor(x):
            x = (x,)
        with torch.no_grad():
            p = func(model(*x))
        preds.append(p)
    preds = torch.cat(preds, dim=0)
    return preds
Пример #2
0
    def _update(engine, batch):
        inputs, targets = prepare_batch(batch, device=device)
        real_x = inputs[0]
        labels = targets[0]

        batch_size = real_x.size(0)

        unfreeze(D)
        D.train()
        optimizerD.zero_grad()

        real_p = D(real_x)
        real_cp = real_p[:, 1:]
        real_p = real_p[:, 0]

        lat = make_latent(batch_size)
        lat = to_device(lat, device)
        z = torch.cat([lat, one_hot(labels, num_classes)], dim=1)
        with torch.no_grad():
            fake_x = G(z)
        fake_p = D(fake_x)
        fake_cp = fake_p[:, 1:]
        fake_p = fake_p[:, 0]
        lossD = criterionD(real_p, fake_p, real_cp, fake_cp, labels)
        lossD.backward()
        optimizerD.step()

        freeze(D)
        G.train()
        optimizerG.zero_grad()

        lat = make_latent(batch_size)
        lat = to_device(lat, device)
        z = torch.cat([lat, one_hot(labels, num_classes)], dim=1)
        fake_p = D(G(z))
        fake_cp = fake_p[:, 1:]
        fake_p = fake_p[:, 0]
        lossG = criterionG(fake_p, fake_cp, labels)
        lossG.backward()
        optimizerG.step()

        output = {
            "lossD": lossD.item(),
            "lossG": lossG.item(),
            "batch_size": batch_size,
        }
        return output
Пример #3
0
    def _update(engine, batch):
        inputs, targets = prepare_batch(batch, device=device)
        real_x = inputs[0]
        labels = targets[0]

        batch_size = real_x.size(0)

        unfreeze(D)
        D.train()
        optimizerD.zero_grad()

        real_p = D(real_x, labels)

        lat = make_latent(batch_size)
        lat = to_device(lat, device)
        with torch.no_grad():
            fake_x = G(lat, labels)
        fake_p = D(fake_x, labels)
        lossD = criterionD(real_p, fake_p)
        lossD.backward()
        optimizerD.step()

        freeze(D)
        # D.eval()
        G.train()
        optimizerG.zero_grad()

        lat = make_latent(batch_size)
        lat = to_device(lat, device)
        fake_p = D(G(lat, labels), labels)
        lossG = criterionG(fake_p)
        lossG.backward()
        optimizerG.step()

        output = {
            "lossD": lossD.item(),
            "lossG": lossG.item(),
            "batch_size": batch_size,
        }
        return output
Пример #4
0
def save_generated(trainer, save_interval, fixed_inputs, sharpen=True):
    if trainer.iterations() % save_interval != 0:
        return
    import matplotlib.pyplot as plt
    trainer.G.eval()
    if torch.is_tensor(fixed_inputs):
        fixed_inputs = (fixed_inputs, )
    fixed_inputs = to_device(fixed_inputs, trainer.device)
    with torch.no_grad():
        fake_x = trainer.G(*fixed_inputs).cpu()
    trainer.G.train()
    img = np.transpose(
        make_grid(fake_x, padding=2, normalize=True).numpy(), (1, 2, 0))
    if not sharpen:
        img = (img + 1) / 2
    fp = trainer.save_path / "images" / ("%d.jpg" % trainer.iterations())
    fp.parent.mkdir(exist_ok=True, parents=True)
    plt.imsave(fp, img)
Пример #5
0
    def _update(engine, batch):
        inputs, _ = prepare_batch(batch, device=device)
        real_x = inputs[0]

        batch_size = real_x.size(0)

        D.q = False
        unfreeze(D.features)
        unfreeze(D.d_head)
        D.features.train()
        D.d_head.train()
        optimizerD.zero_grad()

        real_p = D(real_x)

        lat = make_latent(batch_size)
        lat = to_device(lat, device)
        # with torch.no_grad():
        fake_x = G(lat)
        fake_p = D(fake_x.detach())
        lossD = criterionD(real_p, fake_p)
        lossD.backward()
        optimizerD.step()

        D.q = True
        freeze(D.features)
        freeze(D.d_head)
        G.train()
        D.q_head.train()
        optimizerG.zero_grad()

        # lat = make_latent(batch_size)
        # lat = to_device(lat, device)
        fake_p, lat_p = D(fake_x)
        lossG = criterionG(fake_p, lat_p, lat)
        lossG.backward()
        optimizerG.step()

        output = {
            "lossD": lossD.item(),
            "lossG": lossG.item(),
            "batch_size": batch_size,
        }
        return output