class TestVQVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = VQVAE(3, 64, 512) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): print( sum(p.numel() for p in self.model.parameters() if p.requires_grad)) x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N=0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(8, 'cuda') print(y.shape) def test_generate(self): x = torch.randn(16, 3, 64, 64) y = self.model.generate(x) print(y.shape)
def train_CIFAR10(opt): import torchvision.datasets as datasets import torchvision.transforms as transforms from torchvision.utils import make_grid from matplotlib import pyplot as plt params = get_config(opt.config) save_path = os.path.join( params['save_path'], datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) os.makedirs(save_path, exist_ok=True) shutil.copy('models.py', os.path.join(save_path, 'models.py')) shutil.copy('train.py', os.path.join(save_path, 'train.py')) shutil.copy(opt.config, os.path.join(save_path, os.path.basename(opt.config))) cuda = torch.cuda.is_available() gpu_ids = [i for i in range(torch.cuda.device_count())] TensorType = torch.cuda.FloatTensor if cuda else torch.Tensor data_path = os.path.join(params['data_root'], 'cifar10') os.makedirs(data_path, exist_ok=True) train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) val_dataset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) train_loader = DataLoader(train_dataset, batch_size=params['batch_size'] * len(gpu_ids), shuffle=True, num_workers=params['num_workers'], pin_memory=cuda) val_loader = DataLoader(val_dataset, batch_size=1, num_workers=params['num_workers'], pin_memory=cuda) data_variance = np.var(train_dataset.train_data / 255.0) encoder = Encoder(params['dim'], params['residual_channels'], params['n_layers'], params['d']) decoder = Decoder(params['dim'], params['residual_channels'], params['n_layers'], params['d']) vq = VectorQuantizer(params['k'], params['d'], params['beta'], params['decay'], TensorType) if params['checkpoint'] != None: checkpoint = torch.load(params['checkpoint']) params['start_epoch'] = checkpoint['epoch'] encoder.load_state_dict(checkpoint['encoder']) decoder.load_state_dict(checkpoint['decoder']) vq.load_state_dict(checkpoint['vq']) model = VQVAE(encoder, decoder, vq) if cuda: model = nn.DataParallel(model.cuda(), device_ids=gpu_ids) parameters = list(model.parameters()) opt = torch.optim.Adam([p for p in parameters if p.requires_grad], lr=params['lr']) for epoch in range(params['start_epoch'], params['num_epochs']): train_bar = tqdm(train_loader) for data, _ in train_bar: if cuda: data = data.cuda() opt.zero_grad() vq_loss, data_recon, _ = model(data) recon_error = torch.mean((data_recon - data)**2) / data_variance loss = recon_error + vq_loss.mean() loss.backward() opt.step() train_bar.set_description('Epoch {}: loss {:.4f}'.format( epoch + 1, loss.mean().item())) model.eval() data_val = next(iter(val_loader)) data_val, _ = data_val if cuda: data_val = data_val.cuda() _, data_recon_val, _ = model(data_val) plt.imsave(os.path.join(save_path, 'latest_val_recon.png'), (make_grid(data_recon_val.cpu().data) + 0.5).numpy().transpose(1, 2, 0)) plt.imsave(os.path.join(save_path, 'latest_val_orig.png'), (make_grid(data_val.cpu().data) + 0.5).numpy().transpose( 1, 2, 0)) model.train() torch.save( { 'epoch': epoch, 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'vq': vq.state_dict(), }, os.path.join(save_path, '{}_checkpoint.pth'.format(epoch)))