def evaluate_auto_regressive_prior(autoregressive_prior: AutoRegressivePrior, output_dir): dataset = build_dataset() tf_grid_graphs = tf.function(lambda graphs: grid_graphs( graphs, autoregressive_prior.discrete_voxel_vae.voxels_per_dimension)) fig_dir = os.path.join(output_dir, 'evaluated_pairs') os.makedirs(fig_dir, exist_ok=True) pair_idx = 0 for (graphs, images) in iter(dataset): actual_voxels = tf_grid_graphs(graphs) actual_voxels = actual_voxels.numpy() # batch, H,W,D,C mu_3d, b_3d = autoregressive_prior._deproject_images(images[0:1]) mu_3d = mu_3d.numpy() b_3d = b_3d.numpy() for _actual_voxels, _mu_3d, _b_3d, _image in zip( actual_voxels, mu_3d, b_3d, images): np.savez(os.path.join(fig_dir, "evaluation_{:05}.npz".format(pair_idx)), actual_voxels=_actual_voxels, mu_3d=_mu_3d, b_3d=_b_3d, image=_image) # plot_voxel(_image, _mu_3d, _actual_voxels) pair_idx += 1
def train_discrete_voxel_vae(config, kwargs): # with strategy.scope(): train_one_epoch = build_training(**config, **kwargs) dataset = build_example_dataset(100, batch_size=2, num_blobs=3, num_nodes=64**3, image_dim=256) # the model will call grid_graphs internally to learn the 3D autoencoder. # we show here what that produces from a batch of graphs. for graphs, image in iter(dataset): assert image.numpy().shape == (2, 256, 256, 1) plt.imshow(image[0].numpy()) plt.colorbar() plt.show() voxels = grid_graphs(graphs, 64) assert voxels.numpy().shape == (2, 64, 64, 64, 1) plt.imshow(tf.reduce_mean(voxels[0], axis=-2)) plt.colorbar() plt.show() break # drop the image as the model expects only graphs dataset = dataset.map(lambda graphs, images: (graphs, )) # run on first input to set variable shapes for batch in iter(dataset): train_one_epoch.model(*batch) break log_dir = build_log_dir('log_dir', config) checkpoint_dir = build_checkpoint_dir('checkpointing', config) os.makedirs(checkpoint_dir, exist_ok=True) with open(os.path.join(checkpoint_dir, 'config.json'), 'w') as f: json.dump(config, f) vanilla_training_loop(train_one_epoch=train_one_epoch, training_dataset=dataset, num_epochs=1, early_stop_patience=5, checkpoint_dir=checkpoint_dir, log_dir=log_dir, save_model_dir=checkpoint_dir, debug=False)
def compute_logits(self, graphs): """ Computes normalised logits representing the variational posterior, q(z | img). Args: graphs: GraphsTuple Returns: logits: [batch, W, H, D, num_embeddings] """ #[batch, H', W', D', num_channels] img = grid_graphs(graphs, voxels_per_dimension=self.voxels_per_dimension) # number of channels must be known for conv3d img.set_shape([None, None, None, None, self.num_channels]) logits = self._encoder(img) logits /= 1e-6 + tf.math.reduce_std(logits, axis=-1, keepdims=True) logits -= tf.reduce_logsumexp(logits, axis=-1, keepdims=True) return logits
def log_likelihood(self, graphs: GraphsTuple, mu, logb): """ Log-Laplace distribution. The pdf of log-Laplace is, P(x | mu, b) = 1 / (2 * log(x) * b) * exp(|log(x) - mu|/b) Args: graphs: GraphsTuple in standard form. Assumes properties are of the form log(maximum(1e-5, properties)) mu: [num_samples, batch, H', W', D', channels] logb: [num_samples, batch, H', W', D', channels] Returns: log_prob [num_samples, batch] """ properties = grid_graphs( graphs, self.voxels_per_dimension ) # [num_samples, batch, H', W', D', num_properties] log_prob = - tf.math.abs(properties - mu) / tf.math.exp(logb) \ - tf.math.log(2.) - properties - logb # [num_samples, batch, H', W', D', num_properties] #num_samples, batch return tf.reduce_sum(log_prob, axis=[-1, -2, -3, -4])
def _build(self, graphs, **kwargs) -> dict: """ Args: graphs: GraphsTuple in standard form. **kwargs: Returns: """ latent_logits = self.compute_logits( graphs) # [batch, H, W, D, num_embeddings] log_token_samples_onehot, token_samples_onehot, latent_tokens = self.sample_latent( latent_logits, self.temperature, self.num_token_samples ) # [num_samples, batch, H, W, D, num_embeddings], [num_samples, batch, H, W, D, embedding_size] mu, logb = self.compute_likelihood_parameters( latent_tokens ) # [num_samples, batch, H', W', D', C], [num_samples, batch, H', W', D', C] log_likelihood = self.log_likelihood(graphs, mu, logb) # [num_samples, batch] kl_term = self.kl_term( latent_logits, log_token_samples_onehot) # [num_samples, batch] var_exp = tf.reduce_mean(log_likelihood, axis=0) # [batch] kl_div = tf.reduce_mean(kl_term, axis=0) # [batch] elbo = var_exp - self.beta * kl_div # batch loss = -tf.reduce_mean(elbo) # scalar entropy = -tf.reduce_sum(tf.math.exp(latent_logits) * latent_logits, axis=[-1]) # [batch, H, W, D] perplexity = 2.**(entropy / tf.math.log(2.)) # [batch, H, W, D] mean_perplexity = tf.reduce_mean(perplexity) # scalar if self.log_counter % int(128 / mu.get_shape()[0]) == 0: tf.summary.scalar('perplexity', mean_perplexity, step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step) tf.summary.scalar('temperature', self.temperature, step=self.step) tf.summary.scalar('beta', self.beta, step=self.step) projected_mu = tf.reduce_sum(mu[0], axis=-2) #[batch, H', W', C] voxels = grid_graphs( graphs, self.voxels_per_dimension) #[batch, H', W', D', C] projected_img = tf.reduce_sum(voxels, axis=-2) #[batch, H', W', C] for i in range(self.num_channels): vmin = tf.reduce_min(projected_mu[..., i]) vmax = tf.reduce_max(projected_mu[..., i]) _projected_mu = (projected_mu[..., i:i + 1] - vmin) / ( vmax - vmin) #batch, H', W', 1 _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.) vmin = tf.reduce_min(projected_img[..., i]) vmax = tf.reduce_max(projected_img[..., i]) _projected_img = (projected_img[..., i:i + 1] - vmin) / ( vmax - vmin) #batch, H', W', 1 _projected_img = tf.clip_by_value(_projected_img, 0., 1.) tf.summary.image(f'voxels_predict[{i}]', _projected_mu, step=self.step) tf.summary.image(f'voxels_actual[{i}]', _projected_img, step=self.step) batch, H, W, D, _ = get_shape(latent_logits) _latent_logits = latent_logits # [batch, H, W, D, num_embeddings] _latent_logits -= tf.reduce_min(_latent_logits, axis=-1, keepdims=True) _latent_logits /= tf.reduce_max(_latent_logits, axis=-1, keepdims=True) _latent_logits = tf.reshape( _latent_logits, [batch, H * W * D, self.num_embedding, 1 ]) # [batch, H*W*D, num_embedding, 1] tf.summary.image('latent_logits', _latent_logits, step=self.step) token_sample_onehot = token_samples_onehot[ 0] # [batch, H, W, D, num_embeddings] token_sample_onehot = tf.reshape( token_sample_onehot, [batch, H * W * D, self.num_embedding, 1 ]) # [batch, H*W*D, num_embedding, 1] tf.summary.image('latent_samples_onehot', token_sample_onehot, step=self.step) return dict(loss=loss, metrics=dict(var_exp=var_exp, kl_div=kl_div, mean_perplexity=mean_perplexity))
def write_summary(self, images, graphs, latent_logits_2d, latent_logits_3d, prior_latent_logits_2d, prior_latent_logits_3d): dist_2d = tfp.distributions.OneHotCategorical( logits=latent_logits_2d, dtype=latent_logits_2d.dtype) dist_3d = tfp.distributions.OneHotCategorical( logits=latent_logits_3d, dtype=latent_logits_3d.dtype) token_samples_onehot_2d = dist_2d.sample(1)[0] token_samples_onehot_3d = dist_3d.sample(1)[0] dist_2d_prior = tfp.distributions.OneHotCategorical( logits=prior_latent_logits_2d, dtype=prior_latent_logits_2d.dtype) dist_3d_prior = tfp.distributions.OneHotCategorical( logits=prior_latent_logits_3d, dtype=prior_latent_logits_3d.dtype) prior_token_samples_onehot_2d = dist_2d_prior.sample(1)[0] prior_token_samples_onehot_3d = dist_3d_prior.sample(1)[0] kl_div_2d = tf.reduce_mean( tf.reduce_sum(dist_2d.kl_divergence(dist_2d_prior), axis=[-1, -2])) kl_div_3d = tf.reduce_mean( tf.reduce_sum(dist_3d.kl_divergence(dist_3d_prior), axis=[-1, -2, -3])) tf.summary.scalar('kl_div_2d', kl_div_2d, step=self.step) tf.summary.scalar('kl_div_3d', kl_div_3d, step=self.step) tf.summary.scalar('kl_div', kl_div_2d + kl_div_3d, step=self.step) perplexity_2d = 2.**(dist_2d_prior.entropy() / tf.math.log(2.)) # mean_perplexity_2d = tf.reduce_mean(perplexity_2d) # scalar perplexity_3d = 2.**(dist_3d_prior.entropy() / tf.math.log(2.)) # mean_perplexity_3d = tf.reduce_mean(perplexity_3d) # tf.summary.scalar('perplexity_2d_prior', mean_perplexity_2d, step=self.step) tf.summary.scalar('perplexity_3d_prior', mean_perplexity_3d, step=self.step) prior_latent_tokens_2d = tf.einsum('sbhwd,de->sbhwe', prior_token_samples_onehot_2d[None], self.discrete_image_vae.embeddings) prior_latent_tokens_3d = tf.einsum('sbhwdn,ne->sbhwde', prior_token_samples_onehot_3d[None], self.discrete_voxel_vae.embeddings) mu_2d, logb_2d = self.discrete_image_vae.compute_likelihood_parameters( prior_latent_tokens_2d ) # [num_samples, batch, H', W', C], [num_samples, batch, H', W', C] log_likelihood_2d = self.discrete_image_vae.log_likelihood( images, mu_2d, logb_2d) # [num_samples, batch] var_exp_2d = tf.reduce_mean(log_likelihood_2d) # [scalar] mu_3d, logb_3d = self.discrete_voxel_vae.compute_likelihood_parameters( prior_latent_tokens_3d ) # [num_samples, batch, H', W', D', C], [num_samples, batch, H', W', D', C] log_likelihood_3d = self.discrete_voxel_vae.log_likelihood( graphs, mu_3d, logb_3d) # [num_samples, batch] var_exp_3d = tf.reduce_mean(log_likelihood_3d) # [scalar] var_exp = log_likelihood_2d + log_likelihood_3d tf.summary.scalar('var_exp_3d', tf.reduce_mean(var_exp_3d), step=self.step) tf.summary.scalar('var_exp_2d', tf.reduce_mean(var_exp_2d), step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) projected_mu = tf.reduce_sum(mu_3d[0], axis=-2) # [batch, H', W', C] voxels = grid_graphs(graphs, self.discrete_voxel_vae. voxels_per_dimension) # [batch, H', W', D', C] projected_img = tf.reduce_sum(voxels, axis=-2) # [batch, H', W', C] for i in range(self.discrete_voxel_vae.num_channels): vmin = tf.reduce_min(projected_mu[..., i]) vmax = tf.reduce_max(projected_mu[..., i]) _projected_mu = (projected_mu[..., i:i + 1] - vmin) / ( vmax - vmin) # batch, H', W', 1 _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.) vmin = tf.reduce_min(projected_img[..., i]) vmax = tf.reduce_max(projected_img[..., i]) _projected_img = (projected_img[..., i:i + 1] - vmin) / ( vmax - vmin) # batch, H', W', 1 _projected_img = tf.clip_by_value(_projected_img, 0., 1.) tf.summary.image(f'voxels_predict_prior[{i}]', _projected_mu, step=self.step) tf.summary.image(f'voxels_actual[{i}]', _projected_img, step=self.step) for name, _latent_logits_3d, _tokens_onehot_3d in zip( ['', '_prior'], [latent_logits_3d, prior_latent_logits_3d], [token_samples_onehot_3d, prior_token_samples_onehot_3d]): batch, H3, W3, D3, _ = get_shape(_latent_logits_3d) _latent_logits_3d -= tf.reduce_min(_latent_logits_3d, axis=-1, keepdims=True) _latent_logits_3d /= tf.reduce_max(_latent_logits_3d, axis=-1, keepdims=True) _latent_logits_3d = tf.reshape(_latent_logits_3d, [ batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding, 1 ]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f"latent_logits_3d{name}", _latent_logits_3d, step=self.step) _tokens_onehot_3d = tf.reshape(_tokens_onehot_3d, [ batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding, 1 ]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f'latent_samples_onehot_3d{name}', _tokens_onehot_3d, step=self.step) _mu = mu_2d[0] # [batch, H', W', C] _img = images # [batch, H', W', C] for i in range(self.discrete_image_vae.num_channels): vmin = tf.reduce_min(_mu[..., i]) vmax = tf.reduce_max(_mu[..., i]) _projected_mu = (_mu[..., i:i + 1] - vmin) / (vmax - vmin ) # batch, H', W', 1 _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.) vmin = tf.reduce_min(_img[..., i]) vmax = tf.reduce_max(_img[..., i]) _projected_img = (_img[..., i:i + 1] - vmin) / ( vmax - vmin) # batch, H', W', 1 _projected_img = tf.clip_by_value(_projected_img, 0., 1.) tf.summary.image(f'image_predict_prior[{i}]', _projected_mu, step=self.step) tf.summary.image(f'image_actual[{i}]', _projected_img, step=self.step) for name, _latent_logits_2d, _tokens_onehot_2d in zip( ['', '_prior'], [latent_logits_2d, prior_latent_logits_2d], [token_samples_onehot_2d, prior_token_samples_onehot_2d]): batch, H2, W2, _ = get_shape(_latent_logits_2d) _latent_logits_2d -= tf.reduce_min(_latent_logits_2d, axis=-1, keepdims=True) _latent_logits_2d /= tf.reduce_max(_latent_logits_2d, axis=-1, keepdims=True) _latent_logits_2d = tf.reshape( _latent_logits_2d, [batch, H2 * W2, self.discrete_image_vae.num_embedding, 1 ]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f"latent_logits_2d{name}", _latent_logits_2d, step=self.step) _tokens_onehot_2d = tf.reshape( _tokens_onehot_2d, [batch, H2 * W2, self.discrete_image_vae.num_embedding, 1 ]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f'latent_samples_onehot_2d{name}', _tokens_onehot_2d, step=self.step)