コード例 #1
0
class TestVQVAE(unittest.TestCase):
    def setUp(self) -> None:
        # self.model2 = VAE(3, 10)
        self.model = VQVAE(3, 64, 512)

    def test_summary(self):
        print(summary(self.model, (3, 64, 64), device='cpu'))
        # print(summary(self.model2, (3, 64, 64), device='cpu'))

    def test_forward(self):
        print(
            sum(p.numel() for p in self.model.parameters() if p.requires_grad))
        x = torch.randn(16, 3, 64, 64)
        y = self.model(x)
        print("Model Output size:", y[0].size())
        # print("Model2 Output size:", self.model2(x)[0].size())

    def test_loss(self):
        x = torch.randn(16, 3, 64, 64)

        result = self.model(x)
        loss = self.model.loss_function(*result, M_N=0.005)
        print(loss)

    def test_sample(self):
        self.model.cuda()
        y = self.model.sample(8, 'cuda')
        print(y.shape)

    def test_generate(self):
        x = torch.randn(16, 3, 64, 64)
        y = self.model.generate(x)
        print(y.shape)
コード例 #2
0
ファイル: train.py プロジェクト: HankABC/VQVAE-pytorch
def train_CIFAR10(opt):

    import torchvision.datasets as datasets
    import torchvision.transforms as transforms
    from torchvision.utils import make_grid
    from matplotlib import pyplot as plt
    params = get_config(opt.config)

    save_path = os.path.join(
        params['save_path'],
        datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(save_path, exist_ok=True)
    shutil.copy('models.py', os.path.join(save_path, 'models.py'))
    shutil.copy('train.py', os.path.join(save_path, 'train.py'))
    shutil.copy(opt.config,
                os.path.join(save_path, os.path.basename(opt.config)))

    cuda = torch.cuda.is_available()
    gpu_ids = [i for i in range(torch.cuda.device_count())]

    TensorType = torch.cuda.FloatTensor if cuda else torch.Tensor

    data_path = os.path.join(params['data_root'], 'cifar10')

    os.makedirs(data_path, exist_ok=True)

    train_dataset = datasets.CIFAR10(root=data_path,
                                     train=True,
                                     download=True,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    val_dataset = datasets.CIFAR10(root=data_path,
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]))

    train_loader = DataLoader(train_dataset,
                              batch_size=params['batch_size'] * len(gpu_ids),
                              shuffle=True,
                              num_workers=params['num_workers'],
                              pin_memory=cuda)
    val_loader = DataLoader(val_dataset,
                            batch_size=1,
                            num_workers=params['num_workers'],
                            pin_memory=cuda)

    data_variance = np.var(train_dataset.train_data / 255.0)

    encoder = Encoder(params['dim'], params['residual_channels'],
                      params['n_layers'], params['d'])
    decoder = Decoder(params['dim'], params['residual_channels'],
                      params['n_layers'], params['d'])

    vq = VectorQuantizer(params['k'], params['d'], params['beta'],
                         params['decay'], TensorType)

    if params['checkpoint'] != None:
        checkpoint = torch.load(params['checkpoint'])

        params['start_epoch'] = checkpoint['epoch']
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        vq.load_state_dict(checkpoint['vq'])

    model = VQVAE(encoder, decoder, vq)

    if cuda:
        model = nn.DataParallel(model.cuda(), device_ids=gpu_ids)

    parameters = list(model.parameters())
    opt = torch.optim.Adam([p for p in parameters if p.requires_grad],
                           lr=params['lr'])

    for epoch in range(params['start_epoch'], params['num_epochs']):
        train_bar = tqdm(train_loader)
        for data, _ in train_bar:
            if cuda:
                data = data.cuda()
            opt.zero_grad()

            vq_loss, data_recon, _ = model(data)
            recon_error = torch.mean((data_recon - data)**2) / data_variance
            loss = recon_error + vq_loss.mean()
            loss.backward()
            opt.step()

            train_bar.set_description('Epoch {}: loss {:.4f}'.format(
                epoch + 1,
                loss.mean().item()))

        model.eval()
        data_val = next(iter(val_loader))
        data_val, _ = data_val

        if cuda:
            data_val = data_val.cuda()
        _, data_recon_val, _ = model(data_val)

        plt.imsave(os.path.join(save_path, 'latest_val_recon.png'),
                   (make_grid(data_recon_val.cpu().data) +
                    0.5).numpy().transpose(1, 2, 0))
        plt.imsave(os.path.join(save_path, 'latest_val_orig.png'),
                   (make_grid(data_val.cpu().data) + 0.5).numpy().transpose(
                       1, 2, 0))

        model.train()

        torch.save(
            {
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'vq': vq.state_dict(),
            }, os.path.join(save_path, '{}_checkpoint.pth'.format(epoch)))