def batch_apply(inputs, model, func=lambda x: x, batch_size=32, device=None): device = device or ('cuda' if CUDA else 'cpu') model.eval() model.to(device) if torch.is_tensor(inputs): inputs = (inputs,) if isinstance(inputs, Sequence) and all(torch.is_tensor(t) for t in inputs): it = batchify(inputs, batch_size=batch_size) else: transforms = Compose([ ToTensor() ]) ds = _ImageDataset(inputs, transforms) it = DataLoader(ds, batch_size=batch_size) preds = [] for batch in it: x = to_device(batch, device) if torch.is_tensor(x): x = (x,) with torch.no_grad(): p = func(model(*x)) preds.append(p) preds = torch.cat(preds, dim=0) return preds
def _update(engine, batch): inputs, targets = prepare_batch(batch, device=device) real_x = inputs[0] labels = targets[0] batch_size = real_x.size(0) unfreeze(D) D.train() optimizerD.zero_grad() real_p = D(real_x) real_cp = real_p[:, 1:] real_p = real_p[:, 0] lat = make_latent(batch_size) lat = to_device(lat, device) z = torch.cat([lat, one_hot(labels, num_classes)], dim=1) with torch.no_grad(): fake_x = G(z) fake_p = D(fake_x) fake_cp = fake_p[:, 1:] fake_p = fake_p[:, 0] lossD = criterionD(real_p, fake_p, real_cp, fake_cp, labels) lossD.backward() optimizerD.step() freeze(D) G.train() optimizerG.zero_grad() lat = make_latent(batch_size) lat = to_device(lat, device) z = torch.cat([lat, one_hot(labels, num_classes)], dim=1) fake_p = D(G(z)) fake_cp = fake_p[:, 1:] fake_p = fake_p[:, 0] lossG = criterionG(fake_p, fake_cp, labels) lossG.backward() optimizerG.step() output = { "lossD": lossD.item(), "lossG": lossG.item(), "batch_size": batch_size, } return output
def _update(engine, batch): inputs, targets = prepare_batch(batch, device=device) real_x = inputs[0] labels = targets[0] batch_size = real_x.size(0) unfreeze(D) D.train() optimizerD.zero_grad() real_p = D(real_x, labels) lat = make_latent(batch_size) lat = to_device(lat, device) with torch.no_grad(): fake_x = G(lat, labels) fake_p = D(fake_x, labels) lossD = criterionD(real_p, fake_p) lossD.backward() optimizerD.step() freeze(D) # D.eval() G.train() optimizerG.zero_grad() lat = make_latent(batch_size) lat = to_device(lat, device) fake_p = D(G(lat, labels), labels) lossG = criterionG(fake_p) lossG.backward() optimizerG.step() output = { "lossD": lossD.item(), "lossG": lossG.item(), "batch_size": batch_size, } return output
def save_generated(trainer, save_interval, fixed_inputs, sharpen=True): if trainer.iterations() % save_interval != 0: return import matplotlib.pyplot as plt trainer.G.eval() if torch.is_tensor(fixed_inputs): fixed_inputs = (fixed_inputs, ) fixed_inputs = to_device(fixed_inputs, trainer.device) with torch.no_grad(): fake_x = trainer.G(*fixed_inputs).cpu() trainer.G.train() img = np.transpose( make_grid(fake_x, padding=2, normalize=True).numpy(), (1, 2, 0)) if not sharpen: img = (img + 1) / 2 fp = trainer.save_path / "images" / ("%d.jpg" % trainer.iterations()) fp.parent.mkdir(exist_ok=True, parents=True) plt.imsave(fp, img)
def _update(engine, batch): inputs, _ = prepare_batch(batch, device=device) real_x = inputs[0] batch_size = real_x.size(0) D.q = False unfreeze(D.features) unfreeze(D.d_head) D.features.train() D.d_head.train() optimizerD.zero_grad() real_p = D(real_x) lat = make_latent(batch_size) lat = to_device(lat, device) # with torch.no_grad(): fake_x = G(lat) fake_p = D(fake_x.detach()) lossD = criterionD(real_p, fake_p) lossD.backward() optimizerD.step() D.q = True freeze(D.features) freeze(D.d_head) G.train() D.q_head.train() optimizerG.zero_grad() # lat = make_latent(batch_size) # lat = to_device(lat, device) fake_p, lat_p = D(fake_x) lossG = criterionG(fake_p, lat_p, lat) lossG.backward() optimizerG.step() output = { "lossD": lossD.item(), "lossG": lossG.item(), "batch_size": batch_size, } return output