def post_epoch_visualize(self, epoch, split): # if self.flags.visualize_only: # self.do_plots() # else: if True: print('* Visualizing', split) Z = torch.linspace(0.0 + 1e-3, 1.0 - 1e-3, steps=20) Z = torch.cartesian_prod(Z, Z).view(20, 20, 2) if self.flags.z_size == 2 and self.flags.normal_latent: dist = normal.Normal(0.0, 1.0) x_gens = [] for row in range(20): if self.flags.z_size == 2: z = Z[row] if self.flags.normal_latent: z = dist.icdf(z) else: if self.flags.normal_latent: z = torch.randn(20, self.flags.z_size) else: z = torch.rand(20, self.flags.z_size) z = self.model.prepare_batch(z) x_gen = self.model.run_batch([z]).view(20, self.img_chan, self.img_size, self.img_size).detach().cpu() x_gens.append(x_gen) x_full = torch.cat(x_gens, dim=0).numpy() if split == 'test': fname = self.flags.log_dir + '/test.png' else: fname = self.flags.log_dir + '/vis_%03d.png' % self.model.get_train_steps() misc.save_comparison_grid(fname, x_full, border_width=0, retain_sequence=True) print('* Visualizations saved to', fname)
def post_epoch_visualize(self, epoch, split): print('* Visualizing', split) fname = self.flags.log_dir + '/vis_{}_{}.png'.format( self.model.get_train_steps(), split) if split == 'train': z = torch.randn(self.flags.batch_size, self.flags.latent_size, 1, 1) z = self.model.prepare_batch(z) x = self.model.run_batch([None, z], visualize=True) vis_data = x.cpu().numpy() aspect = 1.0 else: batch = next( self.reader.iter_batches(split, self.batch_size, shuffle=True, partial_batching=True, threads=self.threads, max_batches=1)) batch = self.model.prepare_batch(batch[0]) x = self.model.run_batch([batch, None], visualize=True) y = batch.cpu().numpy()[None, ...] x = x.cpu().numpy()[None, ...] vis_data = np.concatenate([y, x], axis=0) vis_data = np.swapaxes(vis_data, 0, 1).reshape(-1, *x.shape[2:]) aspect = 2.0 misc.save_comparison_grid(fname, vis_data, border_shade=0.5, retain_sequence=True, desired_aspect=aspect) print('* Visualizations saved to', fname)
def post_epoch_visualize(self, epoch, split): if split != 'train': print('* Visualizing', split) vis_data, rows_cols = self._visualize_split( split, min(self.flags.seq_len - 1, 10), 5) if split == 'test': fname = self.flags.log_dir + '/test.png' else: fname = self.flags.log_dir + '/val%03d.png' % epoch misc.save_comparison_grid(fname, vis_data, border_shade=1.0, rows_cols=rows_cols, retain_sequence=True) print('* Visualizations saved to', fname) if split == 'test': print('* Generating more visualizations for', split) vis_data, rows_cols = self._visualize_split(split, 0, 15) fname = self.flags.log_dir + '/test_more.png' misc.save_comparison_grid(fname, vis_data, border_shade=1.0, rows_cols=rows_cols, retain_sequence=True) print('* More visualizations saved to', fname)
def post_epoch_visualize(self, epoch, split): print('* Visualizing', split) out = self._visualize_split(split, self.history_length, 5) if out is not None: vis_data, aspect = out fname = self.flags.log_dir + '/{}'.format(split) + '%03d.png' % epoch misc.save_comparison_grid(fname, vis_data, rows_cols=aspect, border_shade=1.0, retain_sequence=True) print('* Visualizations saved to', fname) else: print('* Visualization skipped')
def post_epoch_visualize(self, epoch, split): print('* Visualizing', split) length = min(10, self.flags.seq_len - 1) n = min(5, self.flags.seq_len - length) vis_data, aspect = self._visualize_split(split, length, 5) fname = self.flags.log_dir + '/{}'.format(split) + '%03d.png' % epoch misc.save_comparison_grid(fname, vis_data, rows_cols=aspect, border_shade=1.0, retain_sequence=True) print('* Visualizations saved to', fname)
def post_epoch_visualize(self, epoch, split): print('* Visualizing', split) Z = torch.linspace(0.0, 1.0, steps=20) Z = torch.cartesian_prod(Z, Z).view(20, 20, 2) x_gens = [] for row in range(20): if self.flags.z_size == 2: z = Z[row] else: z = torch.rand(20, self.flags.z_size) z = self.model.prepare_batch(z) x_gen = self.model.run_batch([z], visualize=True).detach().cpu() x_gens.append(x_gen) x_full = torch.cat(x_gens, dim=0).numpy() if split == 'test': fname = self.flags.log_dir + '/test.png' else: fname = self.flags.log_dir + '/vis_%03d.png' % self.model.get_train_steps() misc.save_comparison_grid(fname, x_full, border_width=0, retain_sequence=True) print('* Visualizations saved to', fname)
from pylego.misc import save_comparison_grid from readers.moving_mnist import MovingMNISTReader if __name__ == '__main__': reader = MovingMNISTReader('data/MNIST') for i, batch in enumerate(reader.iter_batches('train', 4, max_batches=5)): print(batch.size()) if i < 3: batch = batch[0].numpy().reshape(20, 1, 28, 28) save_comparison_grid('seq%d.png' % i, batch)
if __name__ == '__main__': emulator = GymReader('Seaquest-v0', 6, 4, 2, 100) reader = ReplayBuffer(emulator, 5000, 100, 1, 4, 0.99) print('EMULATOR:') for i, batch in enumerate( emulator.iter_batches('train', 4, max_batches=5, threads=2)): obs, actions, rewards, done = batch.get_next()[:4] print('obs', obs.size()) print('actions', actions.shape) print('rewards', rewards.shape) print() if i < 3: batch = obs.numpy().reshape(obs.shape[0] * obs.shape[1], 3, 80, 80) save_comparison_grid('eseq%d.png' % i, batch, rows_cols=obs.shape[:2], retain_sequence=True) print(actions) print() print('REPLAY BUFFER READER:') for i, batch in enumerate(reader.iter_batches('train', 4, max_batches=5)): obs, actions, rewards, done, t1, t2, returns, is_weight, idx = batch print('obs', obs.size()) print('actions', actions.shape) print('rewards', rewards.shape) print('is_weight', is_weight.shape) print() if i < 3: batch = obs.numpy().reshape(4 * 6, 3, 80, 80) save_comparison_grid('rseq%d.png' % i,
from pylego.misc import save_comparison_grid from readers.cifar10 import CIFAR10Reader if __name__ == '__main__': reader = CIFAR10Reader('data') for i, batch in enumerate(reader.iter_batches('train', 256, max_batches=5)): img = batch[0].numpy() print(img.shape, img.min(), img.max()) save_comparison_grid('seq%d.png' % i, img, border_shade=0.75)
with open(batches_fname, 'wb') as f: pickle.dump(batches, f) else: with open(batches_fname, 'rb') as f: batches = pickle.load(f) return batches if __name__ == '__main__': try: os.makedirs(DATA_DIR) except IOError as e: pass batches = get_batches(DATA_DIR + '/norm_batches.pk') misc.save_comparison_grid('example1.png', batches[:16], border_shade=0.8) h, w = batches.shape[2:] crop_top, crop_bottom, crop_left, crop_right = 34, 16, 0, 0 print(batches.shape) batches = batches[:, :, crop_top:h-crop_bottom, crop_left:w-crop_right] print(batches.shape) flat_batches = batches.transpose(1, 0, 2, 3).reshape(batches.shape[1], -1) mean = flat_batches.mean(axis=1)[None, :, None, None] std = flat_batches.std(axis=1)[None, :, None, None] batches -= mean batches /= std batches = batches.mean(axis=1, keepdims=True) # greyscale true_min = batches.min() true_max = batches.max() print(true_min, true_max) bmin = true_min # np.percentile(batches, 0)