def main(): os.makedirs("results/generated", exist_ok=True) os.makedirs("results/reconstructed", exist_ok=True) os.makedirs("results/checkpoints", exist_ok=True) root = os.path.join("data") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = tv.transforms.Compose([ tv.transforms.RandomAffine(0, translate=(5 / 96, 5 / 96), fillcolor=(255, 255, 255)), tv.transforms.ColorJitter(hue=0.5), tv.transforms.RandomHorizontalFlip(p=0.5), tv.transforms.ToTensor(), tv.transforms.Normalize(( 0.5, 0.5, 0.5, ), ( 0.5, 0.5, 0.5, )) ]) dataset = ImageFolder(root=root, transform=transform) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True) X = iter(dataloader) test_ims1, _ = next(X) test_ims2, _ = next(X) test_ims = torch.cat((test_ims1, test_ims2), 0) test_ims_show = tv.utils.make_grid( test_ims[:36], normalize=True, nrow=6, ) test_ims_show = test_ims_show.numpy().transpose((1, 2, 0)) test_ims_show = np.array(test_ims_show * 255, dtype=np.uint8) image = Image.fromarray(test_ims_show) image.save("results/reconstructed/test_images.png") noise_fn = lambda x: torch.randn((x, LATENT_DIM), device=device) test_noise = noise_fn(36) gan = AEGAN( LATENT_DIM, noise_fn, dataloader, device=device, batch_size=BATCH_SIZE, ) gan.generator.load_state_dict( torch.load(os.path.join("results", "checkpoints", "gen.01049.pt"))) gan.encoder.load_state_dict( torch.load(os.path.join("results", "checkpoints", "enc.01049.pt"))) gan.discriminator_image.load_state_dict( torch.load(os.path.join("results", "checkpoints", "dis_i.01049.pt"))) gan.discriminator_latent.load_state_dict( torch.load(os.path.join("results", "checkpoints", "dis_l.01049.pt"))) start = time.time() for i in range(EPOCHS): while True: try: with open("pause.json") as f: pause = json.load(f) if pause['pause'] == 0: break print(f"Pausing for {pause['pause']} seconds") time.sleep(pause["pause"]) except (KeyError, json.decoder.JSONDecodeError, FileNotFoundError): break elapsed = int(time.time() - start) elapsed = f"{elapsed // 3600:02d}:{(elapsed % 3600) // 60:02d}:{elapsed % 60:02d}" print(f"Epoch {i+1}; Elapsed time = {elapsed}s") gan.train_epoch(max_steps=100) if (i + 1) % 50 == 0: torch.save( gan.generator.state_dict(), os.path.join("results", "checkpoints", f"gen.{i:05d}.pt")) torch.save( gan.discriminator_image.state_dict(), os.path.join("results", "checkpoints", f"dis_i.{i:05d}.pt")) torch.save( gan.discriminator_latent.state_dict(), os.path.join("results", "checkpoints", f"dis_l.{i:05d}.pt")) torch.save( gan.encoder.state_dict(), os.path.join("results", "checkpoints", f"enc.{i:05d}.pt")) save_images(gan, test_noise, os.path.join("results", "generated", f"gen.{i:04d}.png")) with torch.no_grad(): reconstructed = gan.generator(gan.encoder(test_ims.cuda())).cpu() reconstructed = tv.utils.make_grid( reconstructed[:36], normalize=True, nrow=6, ) reconstructed = reconstructed.numpy().transpose((1, 2, 0)) reconstructed = np.array(reconstructed * 255, dtype=np.uint8) reconstructed = Image.fromarray(reconstructed) reconstructed.save( os.path.join("results", "reconstructed", f"gen.{i:04d}.png")) images = gan.generate_samples() ims = tv.utils.make_grid(images, normalize=True) plt.imshow(ims.numpy().transpose((1, 2, 0))) plt.show()
def main(): os.makedirs("results/generated", exist_ok=True) os.makedirs("results/reconstructed", exist_ok=True) os.makedirs("results/checkpoints", exist_ok=True) root = os.path.join("data") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = tv.transforms.Compose([ tv.transforms.RandomAffine(0, translate=(5 / 96, 5 / 96), fill=(255, 255, 255)), tv.transforms.ColorJitter(hue=0.5), tv.transforms.RandomHorizontalFlip(p=0.5), tv.transforms.ToTensor(), tv.transforms.Normalize(( 0.5, 0.5, 0.5, ), ( 0.5, 0.5, 0.5, )) ]) dataset = ImageFolder(root=root, transform=Transform2Times(transform)) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True) X = iter(dataloader) [test_ims, _], _ = next(X) while len(test_ims) < 36: [test_ims2, _], _ = next(X) test_ims = torch.cat((test_ims, test_ims2), 0) test_ims_show = tv.utils.make_grid( test_ims[:36], normalize=True, nrow=6, ) test_ims_show = test_ims_show.numpy().transpose((1, 2, 0)) test_ims_show = np.array(test_ims_show * 255, dtype=np.uint8) image = Image.fromarray(test_ims_show) image.save("results/reconstructed/test_images.png") noise_fn = lambda x: torch.randn((x, LATENT_DIM), device=device) test_noise = noise_fn(36) gan = AEGAN(LATENT_DIM, noise_fn, dataloader, device=device, batch_size=BATCH_SIZE, checkpoints_dir=CHECKPOINTS_DIR) cache = ImageCache(BATCH_SIZE * 4 * 1024) last_epoch = load_checkpoint(gan) start = time.time() for i in range(last_epoch + 1, EPOCHS): elapsed = int(time.time() - start) elapsed = f"{elapsed // 3600:02d}:{(elapsed % 3600) // 60:02d}:{elapsed % 60:02d}" print(f"Epoch {i}; Elapsed time = {elapsed}s") gan.train_epoch(cache) if i > 0 and i % CHECKPOINTS_PERIOD == 0: torch.save(gan.state_dict(), os.path.join(CHECKPOINTS_DIR, f"gen.{i:05d}.pt")) if i % SAVE_IMAGES_PERIOD == 0: save_images( gan, test_noise, os.path.join("results", "generated", f"gen.{i:05d}.png")) with torch.no_grad(): reconstructed = gan.generator(gan.encoder( test_ims.cuda())).cpu() reconstructed = tv.utils.make_grid( reconstructed[:36], normalize=True, nrow=6, ) reconstructed = reconstructed.numpy().transpose((1, 2, 0)) reconstructed = np.array(reconstructed * 255, dtype=np.uint8) reconstructed = Image.fromarray(reconstructed) reconstructed.save( os.path.join("results", "reconstructed", f"gen.{i:05d}.png")) images = gan.generate_samples() ims = tv.utils.make_grid(images, normalize=True) plt.imshow(ims.numpy().transpose((1, 2, 0))) plt.show()