def eval_policy(policy, rng, state, model, test_ds, epoch): """Eval for a single epoch.""" batch_metrics = [] policy = flax.jax_utils.unreplicate(flax.jax_utils.replicate(policy)) # Function is recompiled for this specific policy. test_ds = util_fns.get_iterator(test_ds) for batch in test_ds: metrics, rng = eval_step_policy(rng, batch, state, model, policy) # Better to leave metrics on device, and off-load after finishing epoch. batch_metrics.append(metrics) # Load to CPU. batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics)) # Compute mean of metrics across each batch in epoch. epoch_metrics_np = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics[0] if 'batch' not in k } nelbo = epoch_metrics_np['nelbo'] info_string = f'eval policy epoch: {epoch}, nelbo: {nelbo:.4f}' logging.info(info_string) return epoch_metrics_np
def eval_model(p_eval_step, rng, state, test_ds, epoch): """Eval for a single epoch.""" start_time = time.time() batch_metrics = [] test_ds = util_fns.get_iterator(test_ds) for batch in test_ds: metrics, rng = p_eval_step(rng, batch, state) # Better to leave metrics on device, and off-load after finishing epoch. batch_metrics.append(metrics) # Load to CPU. batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics)) # Compute mean of metrics across each batch in epoch. epoch_metrics_np = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics[0] if 'batch' not in k } nelbo = epoch_metrics_np['nelbo'] message = f'Eval epoch took {time.time() - start_time:.1f} seconds.' logging.info(message) info_string = f'eval epoch: {epoch}, nelbo: {nelbo:.4f}' logging.info(info_string) return epoch_metrics_np, rng
def eval_policy_and_sigma(policy, sigma, rng, state, model, dataset): """Eval for a single epoch with policy and sigma.""" batch_metrics = [] # Function is recompiled for this specific policy. dataset = util_fns.get_iterator(dataset) for batch in dataset: metrics, rng = eval_step_policy_and_sigma(rng, batch, state, model, policy, sigma) # Better to leave metrics on device, and off-load after finishing epoch. batch_metrics.append(metrics) # Load to CPU. batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics)) # Compute mean of metrics across each batch in epoch. epoch_metrics_np = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics[0] if 'batch' not in k } stdev_metrics_np = { k: np.std([metrics[k] for metrics in batch_metrics]) for k in batch_metrics[0] if 'batch' not in k } nelbo = epoch_metrics_np['nelbo'] stdev = stdev_metrics_np['nelbo'] num_samples = len(batch_metrics) info_string = (f'eval policy with sigma scores nelbo: {nelbo:.4f} +/- ' f'{stdev:.4f} on {num_samples} samples. So expected stdev ' f'is stdev/sqrt(n) is {stdev / np.sqrt(num_samples):.4f}') logging.info(info_string) return nelbo, stdev, num_samples
def train_epoch(p_train_step, state, train_ds, batch_size, epoch, rng, kl_tracker): """Train for a single epoch.""" start_time = time.time() batch_metrics = [] train_ds = util_fns.get_iterator(train_ds) with jax.profiler.StepTraceAnnotation('train', step_num=state.step): for batch in train_ds: state, metrics, rng = p_train_step(rng, batch, state) # Better to leave metrics on device, and off-load after finishing epoch. batch_metrics.append(metrics) # Load to CPU. batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics)) # This processes the loss per t, although two nested for-loops (counting the # one inside kl_tracker), it actually does not hurt timing performance # meaningfully. t_batches = [ metrics['t_batch'].reshape(batch_size) for metrics in batch_metrics ] nelbo_per_t_batches = [ metrics['nelbo_per_t_batch'].reshape(batch_size) for metrics in batch_metrics ] for t_batch, nelbo_per_t_batch in zip(t_batches, nelbo_per_t_batches): kl_tracker.update(t_batch, nelbo_per_t_batch) # Compute mean of metrics across each batch in epoch. epoch_metrics = { key: np.mean([metrics[key] for metrics in batch_metrics]) for key in batch_metrics[0] if 'batch' not in key } message = f'Epoch took {time.time() - start_time:.1f} seconds.' logging.info(message) info_string = ( f'train epoch: {epoch}, loss: {epoch_metrics["loss"]:.4f} ' f'nelbo: {epoch_metrics["nelbo"]:.4f} ce: {epoch_metrics["ce"]:.4f}') logging.info(info_string) return state, epoch_metrics, rng
def compress_dataset(state, model, test_ds, sigma=None, policy=None): """Compress a dataset. Args: state: A train state containing the params. model: The model class that contains all necessary methods. test_ds: The dataset to compress. sigma: An optional order of the generative process. policy: A policy describing which steps to take in parallel. An arange would do each step individually, but would be very slow. Returns: The final bits per dimension to compress the dataset, where each example is encoded with its own bitstream. """ assert (sigma is None) == (policy is None) if sigma is None and policy is None: logging.info('Compressing with random order since none was given.') policy = model.get_naive_policy(25) sigma = model.get_random_order(jax.random.PRNGKey(0)) total_size = 0 total_count = 0 d = int(np.prod(model.config.data_shape)) test_ds = util_fns.get_iterator(test_ds) for idx, batch in enumerate(test_ds): x = batch['image'] x = x.reshape(-1, *x.shape[2:]) # Flatten, we are not using pmap. batch_size = x.shape[0] # Scale bits determine the rounding precision, 32 seems to work well. # The init_bits puts a little buffer in the bitstream. But this is not # necessary for our model class. streams = [ ans_template.Bitstream(scale_bits=32) for _ in range(x.shape[0]) ] if idx == 0: logging.info('Initial total bits in bitstream: %d', len(streams[0])) logging.info('Encoding...') start = time.time() streams = model.encode_with_policy_and_sigma(streams, state.ema_params, x, policy, sigma) logging.info('Encoding took %.2f seconds', time.time() - start) # Adds the number of bits in the streams to the total size. for stream in streams: total_size += len(stream) total_count += batch_size bits_per_dim = total_size / total_count / d logging.info('Encoded %d Current bits per dim %f', total_count, bits_per_dim) logging.info('Decoding...') start = time.time() decoded, streams = model.decode_with_policy_and_sigma( streams, state.ema_params, policy, sigma, batch_size) del streams logging.info('Decoding took %.2f seconds', time.time() - start) coding_error = jnp.abs(x - decoded).sum() assert coding_error == 0, f'Coding error non-zero: {coding_error}' bits_per_dim = total_size / total_count / d logging.info('Bits per dim %f', bits_per_dim) return bits_per_dim