コード例 #1
0
ファイル: fixedrunner.py プロジェクト: jgalle29/gait
    def post_epoch_visualize(self, epoch, split):
        # if self.flags.visualize_only:
        #     self.do_plots()
        # else:
        if True:
            print('* Visualizing', split)
            Z = torch.linspace(0.0 + 1e-3, 1.0 - 1e-3, steps=20)
            Z = torch.cartesian_prod(Z, Z).view(20, 20, 2)
            if self.flags.z_size == 2 and self.flags.normal_latent:
                dist = normal.Normal(0.0, 1.0)
            x_gens = []
            for row in range(20):
                if self.flags.z_size == 2:
                    z = Z[row]
                    if self.flags.normal_latent:
                        z = dist.icdf(z)
                else:
                    if self.flags.normal_latent:
                        z = torch.randn(20, self.flags.z_size)
                    else:
                        z = torch.rand(20, self.flags.z_size)
                z = self.model.prepare_batch(z)
                x_gen = self.model.run_batch([z]).view(20, self.img_chan, self.img_size, self.img_size).detach().cpu()
                x_gens.append(x_gen)

            x_full = torch.cat(x_gens, dim=0).numpy()
            if split == 'test':
                fname = self.flags.log_dir + '/test.png'
            else:
                fname = self.flags.log_dir + '/vis_%03d.png' % self.model.get_train_steps()
            misc.save_comparison_grid(fname, x_full, border_width=0, retain_sequence=True)
            print('* Visualizations saved to', fname)
コード例 #2
0
    def post_epoch_visualize(self, epoch, split):
        print('* Visualizing', split)
        fname = self.flags.log_dir + '/vis_{}_{}.png'.format(
            self.model.get_train_steps(), split)
        if split == 'train':
            z = torch.randn(self.flags.batch_size, self.flags.latent_size, 1,
                            1)
            z = self.model.prepare_batch(z)
            x = self.model.run_batch([None, z], visualize=True)
            vis_data = x.cpu().numpy()
            aspect = 1.0
        else:
            batch = next(
                self.reader.iter_batches(split,
                                         self.batch_size,
                                         shuffle=True,
                                         partial_batching=True,
                                         threads=self.threads,
                                         max_batches=1))
            batch = self.model.prepare_batch(batch[0])
            x = self.model.run_batch([batch, None], visualize=True)
            y = batch.cpu().numpy()[None, ...]
            x = x.cpu().numpy()[None, ...]
            vis_data = np.concatenate([y, x], axis=0)
            vis_data = np.swapaxes(vis_data, 0, 1).reshape(-1, *x.shape[2:])
            aspect = 2.0

        misc.save_comparison_grid(fname,
                                  vis_data,
                                  border_shade=0.5,
                                  retain_sequence=True,
                                  desired_aspect=aspect)
        print('* Visualizations saved to', fname)
コード例 #3
0
ファイル: tdvaerunner.py プロジェクト: ankitkv/TD-VAE
    def post_epoch_visualize(self, epoch, split):
        if split != 'train':
            print('* Visualizing', split)
            vis_data, rows_cols = self._visualize_split(
                split, min(self.flags.seq_len - 1, 10), 5)
            if split == 'test':
                fname = self.flags.log_dir + '/test.png'
            else:
                fname = self.flags.log_dir + '/val%03d.png' % epoch
            misc.save_comparison_grid(fname,
                                      vis_data,
                                      border_shade=1.0,
                                      rows_cols=rows_cols,
                                      retain_sequence=True)
            print('* Visualizations saved to', fname)

        if split == 'test':
            print('* Generating more visualizations for', split)
            vis_data, rows_cols = self._visualize_split(split, 0, 15)
            fname = self.flags.log_dir + '/test_more.png'
            misc.save_comparison_grid(fname,
                                      vis_data,
                                      border_shade=1.0,
                                      rows_cols=rows_cols,
                                      retain_sequence=True)
            print('* More visualizations saved to', fname)
コード例 #4
0
 def post_epoch_visualize(self, epoch, split):
     print('* Visualizing', split)
     out = self._visualize_split(split, self.history_length, 5)
     if out is not None:
         vis_data, aspect = out
         fname = self.flags.log_dir + '/{}'.format(split) + '%03d.png' % epoch
         misc.save_comparison_grid(fname, vis_data, rows_cols=aspect, border_shade=1.0, retain_sequence=True)
         print('* Visualizations saved to', fname)
     else:
         print('* Visualization skipped')
コード例 #5
0
 def post_epoch_visualize(self, epoch, split):
     print('* Visualizing', split)
     length = min(10, self.flags.seq_len - 1)
     n = min(5, self.flags.seq_len - length)
     vis_data, aspect = self._visualize_split(split, length, 5)
     fname = self.flags.log_dir + '/{}'.format(split) + '%03d.png' % epoch
     misc.save_comparison_grid(fname,
                               vis_data,
                               rows_cols=aspect,
                               border_shade=1.0,
                               retain_sequence=True)
     print('* Visualizations saved to', fname)
コード例 #6
0
    def post_epoch_visualize(self, epoch, split):
        print('* Visualizing', split)
        Z = torch.linspace(0.0, 1.0, steps=20)
        Z = torch.cartesian_prod(Z, Z).view(20, 20, 2)
        x_gens = []
        for row in range(20):
            if self.flags.z_size == 2:
                z = Z[row]
            else:
                z = torch.rand(20, self.flags.z_size)
            z = self.model.prepare_batch(z)
            x_gen = self.model.run_batch([z], visualize=True).detach().cpu()
            x_gens.append(x_gen)

        x_full = torch.cat(x_gens, dim=0).numpy()
        if split == 'test':
            fname = self.flags.log_dir + '/test.png'
        else:
            fname = self.flags.log_dir + '/vis_%03d.png' % self.model.get_train_steps()
        misc.save_comparison_grid(fname, x_full, border_width=0, retain_sequence=True)
        print('* Visualizations saved to', fname)
コード例 #7
0
ファイル: test_reader.py プロジェクト: ankitkv/TD-VAE
from pylego.misc import save_comparison_grid

from readers.moving_mnist import MovingMNISTReader

if __name__ == '__main__':
    reader = MovingMNISTReader('data/MNIST')
    for i, batch in enumerate(reader.iter_batches('train', 4, max_batches=5)):
        print(batch.size())
        if i < 3:
            batch = batch[0].numpy().reshape(20, 1, 28, 28)
            save_comparison_grid('seq%d.png' % i, batch)
コード例 #8
0
if __name__ == '__main__':
    emulator = GymReader('Seaquest-v0', 6, 4, 2, 100)
    reader = ReplayBuffer(emulator, 5000, 100, 1, 4, 0.99)

    print('EMULATOR:')
    for i, batch in enumerate(
            emulator.iter_batches('train', 4, max_batches=5, threads=2)):
        obs, actions, rewards, done = batch.get_next()[:4]
        print('obs', obs.size())
        print('actions', actions.shape)
        print('rewards', rewards.shape)
        print()
        if i < 3:
            batch = obs.numpy().reshape(obs.shape[0] * obs.shape[1], 3, 80, 80)
            save_comparison_grid('eseq%d.png' % i,
                                 batch,
                                 rows_cols=obs.shape[:2],
                                 retain_sequence=True)
            print(actions)
            print()

    print('REPLAY BUFFER READER:')
    for i, batch in enumerate(reader.iter_batches('train', 4, max_batches=5)):
        obs, actions, rewards, done, t1, t2, returns, is_weight, idx = batch
        print('obs', obs.size())
        print('actions', actions.shape)
        print('rewards', rewards.shape)
        print('is_weight', is_weight.shape)
        print()
        if i < 3:
            batch = obs.numpy().reshape(4 * 6, 3, 80, 80)
            save_comparison_grid('rseq%d.png' % i,
コード例 #9
0
ファイル: test_reader.py プロジェクト: ankitkv/ALI
from pylego.misc import save_comparison_grid

from readers.cifar10 import CIFAR10Reader

if __name__ == '__main__':
    reader = CIFAR10Reader('data')
    for i, batch in enumerate(reader.iter_batches('train', 256,
                                                  max_batches=5)):
        img = batch[0].numpy()
        print(img.shape, img.min(), img.max())
        save_comparison_grid('seq%d.png' % i, img, border_shade=0.75)
コード例 #10
0
        with open(batches_fname, 'wb') as f:
            pickle.dump(batches, f)
    else:
        with open(batches_fname, 'rb') as f:
            batches = pickle.load(f)
    return batches


if __name__ == '__main__':
    try:
        os.makedirs(DATA_DIR)
    except IOError as e:
        pass

    batches = get_batches(DATA_DIR + '/norm_batches.pk')
    misc.save_comparison_grid('example1.png', batches[:16], border_shade=0.8)
    h, w = batches.shape[2:]
    crop_top, crop_bottom, crop_left, crop_right = 34, 16, 0, 0
    print(batches.shape)
    batches = batches[:, :, crop_top:h-crop_bottom, crop_left:w-crop_right]
    print(batches.shape)
    flat_batches = batches.transpose(1, 0, 2, 3).reshape(batches.shape[1], -1)
    mean = flat_batches.mean(axis=1)[None, :, None, None]
    std = flat_batches.std(axis=1)[None, :, None, None]
    batches -= mean
    batches /= std
    batches = batches.mean(axis=1, keepdims=True)  # greyscale
    true_min = batches.min()
    true_max = batches.max()
    print(true_min, true_max)
    bmin = true_min  # np.percentile(batches, 0)