def save_checkpoint(ckpt_file, model, optimiser, beta, geco_err_ema, iter_idx, verbose=True): if verbose: fprint(f"Saving model training checkpoint to: {ckpt_file}") ckpt_dict = {'model_state_dict': model.state_dict(), 'optimiser_state_dict': optimiser.state_dict(), 'beta': beta, 'iter_idx': iter_idx} if geco_err_ema is not None: ckpt_dict['err_ema'] = geco_err_ema torch.save(ckpt_dict, ckpt_file)
def load(cfg, **unused_kwargs): del unused_kwargs if not os.path.exists(cfg.data_folder): raise Exception("Data folder does not exist.") print(f"Using {cfg.num_workers} data workers.") # Copy all images and splits to /tmp if cfg.copy_to_tmp: for directory in ['/recordings', '/splits']: src = cfg.data_folder + directory dst = '/tmp' + directory fprint(f"Copying dataset from {src} to {dst}.") copytree(src, dst) cfg.data_folder = '/tmp' # Training tng_set = ShapeStacksDataset(cfg.data_folder, cfg.split_name, 'train', cfg.img_size) tng_loader = DataLoader(tng_set, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) # Validation val_set = ShapeStacksDataset(cfg.data_folder, cfg.split_name, 'eval', cfg.img_size) val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers) # Test tst_set = ShapeStacksDataset(cfg.data_folder, cfg.split_name, 'test', cfg.img_size, shuffle_files=cfg.shuffle_test) tst_loader = DataLoader(tst_set, batch_size=1, shuffle=False, num_workers=1) # Throughput stats loader_throughput(tng_loader) return (tng_loader, val_loader, tst_loader)
def load(cfg, **unused_kwargs): # Fix TensorFlow seed global SEED SEED = cfg.seed tf.set_random_seed(SEED) if cfg.num_workers == 0: fprint("Need to use at least one worker for loading tfrecords.") cfg.num_workers = 1 del unused_kwargs if not os.path.exists(cfg.data_folder): raise Exception("Data folder does not exist.") print(f"Using {cfg.num_workers} data workers.") # Create data iterators train_loader = GQNLoader(data_folder=cfg.data_folder, mode='devel_train', img_size=cfg.img_size, val_frac=cfg.val_frac, batch_size=cfg.batch_size, num_workers=cfg.num_workers, buffer_size=cfg.buffer_size) val_loader = GQNLoader(data_folder=cfg.data_folder, mode='devel_val', img_size=cfg.img_size, val_frac=cfg.val_frac, batch_size=cfg.batch_size, num_workers=cfg.num_workers, buffer_size=cfg.buffer_size) test_loader = GQNLoader(data_folder=cfg.data_folder, mode='test', img_size=cfg.img_size, val_frac=cfg.val_frac, batch_size=1, num_workers=1, buffer_size=cfg.buffer_size) # Create session to be used by loaders sess = tf.InteractiveSession() train_loader.sess = sess val_loader.sess = sess test_loader.sess = sess # Throughput stats if not cfg.debug: loader_throughput(train_loader) return (train_loader, val_loader, test_loader)
def loader_throughput(loader, num_batches=100, burn_in=5): assert num_batches > 0 if burn_in is None: burn_in = num_batches // 10 num_samples = 0 fprint(f"Train loader throughput stats on {num_batches} batches...") for i, batch in enumerate(loader): if i == burn_in: timer = time.time() if i >= burn_in: num_samples += batch['input'].size(0) if i == num_batches + burn_in: break dt = time.time() - timer spb = dt / num_batches ips = num_samples / dt fprint(f"{spb:.3f} s/b, {ips:.1f} im/s")
def load(cfg, **unused_kwargs): # Fix TensorFlow seed global SEED SEED = cfg.seed if cfg.num_workers == 0: fprint("Need to use at least one worker for loading.") cfg.num_workers = 1 del unused_kwargs print(f"Using {cfg.num_workers} data workers.") # Create data iterators train_loader = MineRLLoader( mode="devel_train", img_size=cfg.img_size, val_frac=cfg.val_frac, batch_size=cfg.batch_size, num_workers=cfg.num_workers, buffer_size=cfg.buffer_size, ) val_loader = MineRLLoader( mode="devel_val", img_size=cfg.img_size, val_frac=cfg.val_frac, batch_size=cfg.batch_size, num_workers=cfg.num_workers, buffer_size=cfg.buffer_size, ) test_loader = MineRLLoader( mode="test", img_size=cfg.img_size, val_frac=cfg.val_frac, batch_size=1, num_workers=1, buffer_size=cfg.buffer_size, ) # Throughput stats loader_throughput(train_loader) return (train_loader, val_loader, test_loader)
def visualise_outputs(model, vis_batch, writer, mode, iter_idx): model.eval() # Only visualise for eight images # Forward pass vis_input = vis_batch['input'][:8] if next(model.parameters()).is_cuda: vis_input = vis_input.cuda() output, losses, stats, att_stats, comp_stats = model(vis_input) # Input and recon writer.add_image(mode+'_input', make_grid(vis_batch['input'][:8]), iter_idx) writer.add_image(mode+'_recon', make_grid(output), iter_idx) # Decomposition for key in ['mx_r_k', 'x_r_k', 'log_m_k', 'log_m_r_k']: if key not in stats: continue for step, val in enumerate(stats[key]): if 'log' in key: val = val.exp() writer.add_image(f'{mode}_{key}/k{step}', make_grid(val), iter_idx) # Generation try: output, stats = model.sample(batch_size=8, K_steps=model.K_steps) writer.add_image('samples', make_grid(output), iter_idx) for key in ['x_k', 'log_m_k', 'mx_k']: if key not in stats: continue for step, val in enumerate(stats[key]): if 'log' in key: val = val.exp() writer.add_image(f'gen_{key}/k{step}', make_grid(val), iter_idx) except NotImplementedError: fprint("Sampling not implemented for this model.") model.train()
def dataset_ari(model, data_loader, num_images=1000): model.eval() fprint("Computing ARI on dataset") ari = [] ari_fg = [] model.eval() for bidx, batch in enumerate(data_loader): if next(model.parameters()).is_cuda: batch['input'] = batch['input'].cuda() with torch.no_grad(): _, _, stats, _, _ = model(batch['input']) # Return zero if labels or segmentations are not available if 'instances' not in batch or not hasattr(stats, 'log_m_k'): return 0., 0., [0], [0] _, ari_list = average_ari(stats.log_m_k, batch['instances']) _, ari_fg_list = average_ari(stats.log_m_k, batch['instances'], True) ari += ari_list ari_fg += ari_fg_list if bidx % 1 == 0: log_ari = sum(ari) / len(ari) log_ari_fg = sum(ari_fg) / len(ari_fg) t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") fprint(f"{t} | After [{len(ari)} / {num_images}] images: " + f"ARI {log_ari:.4f}, FG ARI {log_ari_fg:.4f}") if len(ari) >= num_images: break assert len(ari) == len(ari_fg) ari = ari[:num_images] ari_fg = ari_fg[:num_images] avg_ari = sum(ari) / len(ari) avg_ari_fg = sum(ari_fg) / len(ari_fg) fprint(f"FINAL ARI for {len(ari)} images: {avg_ari:.4f}") fprint(f"FINAL FG ARI for {len(ari_fg)} images: {avg_ari_fg:.4f}") model.train() return avg_ari, avg_ari_fg, ari_list, ari_fg_list
def fid_from_model(model, test_loader, batch_size=10, num_images=10000, feat_dim=2048, img_dir='/tmp'): model.eval() # Save images from test set as pngs fprint("Saving images from test set as pngs.", True) test_dir = osp.join(img_dir, 'test_images') os.makedirs(test_dir) count = 0 for bidx, batch in enumerate(test_loader): count = tensor_to_png(batch['input'], test_dir, count, num_images) if count >= num_images: break # Generate images and save as pngs fprint("Generate images and save as pngs.", True) gen_dir = osp.join(img_dir, 'generated_images') os.makedirs(gen_dir) count = 0 for _ in tqdm(range(num_images // batch_size + 1)): if count >= num_images: break with torch.no_grad(): gen_img, _ = model.sample(batch_size) count = tensor_to_png(gen_img, gen_dir, count, num_images) # Compute FID fprint("Computing FID.", True) gpu = next(model.parameters()).is_cuda fid_value = FID.calculate_fid_given_paths([test_dir, gen_dir], batch_size, gpu, feat_dim) fprint(f"FID: {fid_value}", True) model.train() return fid_value
def main(): # Parse flags config = forge.config() fet.EXPERIMENT_FOLDER = config.model_dir fet.FPRINT_FILE = 'fid_evaluation.txt' config.shuffle_test = True # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(config.seed) np.random.seed(config.seed) random.seed(config.seed) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Using GPU? if torch.cuda.is_available() and config.gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') else: config.gpu = False torch.set_default_tensor_type('torch.FloatTensor') fet.print_flags() # Load data _, _, test_loader = fet.load(config.data_config, config) # Load model flag_path = osp.join(config.model_dir, 'flags.json') fprint(f"Restoring flags from {flag_path}") pretrained_flags = AttrDict(fet.json_load(flag_path)) model = fet.load(config.model_config, pretrained_flags) model_path = osp.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) fprint(model) # Put model on GPU if config.gpu: model = model.cuda() # Compute FID fid_from_model(model, test_loader, config.batch_size, config.num_fid_images, config.feat_dim, config.img_dir)
def __next__(self): try: frame = self.sess.run(self.frames) self.count += 1 # Parse image img = frame['image'] img = np.moveaxis(img, 3, 1) shape = img.shape # TODO(martin): use more explicit CLEVR flag? if shape[2] != shape[3]: img = np_img_centre_crop(img, CLEVR_CROP, batch=True) img = torch.FloatTensor(img) / 255. if self.img_size != shape[2]: img = F.interpolate(img, size=self.img_size) # Parse masks raw_masks = frame['mask'] masks = np.zeros((shape[0], 1, shape[2], shape[3]), dtype='int') # Convert to boolean masks cond = np.where(raw_masks[:, :, :, :, 0] == 255, True, False) # Ignore background entities num_entities = cond.shape[1] for o_idx in range(self.background_entities, num_entities): masks[cond[:, o_idx:o_idx + 1, :, :]] = o_idx + 1 masks = torch.FloatTensor(masks) if shape[2] != shape[3]: masks = np_img_centre_crop(masks, CLEVR_CROP, batch=True) masks = torch.FloatTensor(masks) if self.img_size != shape[2]: masks = F.interpolate(masks, size=self.img_size) masks = masks.type(torch.LongTensor) return {'input': img, 'instances': masks} except tf.errors.OutOfRangeError: fprint("Reached end of epoch. Creating new iterator.") fprint(f"Counted {self.count} batches, expected {self.length}.") fprint("Creating new iterator.") self.count = 0 raise StopIteration
def main(): # ------------------------ # SETUP # ------------------------ # Parse flags config = forge.config() if config.debug: config.num_workers = 0 config.batch_size = 2 # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(config.seed) np.random.seed(config.seed) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Setup checkpoint or resume logdir = osp.join(config.results_dir, config.run_name) logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config, config.model_config, config.resume) checkpoint_name = osp.join(logdir, 'model.ckpt') # Using GPU(S)? if torch.cuda.is_available() and config.gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') else: config.gpu = False torch.set_default_tensor_type('torch.FloatTensor') fprint(f"Use GPU: {config.gpu}") if config.gpu and config.multi_gpu and torch.cuda.device_count() > 1: fprint(f"Using {torch.cuda.device_count()} GPUs!") config.num_workers = torch.cuda.device_count() * config.num_workers else: config.multi_gpu = False # Print flags # fet.print_flags() # TODO(martin) make this cleaner fprint(json.dumps(fet._flags.FLAGS.__flags, indent=4, sort_keys=True)) # Setup TensorboardX SummaryWriter writer = SummaryWriter(logdir) # Load data train_loader, val_loader, test_loader = fet.load(config.data_config, config) num_elements = 3 * config.img_size**2 # Assume three input channels # Load model model = fet.load(config.model_config, config) fprint(model) if config.geco: # Goal is specified per pixel & channel so it doesn't need to # be changed for different resolutions etc. geco_goal = config.g_goal * num_elements # Scale step size to get similar update at different resolutions geco_lr = config.g_lr * (64**2 / config.img_size**2) geco = GECO(geco_goal, geco_lr, config.g_alpha, config.g_init, config.g_min, config.g_speedup) beta = geco.beta else: beta = torch.tensor(config.beta) # Setup optimiser if config.optimiser == 'rmsprop': optimiser = optim.RMSprop(model.parameters(), config.learning_rate) elif config.optimiser == 'adam': optimiser = optim.Adam(model.parameters(), config.learning_rate) elif config.optimiser == 'sgd': optimiser = optim.SGD(model.parameters(), config.learning_rate, 0.9) # Try to restore model and optimiser from checkpoint iter_idx = 0 if resume_checkpoint is not None: fprint(f"Restoring checkpoint from {resume_checkpoint}") checkpoint = torch.load(resume_checkpoint, map_location='cpu') # Restore model & optimiser model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) optimiser.load_state_dict(checkpoint['optimiser_state_dict']) # Restore GECO if config.geco and 'beta' in checkpoint: geco.beta = checkpoint['beta'] if config.geco and 'err_ema' in checkpoint: geco.err_ema = checkpoint['err_ema'] # Update starting iter iter_idx = checkpoint['iter_idx'] + 1 fprint(f"Starting training at iter = {iter_idx}") # Push model to GPU(s)? if config.multi_gpu: fprint("Wrapping model in DataParallel.") model = nn.DataParallel(model) if config.gpu: fprint("Pushing model to GPU.") model = model.cuda() if config.geco: geco.to_cuda() # ------------------------ # TRAINING # ------------------------ model.train() timer = time.time() while iter_idx <= config.train_iter: for train_batch in train_loader: # Parse data train_input = train_batch['input'] if config.gpu: train_input = train_input.cuda() # Forward propagation optimiser.zero_grad() output, losses, stats, att_stats, comp_stats = model(train_input) # Reconstruction error err = losses.err.mean(0) # KL divergences kl_m, kl_l = torch.tensor(0), torch.tensor(0) # -- KL stage 1 if 'kl_m' in losses: kl_m = losses.kl_m.mean(0) elif 'kl_m_k' in losses: kl_m = torch.stack(losses.kl_m_k, dim=1).mean(dim=0).sum() # -- KL stage 2 if 'kl_l' in losses: kl_l = losses.kl_l.mean(0) elif 'kl_l_k' in losses: kl_l = torch.stack(losses.kl_l_k, dim=1).mean(dim=0).sum() # Compute ELBO elbo = (err + kl_l + kl_m).detach() err_new = err.detach() kl_new = (kl_m + kl_l).detach() # Compute MSE / RMSE mse_batched = ((train_input - output)**2).mean((1, 2, 3)).detach() rmse_batched = mse_batched.sqrt() mse, rmse = mse_batched.mean(0), rmse_batched.mean(0) # Main objective if config.geco: loss = geco.loss(err, kl_l + kl_m) beta = geco.beta else: if config.beta_warmup: # Increase beta linearly over 20% of training beta = config.beta * iter_idx / (0.2 * config.train_iter) beta = torch.tensor(beta).clamp(0, config.beta) else: beta = config.beta loss = err + beta * (kl_l + kl_m) # Backprop and optimise loss.backward() optimiser.step() # Heartbeat log if (iter_idx % config.report_loss_every == 0 or float(elbo) > ELBO_DIV or config.debug): # Print output and write to file ps = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") ps += f' {config.run_name} | ' ps += f'[{iter_idx}/{config.train_iter:.0e}]' ps += f' elb: {float(elbo):.0f} err: {float(err):.0f} ' if 'kl_m' in losses or 'kl_m_k' in losses: ps += f' klm: {float(kl_m):.1f}' if 'kl_l' in losses or 'kl_l_k' in losses: ps += f' kll: {float(kl_l):.1f}' ps += f' bet: {float(beta):.1e}' s_per_b = (time.time() - timer) if not config.debug: s_per_b /= config.report_loss_every timer = time.time() # Reset timer ps += f' - {s_per_b:.2f} s/b' fprint(ps) # TensorBoard logging # -- Optimisation stats writer.add_scalar('optim/beta', beta, iter_idx) writer.add_scalar('optim/s_per_batch', s_per_b, iter_idx) if config.geco: writer.add_scalar('optim/geco_err_ema', geco.err_ema, iter_idx) writer.add_scalar('optim/geco_err_ema_element', geco.err_ema / num_elements, iter_idx) # -- Main loss terms writer.add_scalar('train/err', err, iter_idx) writer.add_scalar('train/err_element', err / num_elements, iter_idx) writer.add_scalar('train/kl_m', kl_m, iter_idx) writer.add_scalar('train/kl_l', kl_l, iter_idx) writer.add_scalar('train/elbo', elbo, iter_idx) writer.add_scalar('train/loss', loss, iter_idx) writer.add_scalar('train/mse', mse, iter_idx) writer.add_scalar('train/rmse', rmse, iter_idx) # -- Per step loss terms for key in ['kl_l_k', 'kl_m_k']: if key not in losses: continue for step, val in enumerate(losses[key]): writer.add_scalar(f'train_steps/{key}{step}', val.mean(0), iter_idx) # -- Attention stats if config.log_distributions and att_stats is not None: for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']: if key not in att_stats: continue for step, val in enumerate(att_stats[key]): writer.add_histogram(f'att_{key}_{step}', val, iter_idx) # -- Component stats if config.log_distributions and comp_stats is not None: for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']: if key not in comp_stats: continue for step, val in enumerate(comp_stats[key]): writer.add_histogram(f'comp_{key}_{step}', val, iter_idx) # Save checkpoints ckpt_freq = config.train_iter / config.num_checkpoints if iter_idx % ckpt_freq == 0: ckpt_file = '{}-{}'.format(checkpoint_name, iter_idx) fprint(f"Saving model training checkpoint to: {ckpt_file}") if config.multi_gpu: model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() ckpt_dict = { 'iter_idx': iter_idx, 'model_state_dict': model_state_dict, 'optimiser_state_dict': optimiser.state_dict(), 'elbo': elbo } if config.geco: ckpt_dict['beta'] = geco.beta ckpt_dict['err_ema'] = geco.err_ema torch.save(ckpt_dict, ckpt_file) # Run validation and log images if (iter_idx % config.run_validation_every == 0 or float(elbo) > ELBO_DIV): # Weight and gradient histograms if config.log_grads_and_weights: for name, param in model.named_parameters(): writer.add_histogram(f'weights/{name}', param.data, iter_idx) writer.add_histogram(f'grads/{name}', param.grad, iter_idx) # TensorboardX logging - images visualise_inference(model, train_batch, writer, 'train', iter_idx) # Validation fprint("Running validation...") eval_model = model.module if config.multi_gpu else model evaluation(eval_model, val_loader, writer, config, iter_idx, N_eval=config.N_eval) # Increment counter iter_idx += 1 if iter_idx > config.train_iter: break # Exit if training has diverged if elbo.item() > ELBO_DIV: fprint(f"ELBO: {elbo.item()}") fprint( f"ELBO has exceeded {ELBO_DIV} - training has diverged.") sys.exit() # ------------------------ # TESTING # ------------------------ # Save final checkpoint ckpt_file = '{}-{}'.format(checkpoint_name, 'FINAL') fprint(f"Saving model training checkpoint to: {ckpt_file}") if config.multi_gpu: model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() ckpt_dict = { 'iter_idx': iter_idx, 'model_state_dict': model_state_dict, 'optimiser_state_dict': optimiser.state_dict() } if config.geco: ckpt_dict['beta'] = geco.beta ckpt_dict['err_ema'] = geco.err_ema torch.save(ckpt_dict, ckpt_file) # Test evaluation fprint("STARTING TESTING...") eval_model = model.module if config.gpu and config.multi_gpu else model final_elbo = evaluation(eval_model, test_loader, None, config, iter_idx, N_eval=config.N_eval) fprint(f"TEST ELBO = {float(final_elbo)}") # FID computation try: fid_from_model(model, test_loader, img_dir=osp.join('/tmp', logdir)) except NotImplementedError: fprint("Sampling not implemented for this model.") # Close writer writer.close()
def load(cfg, **unused_kwargs): # Fix TensorFlow seed global SEED SEED = cfg.seed tf.set_random_seed(SEED) del unused_kwargs fprint(f"Using {cfg.num_workers} data workers.") sess = tf.InteractiveSession() if cfg.dataset == 'multi_dsprites': cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size cfg.K_steps = 5 if cfg.K_steps < 0 else cfg.K_steps background_entities = 1 max_frames = 60000 raw_dataset = multi_dsprites.dataset(cfg.data_folder + MULTI_DSPRITES, 'colored_on_colored', map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) elif cfg.dataset == 'objects_room': cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size cfg.K_steps = 7 if cfg.K_steps < 0 else cfg.K_steps background_entities = 4 max_frames = 1000000 raw_dataset = objects_room.dataset(cfg.data_folder + OBJECTS_ROOM, 'train', map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) elif cfg.dataset == 'clevr': cfg.img_size = 128 if cfg.img_size < 0 else cfg.img_size cfg.K_steps = 11 if cfg.K_steps < 0 else cfg.K_steps background_entities = 1 max_frames = 70000 raw_dataset = clevr_with_masks.dataset( cfg.data_folder + CLEVR, map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) elif cfg.dataset == 'tetrominoes': cfg.img_size = 32 if cfg.img_size < 0 else cfg.img_size cfg.K_steps = 4 if cfg.K_steps < 0 else cfg.K_steps background_entities = 1 max_frames = 60000 raw_dataset = tetrominoes.dataset(cfg.data_folder + TETROMINOS, map_parallel_calls=cfg.num_workers if cfg.num_workers > 0 else None) else: raise NotImplementedError(f"{cfg.dataset} not a valid dataset.") # Split into train / val / test if cfg.dataset_size > max_frames: fprint(f"WARNING: {cfg.dataset_size} frames requested, "\ "but only {max_frames} available.") cfg.dataset_size = max_frames if cfg.dataset_size > 0: total_sz = cfg.dataset_size raw_dataset = raw_dataset.take(total_sz) else: total_sz = max_frames if total_sz < 0: fprint("Determining size of dataset...") total_sz = len_tfrecords(raw_dataset, sess) fprint(f"Dataset has {total_sz} frames") val_sz = 10000 tst_sz = 10000 tng_sz = total_sz - val_sz - tst_sz assert tng_sz > 0 fprint(f"Splitting into {tng_sz}/{val_sz}/{tst_sz} for tng/val/tst") tst_dataset = raw_dataset.take(tst_sz) val_dataset = raw_dataset.skip(tst_sz).take(val_sz) tng_dataset = raw_dataset.skip(tst_sz + val_sz) tng_loader = MultiOjectLoader(sess, tng_dataset, background_entities, tng_sz, cfg.batch_size, cfg.img_size, cfg.buffer_size) val_loader = MultiOjectLoader(sess, val_dataset, background_entities, val_sz, cfg.batch_size, cfg.img_size, cfg.buffer_size) tst_loader = MultiOjectLoader(sess, tst_dataset, background_entities, tst_sz, cfg.batch_size, cfg.img_size, cfg.buffer_size) # Throughput stats if not cfg.debug: loader_throughput(tng_loader) return (tng_loader, val_loader, tst_loader)
def __iter__(self): fprint("Creating new one_shot_iterator.") it = self.dataset.make_one_shot_iterator() self.frames = it.get_next() return self
def main(): # Parse flags config = forge.config() # Restore flags of pretrained model flag_path = osp.join(config.model_dir, 'flags.json') fprint(f"Restoring flags from {flag_path}") pretrained_flags = AttrDict(fet.json_load(flag_path)) pretrained_flags.batch_size = 1 pretrained_flags.gpu = False pretrained_flags.debug = True fet.print_flags() # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(0) np.random.seed(0) random.seed(0) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Load model model = fet.load(config.model_config, pretrained_flags) model_path = osp.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) fprint(model) # Visualise model.eval() for _ in range(100): y, stats = model.sample(1, pretrained_flags.K_steps) fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps) # Generated plot(axes, 0, 0, y, title='Generated scene', fontsize=12) # Empty plots plot(axes, 1, 0, fontsize=12) plot(axes, 2, 0, fontsize=12) plot(axes, 3, 0, fontsize=12) # Put K generation steps in separate subfigures for step in range(pretrained_flags.K_steps): x_step = stats['x_k'][step] m_step = stats['log_m_k'][step].exp() mx_step = stats['mx_k'][step] if 'log_s_k' in stats: s_step = stats['log_s_k'][step].exp() pre = 'Mask x RGB ' if step == 0 else '' plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12) pre = 'RGB ' if step == 0 else '' plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12) pre = 'Mask ' if step == 0 else '' plot(axes, 2, 1 + step, m_step, pre + f'k={step+1}', True, fontsize=12) if 'log_s_k' in stats: pre = 'Scope ' if step == 0 else '' plot(axes, 3, 1 + step, s_step, pre + f'k={step+1}', True, axis=step == 0, fontsize=12) # Beautify and show figure plt.subplots_adjust(wspace=0.05, hspace=0.05) manager = plt.get_current_fig_manager() manager.resize(*manager.window.maxsize()) plt.show()
def main(): # Parse flags config = forge.config() config.batch_size = 1 config.load_instances = True fet.print_flags() # Restore original model flags pretrained_flags = AttrDict( fet.json_load(os.path.join(config.model_dir, 'flags.json'))) # Get validation loader train_loader, val_loader, test_loader = fet.load(config.data_config, config) fprint(f"Split: {config.split}") if config.split == 'train': batch_loader = train_loader elif config.split == 'val': batch_loader = val_loader elif config.split == 'test': batch_loader = test_loader # Shuffle and prefetch to get same data for different models if 'gqn' not in config.data_config: batch_loader = torch.utils.data.DataLoader(batch_loader.dataset, batch_size=1, num_workers=0, shuffle=True) # Prefetch batches prefetched_batches = [] for i, x in enumerate(batch_loader): if i == config.num_images: break prefetched_batches.append(x) # Load model model = fet.load(config.model_config, pretrained_flags) fprint(model) model_path = os.path.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) # Set experiment folder and fprint file for logging fet.EXPERIMENT_FOLDER = config.model_dir fet.FPRINT_FILE = 'segmentation_metrics.txt' # Compute metrics model.eval() ari_fg_list, sc_fg_list, msc_fg_list = [], [], [] with torch.no_grad(): for i, x in enumerate(tqdm(prefetched_batches)): _, _, stats, _, _ = model(x['input']) # ARI ari_fg, _ = average_ari(stats.log_m_k, x['instances'], foreground_only=True) # Segmentation covering - foreground only gt_instances = x['instances'].clone() gt_instances[gt_instances == 0] = -100 ins_preds = torch.argmax(torch.stack(stats.log_m_k, dim=1), dim=1) sc_fg = average_segcover(gt_instances, ins_preds) msc_fg = average_segcover(gt_instances, ins_preds, False) # Recording ari_fg_list.append(ari_fg) sc_fg_list.append(sc_fg) msc_fg_list.append(msc_fg) # Print average metrics fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}") fprint(f"Average FG SegCover: {sum(sc_fg_list)/len(sc_fg_list)}") fprint(f"Average FG MeanSegCover: {sum(msc_fg_list)/len(msc_fg_list)}")
def evaluation(model, data_loader, writer, config, iter_idx, N_eval=None): # TODO(martin): make interface cleaner model.eval() t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") if iter_idx == 0 or config.debug: num_batches = 1 elif N_eval is not None and N_eval <= len( data_loader) * data_loader.batch_size: num_batches = N_eval // data_loader.batch_size fprint(t + f" | Evaluating only on first {N_eval} examples in loader") else: num_batches = len(data_loader) fprint(t + f" | Evaluating on all {num_batches} examples in loader") start_t = time.time() err, kl_l, kl_m, elbo = 0., 0., 0., 0. batch = None # Don't compute gradient to run faster with torch.no_grad(): # Loop over loader for b_idx, batch in enumerate(data_loader): if config.gpu: batch['input'] = batch['input'].cuda() if b_idx == num_batches: fprint(f"Breaking from eval loop after {b_idx} batches") break if b_idx % 100 == 0: t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") fprint(t + f" | Validation batch [{b_idx+1} | {num_batches}]") _, losses, stats, _, _ = model(batch['input']) new_err = losses.err.mean(0) err += float(new_err) / num_batches # Parse different loss types if 'kl_m' in losses: new_kl_m = losses.kl_m.mean(0) kl_m += float(new_kl_m) / num_batches elif 'kl_m_k' in losses: new_kl_m = torch.stack(losses.kl_m_k, dim=1).sum(1).mean(0) kl_m += float(new_kl_m) / num_batches if 'kl_l' in losses: new_kl_l = losses.kl_l.mean(0) kl_l += float(new_kl_l) / num_batches elif 'kl_l_k' in losses: new_kl_l = torch.stack(losses.kl_l_k, dim=1).sum(1).mean(0) kl_l += float(new_kl_l) / num_batches # Update ELBO if 'elbo' not in losses: # Assign current "estimate" elbo += float(new_err + new_kl_l + new_kl_m) / num_batches else: # Add over steps elbo += float(losses.elbo.mean(0)) / num_batches # Printing duration = time.time() - start_t pstr = f'Evaluation elbo: {elbo:.1f}' pstr += f', err: {err:.1f}, kl_l: {kl_l:.1f}' pstr += f', kl_m: {kl_m:.1f}' pstr += f' --- {num_batches / duration:.1f} b/s' fprint(pstr) # TensorBoard logging if writer is not None: # TensorBoard logging - scalars writer.add_scalar('val/elbo', elbo, iter_idx) writer.add_scalar('val/err', err, iter_idx) writer.add_scalar('val/kl_l', kl_l, iter_idx) writer.add_scalar('val/kl_m', kl_m, iter_idx) # TensorBoard logging - inference (limit to 8) visualise_inference(model, batch, writer, 'val', iter_idx) # TensorBoard logging - generation (limit to 8) try: output, stats = model.sample(batch_size=8, K_steps=config.K_steps) writer.add_image('samples', make_grid(output), iter_idx) for key in ['x_k', 'log_m_k', 'mx_k']: if key not in stats: continue for step, val in enumerate(stats[key]): if 'log' in key: val = val.exp() writer.add_image(f'gen_{key}/k{step}', make_grid(val), iter_idx) except NotImplementedError: fprint("Sampling not implemented for this model.") model.train() return elbo
def evaluation(model, data_loader, writer, config, iter_idx, N_eval=None, N_seg_metrics=50): model.eval() torch.set_grad_enabled(False) batch_size = data_loader.batch_size if iter_idx == 0 or config.debug: num_batches = 1 fprint(f"ITER 0 / DEBUG - eval on {num_batches} batches", True) elif N_eval is not None and N_eval <= len(data_loader)*batch_size: num_batches = int(N_eval // batch_size) fprint(f"N_eval = {N_eval}, eval on {num_batches} batches", True) else: num_batches = len(data_loader) fprint(f"Eval on all {num_batches} batches") start_t = time.time() eval_stats = AttrDefault(list, {}) batch = None # Loop over loader for b_idx, batch in enumerate(data_loader): if b_idx == num_batches: fprint(f"Breaking from eval loop after {b_idx} batches") break if config.gpu: for key, val in batch.items(): batch[key] = val.cuda() # Forward pass _, losses, stats, _, _ = model(batch['input']) # Track individual loss terms for key, val in losses.items(): # Sum over steps if needed if isinstance(val, list): eval_stats[key].append(torch.stack(val, 1).sum(1).mean(0)) else: eval_stats[key].append(val.mean(0)) # Track ELBO kl_m, kl_l = torch.tensor(0), torch.tensor(0) if 'kl_m_k' in losses: kl_m = torch.stack(losses.kl_m_k, dim=1).sum(1).mean(0) elif 'kl_m' in losses: kl_m = losses.kl_m.mean(0) if 'kl_l_k' in losses: kl_l = torch.stack(losses.kl_l_k, dim=1).sum(1).mean(0) elif 'kl_l' in losses: kl_l = losses.kl_l.mean(0) eval_stats['elbo'].append(losses.err.mean(0) + kl_m + kl_l) # Track segmentation metrics metrics if ('instances' in batch and 'log_m_k' in stats and b_idx*batch_size < N_seg_metrics): # ARI new_ari, _ = average_ari( stats.log_m_k, batch['instances']) new_ari_fg, _ = average_ari( stats.log_m_k, batch['instances'], True) eval_stats['ari'].append(new_ari) eval_stats['ari_fg'].append(new_ari_fg) # Segmentation Covering iseg = torch.argmax(torch.cat(stats.log_m_k, 1), 1, True) msc, _ = average_segcover(batch['instances'], iseg) msc_fg, _ = average_segcover(batch['instances'], iseg, ignore_background=True) eval_stats['msc'].append(msc) eval_stats['msc_fg'].append(msc_fg) # Sum over batches for key, val in eval_stats.items(): # Sanity check if ('ari' in key or 'msc' in key) and not config.debug and iter_idx > 0: assert len(val)*batch_size >= N_seg_metrics assert len(val)*batch_size < N_seg_metrics+batch_size eval_stats[key] = sum(val) / len(val) # Track element-wise error nelements = np.prod(batch['input'].shape[1:4]) eval_stats['err_element'] = eval_stats['err'] / nelements # Printing duration = time.time() - start_t fprint(f'Eval duration: {duration:.1f}s, {num_batches / duration:.1f} b/s') eval_stats['duration'] = duration eval_stats['num_batches'] = num_batches eval_stats = dict(eval_stats) for key, val in eval_stats.items(): eval_stats[key] = float(val) # TensorBoard logging if writer is not None: log_scalars(eval_stats, 'val', iter_idx, writer) model.train() torch.set_grad_enabled(True) return eval_stats
def main(): # Parse flags config = forge.config() fet.print_flags() # Restore flags of pretrained model flag_path = osp.join(config.model_dir, 'flags.json') fprint(f"Restoring flags from {flag_path}") pretrained_flags = AttrDict(fet.json_load(flag_path)) pretrained_flags.debug = True # Fix seeds. Always first thing to be done after parsing the config! torch.manual_seed(0) np.random.seed(0) random.seed(0) # Make CUDA operations deterministic torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Load data config.batch_size = 1 _, _, test_loader = fet.load(config.data_config, config) # Load model model = fet.load(config.model_config, pretrained_flags) model_path = osp.join(config.model_dir, config.model_file) fprint(f"Restoring model from {model_path}") checkpoint = torch.load(model_path, map_location='cpu') model_state_dict = checkpoint['model_state_dict'] model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1', None) model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2', None) model.load_state_dict(model_state_dict) fprint(model) # Visualise model.eval() for count, batch in enumerate(test_loader): if count >= config.num_images: break # Forward pass output, _, stats, _, _ = model(batch['input']) # Set up figure fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps) # Input and reconstruction plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12) plot(axes, 1, 0, output, title='Reconstruction', fontsize=12) # Empty plots plot(axes, 2, 0, fontsize=12) plot(axes, 3, 0, fontsize=12) # Put K reconstruction steps into separate subfigures x_k = stats['x_r_k'] log_m_k = stats['log_m_k'] mx_k = [x * m.exp() for x, m in zip(x_k, log_m_k)] log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None for step in range(pretrained_flags.K_steps): mx_step = mx_k[step] x_step = x_k[step] m_step = log_m_k[step].exp() if log_s_k: s_step = log_s_k[step].exp() pre = 'Mask x RGB ' if step == 0 else '' plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12) pre = 'RGB ' if step == 0 else '' plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12) pre = 'Mask ' if step == 0 else '' plot(axes, 2, 1 + step, m_step, pre + f'k={step+1}', True, fontsize=12) if log_s_k: pre = 'Scope ' if step == 0 else '' plot(axes, 3, 1 + step, s_step, pre + f'k={step+1}', True, axis=step == 0, fontsize=12) # Beautify and show figure plt.subplots_adjust(wspace=0.05, hspace=0.15) manager = plt.get_current_fig_manager() manager.resize(*manager.window.maxsize()) plt.show()