def _build(self, latent): self.initialize(latent) n_node, _ = get_shape(latent.nodes) node_values = self.v_linear(latent.nodes) node_keys = self.k_linear(latent.nodes) node_queries = self.q_linear(latent.nodes) # n_node, num_head, F node_keys = self.ln_keys(node_keys) node_queries = self.ln_queries(node_queries) _, _, d_k = get_shape(node_keys) node_queries /= tf.math.sqrt(tf.cast(d_k, node_queries.dtype)) # n_node, F attended_latent = self.self_attention(node_values=node_values, node_keys=node_keys, node_queries=node_queries, attention_graph=latent) # n_nodes, heads, output_size -> n_nodes, heads*output_size output_nodes = tf.reshape( attended_latent.nodes, (n_node, self.num_heads * self.input_node_size)) output_nodes = self.ln1( self.output_linear(output_nodes) + latent.nodes) output_nodes = self.ln2(self.FFN(output_nodes)) output_graph = latent.replace(nodes=output_nodes) return output_graph
def construct_sequence(self, token_samples_idx_2d, token_samples_idx_3d): """ Args: token_samples_idx_2d: [G, H2*W2] token_samples_idx_3d: [G, H3*W3*D3] Returns: sequence: G, 1 + H2*W2 + 1 + H3*W3*D3 + 1 """ idx_dtype = token_samples_idx_2d.dtype G, N2 = get_shape(token_samples_idx_2d) G, N3 = get_shape(token_samples_idx_3d) start_token_idx = tf.constant(self.num_embedding - 3, dtype=idx_dtype) del_token_idx = tf.constant(self.num_embedding - 2, dtype=idx_dtype) eos_token_idx = tf.constant(self.num_embedding - 1, dtype=idx_dtype) start_token = tf.fill((G, 1), start_token_idx) del_token = tf.fill((G, 1), del_token_idx) eos_token = tf.fill((G, 1), eos_token_idx) ### # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3 + 1 sequence = tf.concat([ start_token, token_samples_idx_2d, del_token, token_samples_idx_3d + self.discrete_image_vae.num_embedding, # shift to right eos_token ], axis=-1) return sequence
def compute_likelihood_parameters(self, latent_tokens): """ Compute the likelihood parameters from logits. Args: latent_tokens: [num_samples, batch, W, H, embedding_size] Returns: mu, logb: [num_samples, batch, W', H', num_channels] """ [num_samples, batch, H, W, _] = get_shape(latent_tokens) latent_tokens = tf.reshape( latent_tokens, [num_samples * batch, H, W, self.embedding_dim ]) #[num_samples*batch, H, W, self.embedding_dim] decoded_imgs = self._decoder( latent_tokens) # [num_samples * batch, H', W', C*2] decoded_imgs.set_shape([None, None, None, self.num_channels * 2]) [_, H_2, W_2, _] = get_shape(decoded_imgs) decoded_imgs = tf.reshape( decoded_imgs, [num_samples, batch, H_2, W_2, 2 * self.num_channels ]) # [S, batch, H, W, embedding_dim] mu, logb = decoded_imgs[..., :self.num_channels], decoded_imgs[ ..., self.num_channels:] return mu, logb # [S, batch, H', W' C], [S, batch, H', W' C]
def sample_decoder(self, img_logits, temperature, num_samples): [batch, H, W, _] = get_shape(img_logits) logits = tf.reshape( img_logits, [batch * H * W, self.num_embedding]) # [batch*H*W, num_embeddings] reduce_logsumexp = tf.math.reduce_logsumexp(logits, axis=-1) # [batch*H*W] reduce_logsumexp = tf.tile( reduce_logsumexp[:, None], [1, self.num_embedding]) # [batch*H*W, num_embedding] logits -= reduce_logsumexp # [batch*H*W, num_embeddings] token_distribution = tfp.distributions.RelaxedOneHotCategorical( temperature, logits=logits) token_samples_onehot = token_distribution.sample( (num_samples, ), name='token_samples') # [S, batch*H*W, num_embeddings] def _single_decode(token_sample_onehot): # [batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim] token_sample = tf.matmul( token_sample_onehot, self.embeddings) # [batch*H*W, embedding_dim] # = z ~ q(z|x) latent_img = tf.reshape(token_sample, [batch, H, W, self.embedding_dim ]) # [batch, H, W, embedding_dim] decoded_img = self.decoder(latent_img) # [batch, H', W', C*2] return decoded_img decoded_ims = tf.vectorized_map( _single_decode, token_samples_onehot) # [S, batch, H', W', C*2] decoded_im = tf.reduce_mean(decoded_ims, axis=0) # [batch, H', W', C*2] return decoded_im
def construct_input_graph(self, input_sequence, N2): G, N = get_shape(input_sequence) # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3, embedding_dim input_tokens = tf.nn.embedding_lookup(self.embeddings, input_sequence) self.initialize_positional_encodings(input_tokens) nodes = input_tokens + self.positional_encodings n_node = tf.fill([G], N) n_edge = tf.zeros_like(n_node) data_dict = dict(nodes=nodes, edges=None, senders=None, receivers=None, globals=None, n_node=n_node, n_edge=n_edge) concat_graphs = GraphsTuple(**data_dict) concat_graphs = graph_unbatch_reshape(concat_graphs) # [n_graphs * (num_input + num_output), embedding_size] # nodes, senders, receivers, globals def edge_connect_rule(sender, receiver): # . a . b -> a . b . complete_2d = (sender < N2 + 1) & (receiver < N2 + 1) & ( sender + 1 != receiver) # exclude senders from one-right, so it doesn't learn copy. auto_regressive_3d = (sender <= receiver) & ( receiver >= N2 + 1) # auto-regressive (excluding 2d) with self-loops return complete_2d | auto_regressive_3d # nodes, senders, receivers, globals concat_graphs = connect_graph_dynamic(concat_graphs, edge_connect_rule) return concat_graphs
def _sample_decoder(self, img_logits, temperature, num_samples): [batch, H, W, _] = get_shape(img_logits) logits = tf.reshape( img_logits, [batch * H * W, self.num_embedding]) # [batch*H*W, num_embeddings] logits -= tf.reduce_mean(logits, axis=-1, keepdims=True) logits /= tf.math.reduce_std(logits, axis=-1, keepdims=True) reduce_logsumexp = tf.math.reduce_logsumexp( logits, axis=-1, keepdims=True) # [batch*H*W] # reduce_logsumexp = tf.tile(reduce_logsumexp[:, None], [1, self.num_embedding]) # [batch*H*W, num_embedding] logits -= reduce_logsumexp # [batch*H*W, num_embeddings] token_distribution = tfp.distributions.RelaxedOneHotCategorical( temperature, logits=logits) token_samples_onehot = token_distribution.sample( (num_samples, ), name='token_samples') # [S, batch*H*W, num_embeddings] def _single_latent_img(token_sample_onehot): # [batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim] token_sample = tf.matmul( token_sample_onehot, self.embeddings) # [batch*H*W, embedding_dim] # = z ~ q(z|x) latent_img = tf.reshape(token_sample, [batch, H, W, self.embedding_dim ]) # [batch, H, W, embedding_dim] return latent_img latent_imgs = tf.vectorized_map( _single_latent_img, token_samples_onehot) # [S, batch, H, W, embedding_dim] latent_imgs = tf.reshape( latent_imgs, [num_samples * batch, H, W, self.embedding_dim ]) # [S * batch, H, W, embedding_dim] decoded_imgs = self.decoder(latent_imgs) # [S * batch, H', W', C*2] [_, H_2, W_2, _] = get_shape(decoded_imgs) decoded_imgs = tf.reshape( decoded_imgs, [num_samples, batch, H_2, W_2, 2 * self.num_channels ]) # [S, batch, H, W, embedding_dim] # decoded_img = tf.reduce_mean(decoded_imgs, axis=0) # [batch, H', W', C*2] decoded_imgs = decoded_imgs[..., :self.num_channels] return decoded_imgs # [S, batch, H', W', C]
def initialize(self, graphs): input_node_size = get_shape(graphs.nodes)[-1] self.input_node_size = input_node_size self.v_linear = MultiHeadLinear(output_size=input_node_size, num_heads=self.num_heads, name='mhl1') # values self.k_linear = MultiHeadLinear(output_size=input_node_size, num_heads=self.num_heads, name='mhl2') # keys self.q_linear = MultiHeadLinear(output_size=input_node_size, num_heads=self.num_heads, name='mhl3') # queries self.FFN = snt.nets.MLP([input_node_size, input_node_size], activate_final=False, name='ffn') # Feed forward network self.output_linear = snt.Linear(output_size=input_node_size, name='output_linear')
def log_prob_q(self, latent_logits, log_token_samples_onehot): """ Args: latent_logits: [batch, H, W, num_embeddings] (normalised) log_token_samples_onehot: [num_samples, batch, H, W, num_embeddings] Returns: """ _, H, W, _ = get_shape(latent_logits) q_dist = tfp.distributions.ExpRelaxedOneHotCategorical( self.temperature, logits=latent_logits) log_prob_q = q_dist.log_prob( log_token_samples_onehot) # num_samples, batch, H, W return log_prob_q
def initialize(self, graphs: GraphsTuple): in_node_size = get_shape(graphs.nodes)[-1] node_model_fn = lambda: snt.nets.MLP([in_node_size, in_node_size], activate_final=True, activation=tf.nn.relu, name='node_fn') edge_model_fn = lambda: snt.nets.MLP([in_node_size, in_node_size], activate_final=True, activation=tf.nn.relu, name='edge_fn') self.edge_block = blocks.EdgeBlock(edge_model_fn, use_edges=self.use_edges, use_receiver_nodes=False, use_sender_nodes=True, use_globals=self.use_globals) self.node_block = blocks.NodeBlock(node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=self.use_globals)
def write_summary(self,images, voxels, latent_logits_2d, latent_logits_3d, prior_latent_logits_2d, prior_latent_logits_3d): dist_2d = tfp.distributions.OneHotCategorical(logits=latent_logits_2d, dtype=latent_logits_2d.dtype) dist_3d = tfp.distributions.OneHotCategorical(logits=latent_logits_3d, dtype=latent_logits_3d.dtype) token_samples_onehot_2d = dist_2d.sample(1)[0] token_samples_onehot_3d = dist_3d.sample(1)[0] dist_2d_prior = tfp.distributions.OneHotCategorical(logits=prior_latent_logits_2d, dtype=prior_latent_logits_2d.dtype) dist_3d_prior = tfp.distributions.OneHotCategorical(logits=prior_latent_logits_3d, dtype=prior_latent_logits_3d.dtype) prior_token_samples_onehot_2d = dist_2d_prior.sample(1)[0] prior_token_samples_onehot_3d = dist_3d_prior.sample(1)[0] kl_div_2d = tf.reduce_mean(tf.reduce_sum(dist_2d.kl_divergence(dist_2d_prior), axis=[-1,-2])) kl_div_3d = tf.reduce_mean(tf.reduce_sum(dist_3d.kl_divergence(dist_3d_prior), axis=[-1,-2,-3])) tf.summary.scalar('kl_div_2d', kl_div_2d, step=self.step) tf.summary.scalar('kl_div_3d', kl_div_3d, step=self.step) tf.summary.scalar('kl_div', kl_div_2d + kl_div_3d, step=self.step) perplexity_2d = 2. ** (dist_2d_prior.entropy() / tf.math.log(2.)) # mean_perplexity_2d = tf.reduce_mean(perplexity_2d) # scalar perplexity_3d = 2. ** (dist_3d_prior.entropy() / tf.math.log(2.)) # mean_perplexity_3d = tf.reduce_mean(perplexity_3d) # tf.summary.scalar('perplexity_2d_prior', mean_perplexity_2d, step=self.step) tf.summary.scalar('perplexity_3d_prior', mean_perplexity_3d, step=self.step) prior_latent_tokens_2d = tf.einsum('sbhwd,de->sbhwe', prior_token_samples_onehot_2d[None], self.discrete_image_vae.embeddings) prior_latent_tokens_3d = tf.einsum('sbhwdn,ne->sbhwde', prior_token_samples_onehot_3d[None], self.discrete_voxel_vae.embeddings) mu_2d, logb_2d = self.discrete_image_vae.compute_likelihood_parameters( prior_latent_tokens_2d) # [num_samples, batch, H', W', C], [num_samples, batch, H', W', C] log_likelihood_2d = self.discrete_image_vae.log_likelihood(images, mu_2d, logb_2d) # [num_samples, batch] var_exp_2d = tf.reduce_mean(log_likelihood_2d) # [scalar] mu_3d, logb_3d = self.discrete_voxel_vae.compute_likelihood_parameters( prior_latent_tokens_3d) # [num_samples, batch, H', W', D', C], [num_samples, batch, H', W', D', C] log_likelihood_3d = self.discrete_voxel_vae.log_likelihood(voxels, mu_3d, logb_3d) # [num_samples, batch] var_exp_3d = tf.reduce_mean(log_likelihood_3d) # [scalar] var_exp = log_likelihood_2d + log_likelihood_3d tf.summary.scalar('var_exp_3d', tf.reduce_mean(var_exp_3d), step=self.step) tf.summary.scalar('var_exp_2d', tf.reduce_mean(var_exp_2d), step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) projected_mu = tf.reduce_sum(mu_3d[0], axis=-2) # [batch, H', W', C] projected_img = tf.reduce_sum(voxels, axis=-2) # [batch, H', W', C] for i in range(self.discrete_voxel_vae.num_channels): vmin = tf.reduce_min(projected_mu[..., i]) vmax = tf.reduce_max(projected_mu[..., i]) _projected_mu = (projected_mu[..., i:i + 1] - vmin) / (vmax - vmin) # batch, H', W', 1 _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.) vmin = tf.reduce_min(projected_img[..., i]) vmax = tf.reduce_max(projected_img[..., i]) _projected_img = (projected_img[..., i:i + 1] - vmin) / (vmax - vmin) # batch, H', W', 1 _projected_img = tf.clip_by_value(_projected_img, 0., 1.) tf.summary.image(f'voxels_predict_prior[{i}]', _projected_mu, step=self.step) tf.summary.image(f'voxels_actual[{i}]', _projected_img, step=self.step) for name, _latent_logits_3d, _tokens_onehot_3d in zip(['', '_prior'], [latent_logits_3d,prior_latent_logits_3d], [token_samples_onehot_3d, prior_token_samples_onehot_3d]): batch, H3, W3, D3, _ = get_shape(_latent_logits_3d) _latent_logits_3d -= tf.reduce_min(_latent_logits_3d, axis=-1, keepdims=True) _latent_logits_3d /= tf.reduce_max(_latent_logits_3d, axis=-1, keepdims=True) _latent_logits_3d = tf.reshape(_latent_logits_3d, [batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding, 1]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f"latent_logits_3d{name}", _latent_logits_3d, step=self.step) _tokens_onehot_3d = tf.reshape(_tokens_onehot_3d, [batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding, 1]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f'latent_samples_onehot_3d{name}', _tokens_onehot_3d, step=self.step) _mu = mu_2d[0] # [batch, H', W', C] _img = images # [batch, H', W', C] for i in range(self.discrete_image_vae.num_channels): vmin = tf.reduce_min(_mu[..., i]) vmax = tf.reduce_max(_mu[..., i]) _projected_mu = (_mu[..., i:i + 1] - vmin) / (vmax - vmin) # batch, H', W', 1 _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.) vmin = tf.reduce_min(_img[..., i]) vmax = tf.reduce_max(_img[..., i]) _projected_img = (_img[..., i:i + 1] - vmin) / (vmax - vmin) # batch, H', W', 1 _projected_img = tf.clip_by_value(_projected_img, 0., 1.) tf.summary.image(f'image_predict_prior[{i}]', _projected_mu, step=self.step) tf.summary.image(f'image_actual[{i}]', _projected_img, step=self.step) for name, _latent_logits_2d, _tokens_onehot_2d in zip(['', '_prior'], [latent_logits_2d, prior_latent_logits_2d], [token_samples_onehot_2d, prior_token_samples_onehot_2d]): batch, H2, W2, _ = get_shape(_latent_logits_2d) _latent_logits_2d -= tf.reduce_min(_latent_logits_2d, axis=-1, keepdims=True) _latent_logits_2d /= tf.reduce_max(_latent_logits_2d, axis=-1, keepdims=True) _latent_logits_2d = tf.reshape(_latent_logits_2d, [batch, H2 * W2, self.discrete_image_vae.num_embedding, 1]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f"latent_logits_2d{name}", _latent_logits_2d, step=self.step) _tokens_onehot_2d = tf.reshape(_tokens_onehot_2d, [batch, H2 * W2, self.discrete_image_vae.num_embedding, 1]) # [batch, H*W*D, num_embedding, 1] tf.summary.image(f'latent_samples_onehot_2d{name}', _tokens_onehot_2d, step=self.step)
def _build(self, img, **kwargs) -> dict: """ Args: img: [batch, H', W', num_channel] **kwargs: Returns: """ encoded_img_logits = self.encoder(img) # [batch, H, W, num_embedding] [batch, H, W, _] = get_shape(encoded_img_logits) logits = tf.reshape( encoded_img_logits, [batch * H * W, self.num_embedding]) # [batch*H*W, num_embeddings] reduce_logsumexp = tf.math.reduce_logsumexp(logits, axis=-1) # [batch*H*W] reduce_logsumexp = tf.tile( reduce_logsumexp[:, None], [1, self.num_embedding]) # [batch*H*W, num_embedding] logits -= reduce_logsumexp # [batch*H*W, num_embeddings] temperature = tf.maximum( 0.1, tf.cast(1. - 0.1 / (self.step / 1000), tf.float32)) token_distribution = tfp.distributions.RelaxedOneHotCategorical( temperature, logits=logits) token_samples_onehot = token_distribution.sample( (self.num_token_samples, ), name='token_samples') # [S, batch*H*W, num_embeddings] def _single_decode(token_sample_onehot): #[batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim] token_sample = tf.matmul( token_sample_onehot, self.embeddings) # [batch*H*W, embedding_dim] # = z ~ q(z|x) latent_img = tf.reshape(token_sample, [batch, H, W, self.embedding_dim ]) # [batch, H, W, embedding_dim] decoded_img = self.decoder(latent_img) # [batch, H', W', C*2] # print('decod shape', decoded_img) img_mu = decoded_img[..., :self.num_channels] #[batch, H', W', C] # print('mu shape', img_mu) img_logb = decoded_img[..., self.num_channels:] # print('logb shape', img_logb) log_likelihood = self.log_likelihood(img, img_mu, img_logb) #[batch, H', W', C] log_likelihood = tf.reduce_sum(log_likelihood, axis=[-3, -2, -1]) # [batch] sum_selected_logits = tf.math.reduce_sum(token_sample_onehot * logits, axis=-1) # [batch*H*W] sum_selected_logits = tf.reshape(sum_selected_logits, [batch, H, W]) kl_term = tf.reduce_sum(sum_selected_logits, axis=[-2, -1]) #[batch] return log_likelihood, kl_term, decoded_img #num_samples, batch log_likelihood_samples, kl_term_samples, decoded_ims = tf.vectorized_map( _single_decode, token_samples_onehot) # [S, batch], [S, batch] if self.step % 50 == 0: img_mu_0 = tf.reduce_mean(decoded_ims, axis=0)[..., :self.num_channels] img_mu_0 -= tf.reduce_min(img_mu_0) img_mu_0 /= tf.reduce_max(img_mu_0) tf.summary.image('mu', img_mu_0, step=self.step) smoothed_img = img[..., self.num_channels:] smoothed_img = (smoothed_img - tf.reduce_min(smoothed_img)) / ( tf.reduce_max(smoothed_img) - tf.reduce_min(smoothed_img)) tf.summary.image(f'img_before_autoencoder', smoothed_img, step=self.step) var_exp = tf.reduce_mean(log_likelihood_samples, axis=0) # [batch] kl_div = tf.reduce_mean(kl_term_samples, axis=0) # [batch] elbo = var_exp - kl_div # batch loss = -tf.reduce_mean(elbo) # scalar entropy = -tf.reduce_sum(logits * tf.math.exp(logits), axis=-1) # [batch*H*W] perplexity = 2.**(-entropy / tf.math.log(2.)) # [batch*H*W] mean_perplexity = tf.reduce_mean(perplexity) # scalar if self.step % 2 == 0: logits = tf.nn.softmax(logits, axis=-1) # [batch*H*W, num_embedding] logits -= tf.reduce_min(logits) logits /= tf.reduce_max(logits) logits = tf.reshape( logits, [batch, H * W, self.num_embedding])[0] # [H*W, num_embedding] # tf.repeat(tf.repeat(logits, 16*[4], axis=0), 512*[4], axis=1) tf.summary.image('logits', logits[None, :, :, None], step=self.step) tf.summary.scalar('perplexity', mean_perplexity, step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step) return dict(loss=loss, metrics=dict(var_exp=var_exp, kl_div=kl_div, mean_perplexity=mean_perplexity))
def _build(self, img, **kwargs) -> dict: """ Args: img: [batch, H, W, num_channels] **kwargs: Returns: """ latent_logits = self.compute_logits( img) # [batch, H, W, num_embeddings] log_token_samples_onehot, token_samples_onehot, latent_tokens = self.sample_latent( latent_logits, self.temperature, self.num_token_samples ) # [num_samples, batch, H, W, num_embeddings], [num_samples, batch, H, W, embedding_size] mu, logb = self.compute_likelihood_parameters( latent_tokens ) # [num_samples, batch, H', W', C], [num_samples, batch, H', W', C] log_likelihood = self.log_likelihood(img, mu, logb) # [num_samples, batch] kl_term = self.kl_term( latent_logits, log_token_samples_onehot) # [num_samples, batch] var_exp = tf.reduce_mean(log_likelihood, axis=0) # [batch] kl_div = tf.reduce_mean(kl_term, axis=0) # [batch] elbo = var_exp - self.beta * kl_div # batch loss = -tf.reduce_mean(elbo) # scalar entropy = -tf.reduce_sum(tf.math.exp(latent_logits) * latent_logits, axis=[-1]) # [batch, H, W] perplexity = 2.**(entropy / tf.math.log(2.)) # [batch, H, W] mean_perplexity = tf.reduce_mean(perplexity) # scalar if self.log_counter % int(128 / mu.get_shape()[0]): tf.summary.scalar('perplexity', mean_perplexity, step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step) tf.summary.scalar('temperature', self.temperature, step=self.step) tf.summary.scalar('beta', self.beta, step=self.step) _mu = mu[0] #[batch, H', W', C] _img = img #[batch, H', W', C] for i in range(self.num_channels): vmin = tf.reduce_min(_mu[..., i]) vmax = tf.reduce_max(_mu[..., i]) _projected_mu = (_mu[..., i:i + 1] - vmin) / ( vmax - vmin) #batch, H', W', 1 _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.) vmin = tf.reduce_min(_img[..., i]) vmax = tf.reduce_max(_img[..., i]) _projected_img = (_img[..., i:i + 1] - vmin) / ( vmax - vmin) #batch, H', W', 1 _projected_img = tf.clip_by_value(_projected_img, 0., 1.) tf.summary.image(f'image_predict[{i}]', _projected_mu, step=self.step) tf.summary.image(f'image_actual[{i}]', _projected_img, step=self.step) batch, H, W, _ = get_shape(latent_logits) _latent_logits = latent_logits # [batch, H, W, num_embeddings] _latent_logits -= tf.reduce_min(_latent_logits, axis=-1, keepdims=True) _latent_logits /= tf.reduce_max(_latent_logits, axis=-1, keepdims=True) _latent_logits = tf.reshape(_latent_logits, [batch, H * W, self.num_embedding, 1 ]) # [batch, H*W, num_embedding, 1] tf.summary.image('latent_logits', _latent_logits, step=self.step) token_sample_onehot = token_samples_onehot[ 0] # [batch, H, W, num_embeddings] token_sample_onehot = tf.reshape( token_sample_onehot, [batch, H * W, self.num_embedding, 1 ]) # [batch, H*W, num_embedding, 1] tf.summary.image('latent_samples_onehot', token_sample_onehot, step=self.step) return dict(loss=loss, metrics=dict(var_exp=var_exp, kl_div=kl_div, mean_perplexity=mean_perplexity))
def _build(self, graphs, temperature): """ Adds another set of nodes to each graph. Autoregressively links all nodes in a graph. Args: graphs: batched GraphsTuple, node_shape = [n_graphs * num_input, input_embedding_size] temperature: scalar > 0 Returns: #todo: give shapes to returns token_node kl_div token_3d_samples_onehot basis_weights """ # give graphs edges and new node dimension (linear transformation) graphs = self.projection_node_block( graphs) # nodes = [n_graphs * num_input, embedding_size] batched_graphs = graph_batch_reshape( graphs) # nodes = [n_graphs, num_input, embedding_size] [n_graphs, n_node_per_graph_before_concat, _] = get_shape(batched_graphs.nodes) concat_nodes = tf.concat( [ batched_graphs.nodes, tf.tile(self.starting_node_variable[None, :], [n_graphs, 1, 1]) ], axis=-2) # [n_graphs, num_input + num_output, embedding_size] batched_graphs = batched_graphs.replace( nodes=concat_nodes, globals=tf.tile(self.starting_global_variable[None, :], [n_graphs, 1]), n_node=tf.fill([n_graphs], n_node_per_graph_before_concat + self.num_output)) concat_graphs = graph_unbatch_reshape( batched_graphs ) # [n_graphs * (num_input + num_output), embedding_size] # nodes, senders, receivers, globals concat_graphs = autoregressive_connect_graph_dynamic( concat_graphs ) # exclude self edges because 3d tokens orginally placeholder? # todo: this only works if exclude_self_edges=False n_edge = n_graphs * ( (n_node_per_graph_before_concat + self.num_output) * (n_node_per_graph_before_concat + self.num_output - 1) // 2 + (n_node_per_graph_before_concat + self.num_output)) latent_graphs = concat_graphs.replace(edges=tf.tile( tf.constant(self.edge_size * [0.])[None, :], [n_edge, 1])) latent_graphs.receivers.set_shape([n_edge]) latent_graphs.senders.set_shape([n_edge]) def _core(output_token_idx, latent_graphs, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d): batched_latent_graphs = graph_batch_reshape(latent_graphs) batched_input_nodes = batched_latent_graphs.nodes # [n_graphs, num_input + num_output, embedding_size] # todo: use self-attention latent_graphs = self.selfattention_core(latent_graphs) latent_graphs = self.edge_block(latent_graphs) latent_graphs = self.node_block(latent_graphs) batched_latent_graphs = graph_batch_reshape(latent_graphs) token_3d_logits = batched_latent_graphs.nodes[:, -self. num_output:, :] # n_graphs, num_output, num_embedding token_3d_logits -= tf.reduce_mean(token_3d_logits, axis=-1, keepdims=True) token_3d_logits /= tf.math.reduce_std(token_3d_logits, axis=-1, keepdims=True) reduce_logsumexp = tf.math.reduce_logsumexp( token_3d_logits, axis=-1) # [n_graphs, num_output] reduce_logsumexp = tf.tile( reduce_logsumexp[..., None], [1, 1, self.num_embedding ]) # [ n_graphs, num_output, num_embedding] token_3d_logits -= reduce_logsumexp token_distribution = tfp.distributions.RelaxedOneHotCategorical( temperature, logits=token_3d_logits) token_3d_samples_onehot = token_distribution.sample( (1, ), name='token_samples' ) # [1, n_graphs, num_output, num_embedding] token_3d_samples_onehot = token_3d_samples_onehot[ 0] # [n_graphs, num_output, num_embedding] # token_3d_samples_max_index = tf.math.argmax(token_3d_logits, axis=-1, output_type=tf.int32) # token_3d_samples_onehot = tf.cast(tf.tile(tf.range(self.num_embedding)[None, None, :], [n_graphs, self.num_output, 1]) == # token_3d_samples_max_index[:,:,None], tf.float32) # [n_graphs, num_output, num_embedding] token_3d_samples = tf.einsum( 'goe,ed->god', token_3d_samples_onehot, self.embeddings) # [n_graphs, num_ouput, embedding_dim] _mask = tf.range( self.num_output) == output_token_idx # [num_output] mask = tf.concat([ tf.zeros(n_node_per_graph_before_concat, dtype=tf.bool), _mask ], axis=0) # num_input + num_output mask = tf.tile( mask[None, :, None], [n_graphs, 1, self.embedding_size ]) # [n_graphs, num_input + num_output, embedding_size] kl_term = tf.reduce_sum( (token_3d_samples_onehot * token_3d_logits), axis=-1) # [n_graphs, num_output] kl_term = tf.reduce_sum(tf.cast(_mask, tf.float32) * kl_term, axis=-1) # [n_graphs] kl_term += prev_kl_term # n_graphs, n_node+num_output, embedding_size output_nodes = tf.where( mask, tf.concat([ tf.zeros([ n_graphs, n_node_per_graph_before_concat, self.embedding_size ]), token_3d_samples ], axis=1), batched_input_nodes) batched_latent_graphs = batched_latent_graphs.replace( nodes=output_nodes) latent_graphs = graph_unbatch_reshape(batched_latent_graphs) return (output_token_idx + 1, latent_graphs, kl_term, token_3d_samples_onehot, token_3d_logits) _, latent_graphs, kl_div, token_3d_samples_onehot, logits_3d = tf.while_loop( cond=lambda output_token_idx, state, _, __, ___: output_token_idx < self.num_output, body=lambda output_token_idx, state, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d: _core( output_token_idx, state, prev_kl_term, prev_token_3d_samples_onehot, prev_logits_3d), loop_vars=(tf.constant([0]), latent_graphs, tf.zeros((n_graphs, ), dtype=tf.float32), tf.zeros( (n_graphs, self.num_output, self.num_embedding), dtype=tf.float32), tf.zeros( (n_graphs, self.num_output, self.num_embedding), dtype=tf.float32))) latent_graphs = graph_batch_reshape(latent_graphs) token_nodes = latent_graphs.nodes[:, -self.num_output:, :] if self.do_basis_weight: # compute weights for how much each basis function will contribute, forcing later ones to contribute less. #todo: use self-attention basis_weight_graphs = self.selfattention_weights(latent_graphs) basis_weight_graphs = self.basis_weight_node_block( self.basis_weight_edge_block(basis_weight_graphs)) basis_weight_graphs = graph_batch_reshape(basis_weight_graphs) #[n_graphs, num_output] basis_weights = basis_weight_graphs.nodes[:, -self.num_output:, 0] #make the weights shrink with increasing component basis_weights = tf.math.cumprod(tf.nn.sigmoid(basis_weights), axis=-1) return token_nodes, kl_div, token_3d_samples_onehot, basis_weights, logits_3d else: return token_nodes, kl_div, token_3d_samples_onehot, logits_3d
def _sample_encoder(self, img): [batch, H, W, _] = get_shape(img) img = tf.reshape(img, [batch, H, W, self.num_channels ]) # number of channels must be known for conv2d return self.encoder(img)
def reconstruct_field(self, field_component_tokens, positions, basis_weights): """ Reconstruct the field at positions. Args: field_component_tokens: [num_token_samples, batch, output_size, embedding_dim_3d] positions: [num_token_samples, batch, n_node, 3] basis_weights: [num_token_samples, batch, output_size] Returns: [num_token_samples, batch, n_node, num_properties*2] """ pos_shape = get_shape(positions) def _single_sample(args): """ Compute for single batch. Args: tokens: [batch, output_size, embedding_dim_3d] positions: [batch, n_node, 3] basis_weights: [batch, output_size] Returns: [n_node, num_properties] """ tokens, positions, basis_weights = args def _single_batch(args): """ Compute for single batch. Args: tokens: [output_size, embedding_dim_3d] positions: [n_node, 3] basis_weights: [output_size] Returns: [n_node, num_properties] """ tokens, positions, basis_weights = args #output_size, num_properties*2 basis_weights = tf.concat([ tf.tile(basis_weights[:, None], [1, self.num_properties]), tf.ones([self.num_components, self.num_properties]) ], axis=-1) def _single_component(token): """ Compute for a single component. Args: token: [component_size] Returns: [n_node, num_properties] """ features = tf.concat( [ positions, tf.tile(token[None, :], [pos_shape[2], 1]) ], axis=-1) # [n_node, 3 + embedding_dim_3D] return self.field_reconstruction( features) # [n_node, num_properties] # basis_weights[:, None, :] 4, 1, 14 # tf.vectorized_map(_single_component, tokens) 4, 10000, 14 return basis_weights[:, None, :] * tf.vectorized_map( _single_component, tokens) # [num_components, n_node, num_properties] return tf.vectorized_map( _single_batch, (tokens, positions, basis_weights )) # [batch, num_components, n_node, num_properties] mu, log_stddev = tf.split( tf.vectorized_map( _single_sample, (field_component_tokens, positions, basis_weights)), num_or_size_splits=2, axis=-1 ) # [num_token_samples, batch, num_components, n_node_per_graph, num_properties] return mu, log_stddev
def reconstruct_field(self, field_component_tokens, positions): """ Reconstruct the field at positions. Args: field_component_tokens: [num_token_samples, batch, num_components, embedding_dim_3d] positions: [num_token_samples, batch, n_node, 3] Returns: [num_token_samples, batch, n_node, num_properties*2] """ _, batch, n_node, _ = get_shape(positions) #[num_token_samples*batch, V,V,V, embedding_dim_3d] latent_voxels = tf.reshape(field_component_tokens, [ self.num_token_samples * batch, self.voxel_per_dimension, self.voxel_per_dimension, self.voxel_per_dimension, self.component_size ]) # [num_token_samples*batch, N,N,N, num_properties*2] voxels = self.decoder_3d(latent_voxels) _, N, _, _, _ = get_shape(voxels) positions = tf.reshape(positions, (self.num_token_samples * batch, n_node, 3)) #interpolate the field def _interpolate_single_batch(args): """ Interpolate channels onto positions. Args: voxels: [N,N,N,num_properties*2] positions: [n_node, 3] Returns: [n_node, num_properties*2] """ (voxels, positions) = args pmin = tf.reduce_min(positions, axis=0) pmax = tf.reduce_max(positions, axis=0) arrays = [ tf.linspace(pmin[0], pmax[0], N), tf.linspace(pmin[1], pmax[1], N), tf.linspace(pmin[2], pmax[2], N) ] coords = [positions[:, 0], positions[:, 1], positions[:, 2]] def interp(x, xp, fp): i = tf.clip_by_value(tf.searchsorted(xp, x, side='right'), 0, xp.shape[0] - 1) df = tf.gather(fp, i) - tf.gather(fp, i - 1) dx = tf.gather(xp, i) - tf.gather(xp, i - 1) delta = x - tf.gather(xp, i - 1) f = tf.where((dx == 0), tf.cast(tf.gather(fp, i), tf.float32), tf.cast(tf.gather(fp, i - 1), tf.float32) + (delta / dx) * tf.cast(df, tf.float32)) return f * 0.99999 fractional_coordinates = [ interp(coord, array, tf.range(N)) for array, coord in zip(arrays, coords) ] voxels = tf.transpose(voxels, (3, 0, 1, 2)) #[num_properties*2, n_node] values_at_positions = tf.vectorized_map( lambda voxels: map_coordinates( voxels, fractional_coordinates, order=1), voxels) # [n_node, num_properties*2] values_at_positions = tf.transpose(values_at_positions, (1, 0)) return values_at_positions values_at_positions = tf.vectorized_map(_interpolate_single_batch, (voxels, positions)) values_at_positions = tf.reshape( values_at_positions, (self.num_token_samples, batch, n_node, 2 * self.num_properties)) # mu, log_stddev = tf.split(values_at_positions, num_or_size_splits=2, axis=-1) # [num_token_samples, batch, num_components, n_node_per_graph, num_properties] mu = values_at_positions[..., :self.num_properties] log_stddev = values_at_positions[..., self.num_properties:] return mu, log_stddev
def _incrementally_decode(self, token_samples_idx_2d): """ Args: token_samples_idx_2d: [batch, H2, W2] Returns: token_samples_idx_3d: [batch, H3,W3,D3] """ idx_dtype = token_samples_idx_2d.dtype batch, H2, W2 = get_shape(token_samples_idx_2d) H3 = W3 = D3 = self.discrete_voxel_vae.voxels_per_dimension // self.discrete_voxel_vae.shrink_factor token_samples_idx_2d = tf.reshape(token_samples_idx_2d, (batch, H2 * W2)) # N, _ = get_shape(self.positional_encodings) # batch, H3*W3*D3, num_embedding3 token_samples_idx_3d = tf.zeros((batch, H3 * W3 * D3), dtype=idx_dtype) def _core(output_token_idx, token_samples_idx_3d): """ Args: output_token_idx: which element is being replaced token_samples_idx_3d: [batch, H3*W3*D3] """ # [batch, 1 + H2*W2 + 1 + H3*W3*D3 + 1] sequence = self.construct_sequence(token_samples_idx_2d, token_samples_idx_3d) input_sequence = sequence[:, :-1] input_graphs = self.construct_input_graph(input_sequence, H2*W2) latent_logits = self.compute_logits(input_graphs) #batch, H3 * W3 * D3, num_embedding3 # . a b . c d # a b . c d . prior_latent_logits_3d = latent_logits[:, H2*W2+1:H2*W2+1+H3*W3*D3, self.discrete_image_vae.num_embedding:self.discrete_image_vae.num_embedding + self.discrete_voxel_vae.num_embedding] prior_dist = tfp.distributions.Categorical(logits=prior_latent_logits_3d, dtype=idx_dtype) prior_latent_tokens_idx_3d = prior_dist.sample(1)[0] # batch, H3*W3*D3 # import pylab as plt # # plt.imshow(tf.one_hot(prior_latent_tokens_idx_3d[0, :30], self.discrete_voxel_vae.num_embedding)) # plt.imshow(latent_logits[0, 1020:1050], aspect='auto', interpolation='nearest') # plt.show() _mask = tf.range(H3 * W3 * D3) == output_token_idx # [H3*W3*D3] output_token_samples_idx_3d = tf.where(_mask[None, :], prior_latent_tokens_idx_3d, token_samples_idx_3d ) return (output_token_idx + 1, output_token_samples_idx_3d) _, token_samples_idx_3d = tf.while_loop( cond=lambda output_token_idx, _: output_token_idx < (H3 * W3 * D3), body=_core, loop_vars=(tf.convert_to_tensor(0), token_samples_idx_3d)) # latent_graphs = GraphsTuple(**latent_graphs_data_dict, edges=None, globals=None) token_samples_idx_3d = tf.reshape(token_samples_idx_3d, (batch, H3, W3, D3)) return token_samples_idx_3d
def initialize_positional_encodings(self, nodes): _, n_node, _ = get_shape(nodes) self.positional_encodings = tf.Variable( initial_value=tf.random.truncated_normal((n_node, self.embedding_dim)), name='positional_encodings')
def _im_to_components(self, im, positions, temperature): ''' Args: im: [batch=1, H, W, 2*C] positions: [num_positions, 3] temperature: scalar Returns: ''' latent_logits = self.encoder_2d(im) # batch, H, W, num_embeddings [_, H, W, num_embedding] = get_shape(latent_logits) latent_logits = tf.reshape( latent_logits, [self.batch, H * W, num_embedding]) # [batch, H*W, num_embeddings] reduce_logsumexp = tf.math.reduce_logsumexp(latent_logits, axis=-1) # [batch, H*W] reduce_logsumexp = tf.tile( reduce_logsumexp[..., None], [1, 1, num_embedding]) # [batch, H*W, num_embedding] latent_logits -= reduce_logsumexp # [batch, H*W, num_embeddings] token_samples_onehot, token_samples = self.sample_latent_2d( latent_logits, temperature, self.num_token_samples ) # [num_token_samples, batch, H*W, num_embedding / embedding_dim] [_, _, _, embedding_dim] = get_shape(token_samples) token_samples = tf.reshape(token_samples, [ self.num_token_samples * self.batch * self.n_node_per_graph, embedding_dim ]) # [num_token_samples * batch * H * W, embedding_dim] token_graphs = GraphsTuple( nodes=token_samples, edges=None, globals=None, senders=None, receivers=None, n_node=tf.constant(self.num_token_samples * self.batch * [self.n_node_per_graph], dtype=tf.int32), n_edge=tf.constant(self.num_token_samples * self.batch * [0], dtype=tf.int32) ) # [n_node, embedding_dim], n_node = n_graphs * n_node_per_graph tokens_3d, kl_div, token_3d_samples_onehot, tokens_3d_logits = self.autoregressive_prior( token_graphs, temperature ) # [n_graphs, num_output, embedding_dim_3d], [n_graphs], [n_graphs, num_output, num_embedding_3d], [n_graphs, num_output] tokens_3d = tf.reshape(tokens_3d, [ self.num_token_samples, self.batch, self.num_components, self.component_size ]) # [num_token_samples, batch, num_output, embedding_dim_3d] mu, log_stddev = self.reconstruct_field( tokens_3d, positions ) # [num_token_samples, batch, n_node_per_graph, num_properties] # take off the fake sample and batch dims return mu[0, 0, :, :], log_stddev[ 0, 0, :, :] # [n_node_per_graph, num_properties]
def _build(self, voxels, images): idx_dtype = tf.int32 latent_logits_2d = self.discrete_image_vae.compute_logits(images) # [batch, H, W, num_embeddings] latent_logits_3d = self.discrete_voxel_vae.compute_logits(voxels) # [batch, H, W, D, num_embeddings] batch, H2, W2, _ = get_shape(latent_logits_2d) batch, H3, W3, D3, _ = get_shape(latent_logits_3d) G = self.num_token_samples * batch latent_logits_2d = tf.reshape(latent_logits_2d, (batch, H2 * W2, self.discrete_image_vae.num_embedding)) latent_logits_3d = tf.reshape(latent_logits_3d, (batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding)) q_dist_2d = tfp.distributions.Categorical(logits=latent_logits_2d, dtype=idx_dtype) token_samples_idx_2d = q_dist_2d.sample(self.num_token_samples) # [num_samples, batch, H2*W2] token_samples_idx_2d = tf.reshape(token_samples_idx_2d, (G, H2*W2)) q_dist_3d = tfp.distributions.Categorical(logits=latent_logits_3d, dtype=idx_dtype) token_samples_idx_3d = q_dist_3d.sample( self.num_token_samples) # [num_samples, batch, H3*W3*D3] token_samples_idx_3d = tf.reshape(token_samples_idx_3d, (G, H3*W3*D3)) entropy_2d = tf.reduce_sum(q_dist_2d.entropy(), axis=-1) entropy_3d = tf.reduce_sum(q_dist_3d.entropy(), axis=-1) entropy = entropy_3d + entropy_2d # [batch] ## create sequence sequence = self.construct_sequence(token_samples_idx_2d, token_samples_idx_3d) input_sequence = sequence[:, :-1] input_graphs = self.construct_input_graph(input_sequence, H2*W2) latent_logits = self.compute_logits(input_graphs) prior_dist = tfp.distributions.Categorical(logits=latent_logits, dtype=idx_dtype) output_sequence = sequence[:, 1:] cross_entropy = -prior_dist.log_prob(output_sequence)#num_samples*batch, H2*W2+1+H3*W3*D3+1 # . a . b # a . b . cross_entropy = cross_entropy[:, H2*W2+1:-1] cross_entropy = tf.reshape(tf.reduce_sum(cross_entropy, axis=-1), (self.num_token_samples, batch)) # num_samples,batch kl_term = cross_entropy + entropy # [num_samples, batch] kl_div = tf.reduce_mean(kl_term) # scalar # elbo = tf.stop_gradient(var_exp) - self.beta * kl_div elbo = - kl_div # scalar loss = - elbo # scalar if self.log_counter % int(128 / kl_term.get_shape()[0]) == 0: prior_latent_logits_2d = tf.reshape(latent_logits[:, :H2*W2, :self.discrete_image_vae.num_embedding], (self.num_token_samples, batch, H2, W2, self.discrete_image_vae.num_embedding)) prior_latent_logits_3d = tf.reshape(latent_logits[:, H2*W2+1:H2*W2+1+H3*W3*D3, self.discrete_image_vae.num_embedding:self.discrete_image_vae.num_embedding+self.discrete_voxel_vae.num_embedding], (self.num_token_samples, batch, H3, W3, D3, self.discrete_voxel_vae.num_embedding)) latent_logits_2d = tf.reshape(latent_logits_2d, (batch, H2, W2, self.discrete_image_vae.num_embedding)) latent_logits_3d = tf.reshape(latent_logits_3d, (batch, H3, W3, D3, self.discrete_voxel_vae.num_embedding)) self.write_summary(images, voxels, latent_logits_2d, latent_logits_3d, prior_latent_logits_2d[0], prior_latent_logits_3d[0]) return dict(loss=loss)
def _build(self, graphs, imgs, **kwargs) -> dict: # graphs.nodes: [batch, n_node_per_graph, 3+num_properties] # imgs: [batch, H', W', C] latent_logits = self.encoder_2d(imgs) #batch, H, W, num_embeddings [_, H, W, num_embedding] = get_shape(latent_logits) latent_logits = tf.reshape( latent_logits, [self.batch, H * W, num_embedding]) # [batch, H*W, num_embeddings] latent_logits -= tf.reduce_mean(latent_logits, axis=-1, keepdims=True) latent_logits /= tf.math.reduce_std(latent_logits, axis=-1, keepdims=True) reduce_logsumexp = tf.math.reduce_logsumexp(latent_logits, axis=-1) # [batch, H*W] reduce_logsumexp = tf.tile( reduce_logsumexp[..., None], [1, 1, num_embedding]) # [batch, H*W, num_embedding] latent_logits -= reduce_logsumexp # [batch, H*W, num_embeddings] # temperature = tf.maximum(0.1, tf.cast(10. - 0.1 * (self.step / 1000), tf.float32)) # temperature = 10. token_samples_onehot, token_samples = self.sample_latent_2d( latent_logits, self.temperature, self.num_token_samples ) # [num_token_samples, batch, H*W, num_embedding / embedding_dim] # token_samples = tf.reshape(token_samples, [self.num_token_samples * self.batch, H * W, self.component_size]) # [num_token_samples*batch, H*W, embedding_dim] [_, _, _, embedding_dim] = get_shape(token_samples) n_graphs = self.num_token_samples * self.batch token_samples = tf.reshape( token_samples, [n_graphs * self.n_node_per_graph, embedding_dim ]) # [num_token_samples * batch * H * W, embedding_dim] token_graphs = GraphsTuple( nodes=token_samples, edges=None, globals=None, senders=None, receivers=None, n_node=tf.constant(n_graphs * [self.n_node_per_graph], dtype=tf.int32), n_edge=tf.constant(n_graphs * [0], dtype=tf.int32) ) # [n_node, embedding_dim], n_node = n_graphs * n_node_per_graph tokens_3d, kl_div, token_3d_samples_onehot, logits_3d = self.autoregressive_prior( token_graphs, self.temperature ) # [n_graphs, num_output, embedding_dim_3d], [n_graphs], [n_graphs, num_output, num_embedding_3d], [n_graphs, num_output] tokens_3d = tf.reshape(tokens_3d, [ self.num_token_samples, self.batch, self.num_components, self.component_size ]) # [num_token_samples, batch, num_output, embedding_dim_3d] kl_div = tf.reshape( kl_div, [self.num_token_samples, self.batch]) # [num_token_samples, batch] # properties: [num_token_samples, batch, n_node_per_graph, 3+num_properties] # tokens_3d: [num_token_samples, batch, num_output, embedding_dim_3d] properties = tf.tile(graphs.nodes[None, ...], [self.num_token_samples, 1, 1, 1]) field_properties, log_likelihood = self.log_likelihood( tokens_3d, properties) # field_properties: [num_token_samples, batch, n_node_per_graph, num_properties] # log_likelihood: [num_token_samples, batch] field_properties = tf.reduce_mean( field_properties, axis=0) # [batch, n_node_per_graph, num_properties] var_exp = tf.reduce_mean(log_likelihood, axis=0) # [batch] kl_div = tf.reduce_mean(kl_div, axis=0) # [batch] elbo = tf.reduce_mean(var_exp - self.beta * kl_div) # scalar entropy = -tf.reduce_sum(latent_logits * tf.math.exp(latent_logits), axis=-1) # [batch, H, W] perplexity = 2.**(-entropy / tf.math.log(2.)) mean_perplexity = tf.reduce_mean(perplexity) loss = -elbo # maximize ELBO so minimize -ELBO [_, _, _, num_channels] = get_shape(imgs) if self.step % 10 == 0: tf.summary.scalar('perplexity', mean_perplexity, step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step) tf.summary.scalar('temperature', self.temperature, step=self.step) token_3d_samples_onehot -= tf.reduce_min(token_3d_samples_onehot) token_3d_samples_onehot /= tf.reduce_max(token_3d_samples_onehot) tf.summary.image('token_3d_samples_onehot', token_3d_samples_onehot[..., None], step=self.step) logits_3d -= tf.reduce_min(logits_3d) logits_3d /= tf.reduce_max(logits_3d) tf.summary.image('logits_3d', logits_3d[..., None], step=self.step) token_samples_onehot = token_samples_onehot[0] token_samples_onehot -= tf.reduce_min(token_samples_onehot) token_samples_onehot /= tf.reduce_max(token_samples_onehot) tf.summary.image('token_samples_onehot', token_samples_onehot[..., None], step=self.step) if self.step % 50 == 0: input_properties = graphs.nodes[0] reconstructed_properties = field_properties[0] pos = tf.reverse(input_properties[:, :2], [1]) for i in range(num_channels): img_i = imgs[0, ..., i][None, ..., None] img_i = (img_i - tf.reduce_min(img_i)) / ( tf.reduce_max(img_i) - tf.reduce_min(img_i)) tf.summary.image(f'img_before_autoencoder_{i}', img_i, step=self.step) for i in range(self.num_properties): image_before, _ = histogramdd(pos, bins=64, weights=input_properties[:, 3 + i]) image_before -= tf.reduce_min(image_before) image_before /= tf.reduce_max(image_before) tf.summary.image(f"{3+i}_xy_image_before_b", image_before[None, :, :, None], step=self.step) tf.summary.scalar(f"properties{3+i}_std_before", tf.math.reduce_std(input_properties[:, 3 + i]), step=self.step) image_after, _ = histogramdd( pos, bins=64, weights=reconstructed_properties[:, i]) # image_after, _ = histogramdd(pos, bins=50, weights=tf.random.truncated_normal((10000, ))) image_after -= tf.reduce_min(image_after) image_after /= tf.reduce_max(image_after) tf.summary.image(f"{3+i}_xy_image_after", image_after[None, :, :, None], step=self.step) tf.summary.scalar(f"properties{3+i}_std_after", tf.math.reduce_std( reconstructed_properties[:, i]), step=self.step) return dict(loss=loss, metrics=dict(var_exp=var_exp, kl_term=kl_div, mean_perplexity=mean_perplexity))
def _build(self, img, **kwargs) -> dict: """ Args: img: [batch, H', W', num_channel] **kwargs: Returns: """ encoded_img_logits = self.encoder(img) # [batch, H, W, num_embedding] [batch, H, W, _] = get_shape(encoded_img_logits) logits = tf.reshape( encoded_img_logits, [batch * H * W, self.num_embedding]) # [batch*H*W, num_embeddings] logits -= tf.reduce_mean(logits, axis=-1) logits /= tf.math.reduce_std(logits, axis=-1) reduce_logsumexp = tf.math.reduce_logsumexp(logits, axis=-1) # [batch*H*W] reduce_logsumexp = tf.tile( reduce_logsumexp[:, None], [1, self.num_embedding]) # [batch*H*W, num_embedding] logits -= reduce_logsumexp # [batch*H*W, num_embeddings] # temperature = tf.maximum(0.1, tf.cast(10. - 0.1 * (self.step / 1000), tf.float32)) token_distribution = tfp.distributions.RelaxedOneHotCategorical( self.temperature, logits=logits) # token_distribution = tfp.distributions.OneHotCategorical(logits=logits) token_samples_onehot = token_distribution.sample( (self.num_token_samples, ), name='token_samples') # [S, batch*H*W, num_embeddings] token_samples_onehot = tf.cast(token_samples_onehot, dtype=tf.float32) def _single_decode_part_1(token_sample_onehot): #[batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim] token_sample = tf.matmul( token_sample_onehot, self.embeddings) # [batch*H*W, embedding_dim] # = z ~ q(z|x) latent_img = tf.reshape(token_sample, [batch, H, W, self.embedding_dim ]) # [batch, H, W, embedding_dim] return latent_img def _single_decode_part_2(args): decoded_img, token_sample_onehot = args img_mu = decoded_img[..., :self.num_channels] #[batch, H', W', C] img_logb = decoded_img[..., self.num_channels:] log_likelihood = self.log_likelihood(img, img_mu, img_logb) #[batch, H', W'] log_likelihood = tf.reduce_sum(log_likelihood, axis=[-2, -1]) # [batch] sum_selected_logits = tf.math.reduce_sum(token_sample_onehot * logits, axis=-1) # [batch*H*W] sum_selected_logits = tf.reshape(sum_selected_logits, [batch, H, W]) kl_term = tf.reduce_sum(sum_selected_logits, axis=[-2, -1]) #[batch] return log_likelihood, kl_term, decoded_img latent_imgs = tf.vectorized_map( _single_decode_part_1, token_samples_onehot) # [S, batch, H, W, embedding_dim] # decoder outside the vectorized map, first merge batch and sample dimension latent_imgs = tf.reshape( latent_imgs, [self.num_token_samples * batch, H, W, self.embedding_dim ]) # [S * batch, H, W, embedding_dim] decoded_imgs = self.decoder(latent_imgs) # [S * batch, H', W', C*2] # reshape again for input of the second vectorized map [_, H_2, W_2, _] = get_shape(decoded_imgs) decoded_imgs = tf.reshape( decoded_imgs, [self.num_token_samples, batch, H_2, W_2, 2 * self.num_channels ]) # [S, batch, H, W, embedding_dim] log_likelihood_samples, kl_term_samples, decoded_ims = tf.vectorized_map( _single_decode_part_2, (decoded_imgs, token_samples_onehot)) var_exp = tf.reduce_mean(log_likelihood_samples, axis=0) # [batch] kl_div = tf.reduce_mean(kl_term_samples, axis=0) # [batch] elbo = var_exp - kl_div # batch loss = -tf.reduce_mean(elbo) # scalar latent_logits = tf.reshape(logits, [batch, H, W, self.num_embedding]) token_samples_onehot_perp = tf.reshape( token_samples_onehot, [self.num_token_samples, batch, H, W, self.num_embedding]) q_dist = tfp.distributions.RelaxedOneHotCategorical( self.temperature, logits=latent_logits) log_prob_q = q_dist.log_prob( token_samples_onehot_perp) # num_samples, batch, H, W entropy = -tf.reduce_sum(tf.math.exp(log_prob_q) * log_prob_q, axis=[-1, -2, -3]) # [S, batch] perplexity = 2.**(entropy / tf.math.log(2.)) # [S, batch, H, W] mean_perplexity = tf.reduce_mean(perplexity) # scalar # entropy = -tf.reduce_sum(logits * token_samples_onehot, axis=-1) # [S, batch*H*W] # perplexity = 2. ** (-entropy / tf.math.log(2.)) # [S, batch*H*W] # mean_perplexity = tf.reduce_mean(perplexity) # scalar if self.step % 10 == 0: for i in range(self.num_channels): img_mu_0 = tf.reduce_mean(decoded_ims, axis=0)[..., i][..., None] img_mu_0 -= tf.reduce_min(img_mu_0) img_mu_0 /= tf.reduce_max(img_mu_0) tf.summary.image(f'mu_{i}', img_mu_0, step=self.step) img_i = img[..., i][..., None] img_i = (img_i - tf.reduce_min(img_i)) / ( tf.reduce_max(img_i) - tf.reduce_min(img_i)) tf.summary.image(f'img_before_autoencoder_{i}', img_i, step=self.step) # logits = tf.nn.softmax(logits, axis=-1) # [batch*H*W, num_embedding] logits -= tf.reduce_min(logits) logits /= tf.reduce_max(logits) logits = tf.reshape(logits, [batch, H * W, self.num_embedding ]) # [batch, H*W, num_embedding] tf.summary.image('logits', logits[:, :, :, None], step=self.step) token_sample_onehot = token_samples_onehot[0, ...] token_sample_onehot -= tf.reduce_min( token_sample_onehot) # [batch*H*W, num_embeddings] token_sample_onehot /= tf.reduce_max(token_sample_onehot) token_sample_onehot = tf.reshape( token_sample_onehot, [batch, H * W, self.num_embedding ]) # [batch, H*W, num_embedding] tf.summary.image('token_sample_onehot', token_sample_onehot[:, :, :, None], step=self.step) latent_img = latent_imgs[:, :, :, 0] latent_img -= tf.reduce_min(latent_img) latent_img /= tf.reduce_max(latent_img) tf.summary.image('latent_im', latent_img[..., None], step=self.step) if self.step % 10 == 0: tf.summary.scalar('perplexity', mean_perplexity, step=self.step) tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step) tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step) tf.summary.scalar('temperature', self.temperature, step=self.step) return dict(loss=loss, metrics=dict(var_exp=var_exp, kl_div=kl_div, mean_perplexity=mean_perplexity))