def _build(self, latent):
        self.initialize(latent)
        n_node, _ = get_shape(latent.nodes)
        node_values = self.v_linear(latent.nodes)
        node_keys = self.k_linear(latent.nodes)
        node_queries = self.q_linear(latent.nodes)  # n_node, num_head, F

        node_keys = self.ln_keys(node_keys)
        node_queries = self.ln_queries(node_queries)
        _, _, d_k = get_shape(node_keys)
        node_queries /= tf.math.sqrt(tf.cast(d_k,
                                             node_queries.dtype))  # n_node, F

        attended_latent = self.self_attention(node_values=node_values,
                                              node_keys=node_keys,
                                              node_queries=node_queries,
                                              attention_graph=latent)
        # n_nodes, heads, output_size -> n_nodes, heads*output_size
        output_nodes = tf.reshape(
            attended_latent.nodes,
            (n_node, self.num_heads * self.input_node_size))
        output_nodes = self.ln1(
            self.output_linear(output_nodes) + latent.nodes)
        output_nodes = self.ln2(self.FFN(output_nodes))
        output_graph = latent.replace(nodes=output_nodes)
        return output_graph
Ejemplo n.º 2
0
    def construct_sequence(self, token_samples_idx_2d, token_samples_idx_3d):
        """

        Args:
            token_samples_idx_2d: [G, H2*W2]
            token_samples_idx_3d: [G, H3*W3*D3]

        Returns:
            sequence: G, 1 + H2*W2 + 1 + H3*W3*D3 + 1
        """
        idx_dtype = token_samples_idx_2d.dtype
        G, N2 = get_shape(token_samples_idx_2d)
        G, N3 = get_shape(token_samples_idx_3d)
        start_token_idx = tf.constant(self.num_embedding - 3, dtype=idx_dtype)
        del_token_idx = tf.constant(self.num_embedding - 2, dtype=idx_dtype)
        eos_token_idx = tf.constant(self.num_embedding - 1, dtype=idx_dtype)
        start_token = tf.fill((G, 1), start_token_idx)
        del_token = tf.fill((G, 1), del_token_idx)
        eos_token = tf.fill((G, 1), eos_token_idx)
        ###
        # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3 + 1
        sequence = tf.concat([
            start_token,
            token_samples_idx_2d,
            del_token,
            token_samples_idx_3d + self.discrete_image_vae.num_embedding,  # shift to right
            eos_token
        ], axis=-1)
        return sequence
    def compute_likelihood_parameters(self, latent_tokens):
        """
        Compute the likelihood parameters from logits.

        Args:
            latent_tokens: [num_samples, batch, W, H, embedding_size]

        Returns:
            mu, logb: [num_samples, batch, W', H', num_channels]

        """
        [num_samples, batch, H, W, _] = get_shape(latent_tokens)
        latent_tokens = tf.reshape(
            latent_tokens, [num_samples * batch, H, W, self.embedding_dim
                            ])  #[num_samples*batch, H, W, self.embedding_dim]
        decoded_imgs = self._decoder(
            latent_tokens)  # [num_samples * batch, H', W', C*2]
        decoded_imgs.set_shape([None, None, None, self.num_channels * 2])
        [_, H_2, W_2, _] = get_shape(decoded_imgs)
        decoded_imgs = tf.reshape(
            decoded_imgs,
            [num_samples, batch, H_2, W_2, 2 * self.num_channels
             ])  # [S, batch, H, W, embedding_dim]
        mu, logb = decoded_imgs[..., :self.num_channels], decoded_imgs[
            ..., self.num_channels:]
        return mu, logb  # [S, batch, H', W' C], [S, batch, H', W' C]
Ejemplo n.º 4
0
    def sample_decoder(self, img_logits, temperature, num_samples):
        [batch, H, W, _] = get_shape(img_logits)

        logits = tf.reshape(
            img_logits,
            [batch * H * W, self.num_embedding])  # [batch*H*W, num_embeddings]
        reduce_logsumexp = tf.math.reduce_logsumexp(logits,
                                                    axis=-1)  # [batch*H*W]
        reduce_logsumexp = tf.tile(
            reduce_logsumexp[:, None],
            [1, self.num_embedding])  # [batch*H*W, num_embedding]
        logits -= reduce_logsumexp  # [batch*H*W, num_embeddings]
        token_distribution = tfp.distributions.RelaxedOneHotCategorical(
            temperature, logits=logits)
        token_samples_onehot = token_distribution.sample(
            (num_samples, ),
            name='token_samples')  # [S, batch*H*W, num_embeddings]

        def _single_decode(token_sample_onehot):
            # [batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim]
            token_sample = tf.matmul(
                token_sample_onehot,
                self.embeddings)  # [batch*H*W, embedding_dim]  # = z ~ q(z|x)
            latent_img = tf.reshape(token_sample,
                                    [batch, H, W, self.embedding_dim
                                     ])  # [batch, H, W, embedding_dim]
            decoded_img = self.decoder(latent_img)  # [batch, H', W', C*2]
            return decoded_img

        decoded_ims = tf.vectorized_map(
            _single_decode, token_samples_onehot)  # [S, batch, H', W', C*2]
        decoded_im = tf.reduce_mean(decoded_ims,
                                    axis=0)  # [batch, H', W', C*2]
        return decoded_im
Ejemplo n.º 5
0
    def construct_input_graph(self, input_sequence, N2):
        G, N = get_shape(input_sequence)
        # num_samples*batch, 1 + H2*W2 + 1 + H3*W3*D3, embedding_dim
        input_tokens = tf.nn.embedding_lookup(self.embeddings, input_sequence)
        self.initialize_positional_encodings(input_tokens)
        nodes = input_tokens + self.positional_encodings
        n_node = tf.fill([G], N)
        n_edge = tf.zeros_like(n_node)
        data_dict = dict(nodes=nodes, edges=None, senders=None, receivers=None, globals=None,
                         n_node=n_node,
                         n_edge=n_edge)
        concat_graphs = GraphsTuple(**data_dict)
        concat_graphs = graph_unbatch_reshape(concat_graphs)  # [n_graphs * (num_input + num_output), embedding_size]
        # nodes, senders, receivers, globals
        def edge_connect_rule(sender, receiver):
            # . a . b -> a . b .
            complete_2d = (sender < N2 + 1) & (receiver < N2 + 1) & (
                        sender + 1 != receiver)  # exclude senders from one-right, so it doesn't learn copy.
            auto_regressive_3d = (sender <= receiver) & (
                        receiver >= N2 + 1)  # auto-regressive (excluding 2d) with self-loops
            return complete_2d | auto_regressive_3d

        # nodes, senders, receivers, globals
        concat_graphs = connect_graph_dynamic(concat_graphs, edge_connect_rule)
        return concat_graphs
    def _sample_decoder(self, img_logits, temperature, num_samples):
        [batch, H, W, _] = get_shape(img_logits)

        logits = tf.reshape(
            img_logits,
            [batch * H * W, self.num_embedding])  # [batch*H*W, num_embeddings]
        logits -= tf.reduce_mean(logits, axis=-1, keepdims=True)
        logits /= tf.math.reduce_std(logits, axis=-1, keepdims=True)
        reduce_logsumexp = tf.math.reduce_logsumexp(
            logits, axis=-1, keepdims=True)  # [batch*H*W]
        # reduce_logsumexp = tf.tile(reduce_logsumexp[:, None], [1, self.num_embedding])  # [batch*H*W, num_embedding]
        logits -= reduce_logsumexp  # [batch*H*W, num_embeddings]
        token_distribution = tfp.distributions.RelaxedOneHotCategorical(
            temperature, logits=logits)
        token_samples_onehot = token_distribution.sample(
            (num_samples, ),
            name='token_samples')  # [S, batch*H*W, num_embeddings]

        def _single_latent_img(token_sample_onehot):
            # [batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim]
            token_sample = tf.matmul(
                token_sample_onehot,
                self.embeddings)  # [batch*H*W, embedding_dim]  # = z ~ q(z|x)
            latent_img = tf.reshape(token_sample,
                                    [batch, H, W, self.embedding_dim
                                     ])  # [batch, H, W, embedding_dim]
            return latent_img

        latent_imgs = tf.vectorized_map(
            _single_latent_img,
            token_samples_onehot)  # [S, batch, H, W, embedding_dim]
        latent_imgs = tf.reshape(
            latent_imgs, [num_samples * batch, H, W, self.embedding_dim
                          ])  # [S * batch, H, W, embedding_dim]
        decoded_imgs = self.decoder(latent_imgs)  # [S * batch, H', W', C*2]
        [_, H_2, W_2, _] = get_shape(decoded_imgs)
        decoded_imgs = tf.reshape(
            decoded_imgs,
            [num_samples, batch, H_2, W_2, 2 * self.num_channels
             ])  # [S, batch, H, W, embedding_dim]
        # decoded_img = tf.reduce_mean(decoded_imgs, axis=0)  # [batch, H', W', C*2]
        decoded_imgs = decoded_imgs[..., :self.num_channels]
        return decoded_imgs  # [S, batch, H', W', C]
Ejemplo n.º 7
0
    def initialize(self, graphs):
        input_node_size = get_shape(graphs.nodes)[-1]
        self.input_node_size = input_node_size

        self.v_linear = MultiHeadLinear(output_size=input_node_size, num_heads=self.num_heads, name='mhl1')  # values
        self.k_linear = MultiHeadLinear(output_size=input_node_size, num_heads=self.num_heads, name='mhl2')  # keys
        self.q_linear = MultiHeadLinear(output_size=input_node_size, num_heads=self.num_heads, name='mhl3')  # queries

        self.FFN = snt.nets.MLP([input_node_size, input_node_size], activate_final=False,
                                name='ffn')  # Feed forward network
        self.output_linear = snt.Linear(output_size=input_node_size, name='output_linear')
    def log_prob_q(self, latent_logits, log_token_samples_onehot):
        """
        Args:
            latent_logits: [batch, H, W, num_embeddings] (normalised)
            log_token_samples_onehot: [num_samples, batch, H, W, num_embeddings]

        Returns:

        """
        _, H, W, _ = get_shape(latent_logits)
        q_dist = tfp.distributions.ExpRelaxedOneHotCategorical(
            self.temperature, logits=latent_logits)
        log_prob_q = q_dist.log_prob(
            log_token_samples_onehot)  # num_samples, batch, H, W
        return log_prob_q
Ejemplo n.º 9
0
    def initialize(self, graphs: GraphsTuple):
        in_node_size = get_shape(graphs.nodes)[-1]
        node_model_fn = lambda: snt.nets.MLP([in_node_size, in_node_size], activate_final=True,
                                             activation=tf.nn.relu,
                                             name='node_fn')
        edge_model_fn = lambda: snt.nets.MLP([in_node_size, in_node_size],
                                             activate_final=True, activation=tf.nn.relu,
                                             name='edge_fn')

        self.edge_block = blocks.EdgeBlock(edge_model_fn,
                                           use_edges=self.use_edges,
                                           use_receiver_nodes=False,
                                           use_sender_nodes=True,
                                           use_globals=self.use_globals)

        self.node_block = blocks.NodeBlock(node_model_fn,
                                           use_received_edges=True,
                                           use_sent_edges=False,
                                           use_nodes=True,
                                           use_globals=self.use_globals)
Ejemplo n.º 10
0
    def write_summary(self,images, voxels,
                               latent_logits_2d,
                               latent_logits_3d,
                               prior_latent_logits_2d,
                               prior_latent_logits_3d):

        dist_2d = tfp.distributions.OneHotCategorical(logits=latent_logits_2d, dtype=latent_logits_2d.dtype)
        dist_3d = tfp.distributions.OneHotCategorical(logits=latent_logits_3d, dtype=latent_logits_3d.dtype)
        token_samples_onehot_2d = dist_2d.sample(1)[0]
        token_samples_onehot_3d = dist_3d.sample(1)[0]

        dist_2d_prior = tfp.distributions.OneHotCategorical(logits=prior_latent_logits_2d, dtype=prior_latent_logits_2d.dtype)
        dist_3d_prior = tfp.distributions.OneHotCategorical(logits=prior_latent_logits_3d, dtype=prior_latent_logits_3d.dtype)
        prior_token_samples_onehot_2d = dist_2d_prior.sample(1)[0]
        prior_token_samples_onehot_3d = dist_3d_prior.sample(1)[0]

        kl_div_2d = tf.reduce_mean(tf.reduce_sum(dist_2d.kl_divergence(dist_2d_prior), axis=[-1,-2]))
        kl_div_3d = tf.reduce_mean(tf.reduce_sum(dist_3d.kl_divergence(dist_3d_prior), axis=[-1,-2,-3]))
        tf.summary.scalar('kl_div_2d', kl_div_2d, step=self.step)
        tf.summary.scalar('kl_div_3d', kl_div_3d, step=self.step)
        tf.summary.scalar('kl_div', kl_div_2d + kl_div_3d, step=self.step)

        perplexity_2d = 2. ** (dist_2d_prior.entropy() / tf.math.log(2.))  #
        mean_perplexity_2d = tf.reduce_mean(perplexity_2d)  # scalar

        perplexity_3d = 2. ** (dist_3d_prior.entropy() / tf.math.log(2.))  #
        mean_perplexity_3d = tf.reduce_mean(perplexity_3d)  #

        tf.summary.scalar('perplexity_2d_prior', mean_perplexity_2d, step=self.step)
        tf.summary.scalar('perplexity_3d_prior', mean_perplexity_3d, step=self.step)

        prior_latent_tokens_2d = tf.einsum('sbhwd,de->sbhwe', prior_token_samples_onehot_2d[None], self.discrete_image_vae.embeddings)
        prior_latent_tokens_3d = tf.einsum('sbhwdn,ne->sbhwde', prior_token_samples_onehot_3d[None], self.discrete_voxel_vae.embeddings)
        mu_2d, logb_2d = self.discrete_image_vae.compute_likelihood_parameters(
            prior_latent_tokens_2d)  # [num_samples, batch, H', W', C], [num_samples, batch, H', W', C]
        log_likelihood_2d = self.discrete_image_vae.log_likelihood(images, mu_2d, logb_2d)  # [num_samples, batch]
        var_exp_2d = tf.reduce_mean(log_likelihood_2d)  # [scalar]
        mu_3d, logb_3d = self.discrete_voxel_vae.compute_likelihood_parameters(
            prior_latent_tokens_3d)  # [num_samples, batch, H', W', D', C], [num_samples, batch, H', W', D', C]
        log_likelihood_3d = self.discrete_voxel_vae.log_likelihood(voxels, mu_3d, logb_3d)  # [num_samples, batch]
        var_exp_3d = tf.reduce_mean(log_likelihood_3d)  # [scalar]
        var_exp = log_likelihood_2d + log_likelihood_3d

        tf.summary.scalar('var_exp_3d', tf.reduce_mean(var_exp_3d), step=self.step)
        tf.summary.scalar('var_exp_2d', tf.reduce_mean(var_exp_2d), step=self.step)
        tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step)


        projected_mu = tf.reduce_sum(mu_3d[0], axis=-2)  # [batch, H', W', C]
        projected_img = tf.reduce_sum(voxels, axis=-2)  # [batch, H', W', C]
        for i in range(self.discrete_voxel_vae.num_channels):
            vmin = tf.reduce_min(projected_mu[..., i])
            vmax = tf.reduce_max(projected_mu[..., i])
            _projected_mu = (projected_mu[..., i:i + 1] - vmin) / (vmax - vmin)  # batch, H', W', 1
            _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.)

            vmin = tf.reduce_min(projected_img[..., i])
            vmax = tf.reduce_max(projected_img[..., i])
            _projected_img = (projected_img[..., i:i + 1] - vmin) / (vmax - vmin)  # batch, H', W', 1
            _projected_img = tf.clip_by_value(_projected_img, 0., 1.)

            tf.summary.image(f'voxels_predict_prior[{i}]', _projected_mu, step=self.step)
            tf.summary.image(f'voxels_actual[{i}]', _projected_img, step=self.step)

        for name, _latent_logits_3d, _tokens_onehot_3d in zip(['', '_prior'],
                                          [latent_logits_3d,prior_latent_logits_3d],
                                                              [token_samples_onehot_3d, prior_token_samples_onehot_3d]):
            batch, H3, W3, D3, _ = get_shape(_latent_logits_3d)
            _latent_logits_3d -= tf.reduce_min(_latent_logits_3d, axis=-1, keepdims=True)
            _latent_logits_3d /= tf.reduce_max(_latent_logits_3d, axis=-1, keepdims=True)
            _latent_logits_3d = tf.reshape(_latent_logits_3d,
                                           [batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding,
                                            1])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f"latent_logits_3d{name}", _latent_logits_3d, step=self.step)

            _tokens_onehot_3d = tf.reshape(_tokens_onehot_3d,
                                                [batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding,
                                                 1])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f'latent_samples_onehot_3d{name}', _tokens_onehot_3d, step=self.step)

        _mu = mu_2d[0]  # [batch, H', W', C]
        _img = images  # [batch, H', W', C]
        for i in range(self.discrete_image_vae.num_channels):
            vmin = tf.reduce_min(_mu[..., i])
            vmax = tf.reduce_max(_mu[..., i])
            _projected_mu = (_mu[..., i:i + 1] - vmin) / (vmax - vmin)  # batch, H', W', 1
            _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.)

            vmin = tf.reduce_min(_img[..., i])
            vmax = tf.reduce_max(_img[..., i])
            _projected_img = (_img[..., i:i + 1] - vmin) / (vmax - vmin)  # batch, H', W', 1
            _projected_img = tf.clip_by_value(_projected_img, 0., 1.)

            tf.summary.image(f'image_predict_prior[{i}]', _projected_mu, step=self.step)
            tf.summary.image(f'image_actual[{i}]', _projected_img, step=self.step)

        for name, _latent_logits_2d, _tokens_onehot_2d in zip(['', '_prior'],
                                                              [latent_logits_2d, prior_latent_logits_2d],
                                                              [token_samples_onehot_2d, prior_token_samples_onehot_2d]):
            batch, H2, W2, _ = get_shape(_latent_logits_2d)
            _latent_logits_2d -= tf.reduce_min(_latent_logits_2d, axis=-1, keepdims=True)
            _latent_logits_2d /= tf.reduce_max(_latent_logits_2d, axis=-1, keepdims=True)
            _latent_logits_2d = tf.reshape(_latent_logits_2d,
                                           [batch, H2 * W2, self.discrete_image_vae.num_embedding,
                                            1])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f"latent_logits_2d{name}", _latent_logits_2d, step=self.step)

            _tokens_onehot_2d = tf.reshape(_tokens_onehot_2d,
                                           [batch, H2 * W2, self.discrete_image_vae.num_embedding,
                                            1])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f'latent_samples_onehot_2d{name}', _tokens_onehot_2d, step=self.step)
Ejemplo n.º 11
0
    def _build(self, img, **kwargs) -> dict:
        """

        Args:
            img: [batch, H', W', num_channel]
            **kwargs:

        Returns:

        """
        encoded_img_logits = self.encoder(img)  # [batch, H, W, num_embedding]
        [batch, H, W, _] = get_shape(encoded_img_logits)

        logits = tf.reshape(
            encoded_img_logits,
            [batch * H * W, self.num_embedding])  # [batch*H*W, num_embeddings]
        reduce_logsumexp = tf.math.reduce_logsumexp(logits,
                                                    axis=-1)  # [batch*H*W]
        reduce_logsumexp = tf.tile(
            reduce_logsumexp[:, None],
            [1, self.num_embedding])  # [batch*H*W, num_embedding]
        logits -= reduce_logsumexp  # [batch*H*W, num_embeddings]

        temperature = tf.maximum(
            0.1, tf.cast(1. - 0.1 / (self.step / 1000), tf.float32))
        token_distribution = tfp.distributions.RelaxedOneHotCategorical(
            temperature, logits=logits)
        token_samples_onehot = token_distribution.sample(
            (self.num_token_samples, ),
            name='token_samples')  # [S, batch*H*W, num_embeddings]

        def _single_decode(token_sample_onehot):
            #[batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim]
            token_sample = tf.matmul(
                token_sample_onehot,
                self.embeddings)  # [batch*H*W, embedding_dim]  # = z ~ q(z|x)
            latent_img = tf.reshape(token_sample,
                                    [batch, H, W, self.embedding_dim
                                     ])  # [batch, H, W, embedding_dim]
            decoded_img = self.decoder(latent_img)  # [batch, H', W', C*2]
            # print('decod shape', decoded_img)
            img_mu = decoded_img[..., :self.num_channels]  #[batch, H', W', C]
            # print('mu shape', img_mu)
            img_logb = decoded_img[..., self.num_channels:]
            # print('logb shape', img_logb)
            log_likelihood = self.log_likelihood(img, img_mu,
                                                 img_logb)  #[batch, H', W', C]
            log_likelihood = tf.reduce_sum(log_likelihood,
                                           axis=[-3, -2, -1])  # [batch]
            sum_selected_logits = tf.math.reduce_sum(token_sample_onehot *
                                                     logits,
                                                     axis=-1)  # [batch*H*W]
            sum_selected_logits = tf.reshape(sum_selected_logits,
                                             [batch, H, W])
            kl_term = tf.reduce_sum(sum_selected_logits, axis=[-2,
                                                               -1])  #[batch]
            return log_likelihood, kl_term, decoded_img

        #num_samples, batch
        log_likelihood_samples, kl_term_samples, decoded_ims = tf.vectorized_map(
            _single_decode, token_samples_onehot)  # [S, batch], [S, batch]

        if self.step % 50 == 0:
            img_mu_0 = tf.reduce_mean(decoded_ims,
                                      axis=0)[..., :self.num_channels]
            img_mu_0 -= tf.reduce_min(img_mu_0)
            img_mu_0 /= tf.reduce_max(img_mu_0)
            tf.summary.image('mu', img_mu_0, step=self.step)

            smoothed_img = img[..., self.num_channels:]
            smoothed_img = (smoothed_img - tf.reduce_min(smoothed_img)) / (
                tf.reduce_max(smoothed_img) - tf.reduce_min(smoothed_img))
            tf.summary.image(f'img_before_autoencoder',
                             smoothed_img,
                             step=self.step)

        var_exp = tf.reduce_mean(log_likelihood_samples, axis=0)  # [batch]
        kl_div = tf.reduce_mean(kl_term_samples, axis=0)  # [batch]
        elbo = var_exp - kl_div  # batch
        loss = -tf.reduce_mean(elbo)  # scalar

        entropy = -tf.reduce_sum(logits * tf.math.exp(logits),
                                 axis=-1)  # [batch*H*W]
        perplexity = 2.**(-entropy / tf.math.log(2.))  # [batch*H*W]
        mean_perplexity = tf.reduce_mean(perplexity)  # scalar

        if self.step % 2 == 0:
            logits = tf.nn.softmax(logits,
                                   axis=-1)  # [batch*H*W, num_embedding]
            logits -= tf.reduce_min(logits)
            logits /= tf.reduce_max(logits)
            logits = tf.reshape(
                logits,
                [batch, H * W, self.num_embedding])[0]  # [H*W, num_embedding]
            # tf.repeat(tf.repeat(logits, 16*[4], axis=0), 512*[4], axis=1)
            tf.summary.image('logits',
                             logits[None, :, :, None],
                             step=self.step)
            tf.summary.scalar('perplexity', mean_perplexity, step=self.step)
            tf.summary.scalar('var_exp',
                              tf.reduce_mean(var_exp),
                              step=self.step)
            tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step)

        return dict(loss=loss,
                    metrics=dict(var_exp=var_exp,
                                 kl_div=kl_div,
                                 mean_perplexity=mean_perplexity))
    def _build(self, img, **kwargs) -> dict:
        """
        Args:
            img: [batch, H, W, num_channels]
            **kwargs:

        Returns:

        """
        latent_logits = self.compute_logits(
            img)  # [batch, H, W, num_embeddings]
        log_token_samples_onehot, token_samples_onehot, latent_tokens = self.sample_latent(
            latent_logits, self.temperature, self.num_token_samples
        )  # [num_samples, batch, H, W, num_embeddings], [num_samples, batch, H, W, embedding_size]
        mu, logb = self.compute_likelihood_parameters(
            latent_tokens
        )  # [num_samples, batch, H', W', C], [num_samples, batch, H', W', C]
        log_likelihood = self.log_likelihood(img, mu,
                                             logb)  # [num_samples, batch]
        kl_term = self.kl_term(
            latent_logits, log_token_samples_onehot)  # [num_samples, batch]

        var_exp = tf.reduce_mean(log_likelihood, axis=0)  # [batch]
        kl_div = tf.reduce_mean(kl_term, axis=0)  # [batch]
        elbo = var_exp - self.beta * kl_div  # batch
        loss = -tf.reduce_mean(elbo)  # scalar

        entropy = -tf.reduce_sum(tf.math.exp(latent_logits) * latent_logits,
                                 axis=[-1])  # [batch, H, W]
        perplexity = 2.**(entropy / tf.math.log(2.))  # [batch, H, W]
        mean_perplexity = tf.reduce_mean(perplexity)  # scalar

        if self.log_counter % int(128 / mu.get_shape()[0]):
            tf.summary.scalar('perplexity', mean_perplexity, step=self.step)
            tf.summary.scalar('var_exp',
                              tf.reduce_mean(var_exp),
                              step=self.step)
            tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step)
            tf.summary.scalar('temperature', self.temperature, step=self.step)
            tf.summary.scalar('beta', self.beta, step=self.step)

            _mu = mu[0]  #[batch, H', W', C]
            _img = img  #[batch, H', W', C]
            for i in range(self.num_channels):
                vmin = tf.reduce_min(_mu[..., i])
                vmax = tf.reduce_max(_mu[..., i])
                _projected_mu = (_mu[..., i:i + 1] - vmin) / (
                    vmax - vmin)  #batch, H', W', 1
                _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.)
                vmin = tf.reduce_min(_img[..., i])
                vmax = tf.reduce_max(_img[..., i])
                _projected_img = (_img[..., i:i + 1] - vmin) / (
                    vmax - vmin)  #batch, H', W', 1
                _projected_img = tf.clip_by_value(_projected_img, 0., 1.)

                tf.summary.image(f'image_predict[{i}]',
                                 _projected_mu,
                                 step=self.step)
                tf.summary.image(f'image_actual[{i}]',
                                 _projected_img,
                                 step=self.step)

            batch, H, W, _ = get_shape(latent_logits)
            _latent_logits = latent_logits  # [batch, H, W, num_embeddings]
            _latent_logits -= tf.reduce_min(_latent_logits,
                                            axis=-1,
                                            keepdims=True)
            _latent_logits /= tf.reduce_max(_latent_logits,
                                            axis=-1,
                                            keepdims=True)
            _latent_logits = tf.reshape(_latent_logits,
                                        [batch, H * W, self.num_embedding, 1
                                         ])  # [batch, H*W, num_embedding, 1]
            tf.summary.image('latent_logits', _latent_logits, step=self.step)

            token_sample_onehot = token_samples_onehot[
                0]  # [batch, H, W, num_embeddings]
            token_sample_onehot = tf.reshape(
                token_sample_onehot, [batch, H * W, self.num_embedding, 1
                                      ])  # [batch, H*W, num_embedding, 1]
            tf.summary.image('latent_samples_onehot',
                             token_sample_onehot,
                             step=self.step)

        return dict(loss=loss,
                    metrics=dict(var_exp=var_exp,
                                 kl_div=kl_div,
                                 mean_perplexity=mean_perplexity))
    def _build(self, graphs, temperature):
        """
        Adds another set of nodes to each graph. Autoregressively links all nodes in a graph.

        Args:
            graphs: batched GraphsTuple, node_shape = [n_graphs * num_input, input_embedding_size]
            temperature: scalar > 0
        Returns:
            #todo: give shapes to returns
            token_node
             kl_div
             token_3d_samples_onehot
             basis_weights

        """
        # give graphs edges and new node dimension (linear transformation)
        graphs = self.projection_node_block(
            graphs)  # nodes = [n_graphs * num_input, embedding_size]
        batched_graphs = graph_batch_reshape(
            graphs)  # nodes = [n_graphs, num_input, embedding_size]
        [n_graphs, n_node_per_graph_before_concat,
         _] = get_shape(batched_graphs.nodes)

        concat_nodes = tf.concat(
            [
                batched_graphs.nodes,
                tf.tile(self.starting_node_variable[None, :], [n_graphs, 1, 1])
            ],
            axis=-2)  # [n_graphs, num_input + num_output, embedding_size]
        batched_graphs = batched_graphs.replace(
            nodes=concat_nodes,
            globals=tf.tile(self.starting_global_variable[None, :],
                            [n_graphs, 1]),
            n_node=tf.fill([n_graphs],
                           n_node_per_graph_before_concat + self.num_output))
        concat_graphs = graph_unbatch_reshape(
            batched_graphs
        )  # [n_graphs * (num_input + num_output), embedding_size]

        # nodes, senders, receivers, globals
        concat_graphs = autoregressive_connect_graph_dynamic(
            concat_graphs
        )  # exclude self edges because 3d tokens orginally placeholder?

        # todo: this only works if exclude_self_edges=False
        n_edge = n_graphs * (
            (n_node_per_graph_before_concat + self.num_output) *
            (n_node_per_graph_before_concat + self.num_output - 1) // 2 +
            (n_node_per_graph_before_concat + self.num_output))

        latent_graphs = concat_graphs.replace(edges=tf.tile(
            tf.constant(self.edge_size * [0.])[None, :], [n_edge, 1]))
        latent_graphs.receivers.set_shape([n_edge])
        latent_graphs.senders.set_shape([n_edge])

        def _core(output_token_idx, latent_graphs, prev_kl_term,
                  prev_token_3d_samples_onehot, prev_logits_3d):
            batched_latent_graphs = graph_batch_reshape(latent_graphs)
            batched_input_nodes = batched_latent_graphs.nodes  # [n_graphs, num_input + num_output, embedding_size]

            # todo: use self-attention
            latent_graphs = self.selfattention_core(latent_graphs)
            latent_graphs = self.edge_block(latent_graphs)
            latent_graphs = self.node_block(latent_graphs)

            batched_latent_graphs = graph_batch_reshape(latent_graphs)

            token_3d_logits = batched_latent_graphs.nodes[:, -self.
                                                          num_output:, :]  # n_graphs, num_output, num_embedding

            token_3d_logits -= tf.reduce_mean(token_3d_logits,
                                              axis=-1,
                                              keepdims=True)
            token_3d_logits /= tf.math.reduce_std(token_3d_logits,
                                                  axis=-1,
                                                  keepdims=True)
            reduce_logsumexp = tf.math.reduce_logsumexp(
                token_3d_logits, axis=-1)  # [n_graphs, num_output]
            reduce_logsumexp = tf.tile(
                reduce_logsumexp[..., None],
                [1, 1, self.num_embedding
                 ])  # [ n_graphs, num_output, num_embedding]
            token_3d_logits -= reduce_logsumexp

            token_distribution = tfp.distributions.RelaxedOneHotCategorical(
                temperature, logits=token_3d_logits)
            token_3d_samples_onehot = token_distribution.sample(
                (1, ), name='token_samples'
            )  # [1, n_graphs, num_output, num_embedding]
            token_3d_samples_onehot = token_3d_samples_onehot[
                0]  # [n_graphs, num_output, num_embedding]
            # token_3d_samples_max_index = tf.math.argmax(token_3d_logits, axis=-1, output_type=tf.int32)
            # token_3d_samples_onehot = tf.cast(tf.tile(tf.range(self.num_embedding)[None, None, :], [n_graphs, self.num_output, 1]) ==
            #                                   token_3d_samples_max_index[:,:,None], tf.float32)  # [n_graphs, num_output, num_embedding]
            token_3d_samples = tf.einsum(
                'goe,ed->god', token_3d_samples_onehot,
                self.embeddings)  # [n_graphs, num_ouput, embedding_dim]
            _mask = tf.range(
                self.num_output) == output_token_idx  # [num_output]
            mask = tf.concat([
                tf.zeros(n_node_per_graph_before_concat, dtype=tf.bool), _mask
            ],
                             axis=0)  # num_input + num_output
            mask = tf.tile(
                mask[None, :, None],
                [n_graphs, 1, self.embedding_size
                 ])  # [n_graphs, num_input + num_output, embedding_size]

            kl_term = tf.reduce_sum(
                (token_3d_samples_onehot * token_3d_logits),
                axis=-1)  # [n_graphs, num_output]
            kl_term = tf.reduce_sum(tf.cast(_mask, tf.float32) * kl_term,
                                    axis=-1)  # [n_graphs]
            kl_term += prev_kl_term

            # n_graphs, n_node+num_output, embedding_size
            output_nodes = tf.where(
                mask,
                tf.concat([
                    tf.zeros([
                        n_graphs, n_node_per_graph_before_concat,
                        self.embedding_size
                    ]), token_3d_samples
                ],
                          axis=1), batched_input_nodes)
            batched_latent_graphs = batched_latent_graphs.replace(
                nodes=output_nodes)
            latent_graphs = graph_unbatch_reshape(batched_latent_graphs)

            return (output_token_idx + 1, latent_graphs, kl_term,
                    token_3d_samples_onehot, token_3d_logits)

        _, latent_graphs, kl_div, token_3d_samples_onehot, logits_3d = tf.while_loop(
            cond=lambda output_token_idx, state, _, __, ___: output_token_idx <
            self.num_output,
            body=lambda output_token_idx, state, prev_kl_term,
            prev_token_3d_samples_onehot, prev_logits_3d: _core(
                output_token_idx, state, prev_kl_term,
                prev_token_3d_samples_onehot, prev_logits_3d),
            loop_vars=(tf.constant([0]), latent_graphs,
                       tf.zeros((n_graphs, ), dtype=tf.float32),
                       tf.zeros(
                           (n_graphs, self.num_output, self.num_embedding),
                           dtype=tf.float32),
                       tf.zeros(
                           (n_graphs, self.num_output, self.num_embedding),
                           dtype=tf.float32)))

        latent_graphs = graph_batch_reshape(latent_graphs)
        token_nodes = latent_graphs.nodes[:, -self.num_output:, :]

        if self.do_basis_weight:
            # compute weights for how much each basis function will contribute, forcing later ones to contribute less.
            #todo: use self-attention
            basis_weight_graphs = self.selfattention_weights(latent_graphs)
            basis_weight_graphs = self.basis_weight_node_block(
                self.basis_weight_edge_block(basis_weight_graphs))
            basis_weight_graphs = graph_batch_reshape(basis_weight_graphs)
            #[n_graphs, num_output]
            basis_weights = basis_weight_graphs.nodes[:, -self.num_output:, 0]
            #make the weights shrink with increasing component
            basis_weights = tf.math.cumprod(tf.nn.sigmoid(basis_weights),
                                            axis=-1)
            return token_nodes, kl_div, token_3d_samples_onehot, basis_weights, logits_3d
        else:
            return token_nodes, kl_div, token_3d_samples_onehot, logits_3d
 def _sample_encoder(self, img):
     [batch, H, W, _] = get_shape(img)
     img = tf.reshape(img,
                      [batch, H, W, self.num_channels
                       ])  # number of channels must be known for conv2d
     return self.encoder(img)
Ejemplo n.º 15
0
    def reconstruct_field(self, field_component_tokens, positions,
                          basis_weights):
        """
        Reconstruct the field at positions.

        Args:
            field_component_tokens: [num_token_samples, batch, output_size, embedding_dim_3d]
            positions: [num_token_samples, batch, n_node, 3]
            basis_weights: [num_token_samples, batch, output_size]

        Returns:
            [num_token_samples, batch, n_node, num_properties*2]
        """
        pos_shape = get_shape(positions)

        def _single_sample(args):
            """
            Compute for single batch.

            Args:
                tokens: [batch, output_size, embedding_dim_3d]
                positions: [batch, n_node, 3]
                basis_weights: [batch, output_size]

            Returns:
                [n_node, num_properties]
            """
            tokens, positions, basis_weights = args

            def _single_batch(args):
                """
                Compute for single batch.

                Args:
                    tokens: [output_size, embedding_dim_3d]
                    positions: [n_node, 3]
                    basis_weights: [output_size]

                Returns:
                    [n_node, num_properties]
                """
                tokens, positions, basis_weights = args
                #output_size, num_properties*2
                basis_weights = tf.concat([
                    tf.tile(basis_weights[:, None], [1, self.num_properties]),
                    tf.ones([self.num_components, self.num_properties])
                ],
                                          axis=-1)

                def _single_component(token):
                    """
                    Compute for a single component.

                    Args:
                        token: [component_size]

                    Returns:
                        [n_node, num_properties]
                    """

                    features = tf.concat(
                        [
                            positions,
                            tf.tile(token[None, :], [pos_shape[2], 1])
                        ],
                        axis=-1)  # [n_node, 3 + embedding_dim_3D]
                    return self.field_reconstruction(
                        features)  # [n_node, num_properties]

                # basis_weights[:, None, :] 4, 1, 14
                # tf.vectorized_map(_single_component, tokens) 4, 10000, 14
                return basis_weights[:, None, :] * tf.vectorized_map(
                    _single_component,
                    tokens)  # [num_components, n_node, num_properties]

            return tf.vectorized_map(
                _single_batch,
                (tokens, positions, basis_weights
                 ))  # [batch, num_components, n_node, num_properties]

        mu, log_stddev = tf.split(
            tf.vectorized_map(
                _single_sample,
                (field_component_tokens, positions, basis_weights)),
            num_or_size_splits=2,
            axis=-1
        )  # [num_token_samples, batch, num_components, n_node_per_graph, num_properties]
        return mu, log_stddev
Ejemplo n.º 16
0
    def reconstruct_field(self, field_component_tokens, positions):
        """
        Reconstruct the field at positions.

        Args:
            field_component_tokens: [num_token_samples, batch, num_components, embedding_dim_3d]
            positions: [num_token_samples, batch, n_node, 3]

        Returns:
            [num_token_samples, batch, n_node, num_properties*2]
        """
        _, batch, n_node, _ = get_shape(positions)
        #[num_token_samples*batch, V,V,V, embedding_dim_3d]
        latent_voxels = tf.reshape(field_component_tokens, [
            self.num_token_samples * batch, self.voxel_per_dimension,
            self.voxel_per_dimension, self.voxel_per_dimension,
            self.component_size
        ])
        # [num_token_samples*batch, N,N,N, num_properties*2]
        voxels = self.decoder_3d(latent_voxels)
        _, N, _, _, _ = get_shape(voxels)

        positions = tf.reshape(positions,
                               (self.num_token_samples * batch, n_node, 3))

        #interpolate the field
        def _interpolate_single_batch(args):
            """
            Interpolate channels onto positions.

            Args:
                voxels: [N,N,N,num_properties*2]
                positions: [n_node, 3]

            Returns: [n_node, num_properties*2]
            """
            (voxels, positions) = args
            pmin = tf.reduce_min(positions, axis=0)
            pmax = tf.reduce_max(positions, axis=0)

            arrays = [
                tf.linspace(pmin[0], pmax[0], N),
                tf.linspace(pmin[1], pmax[1], N),
                tf.linspace(pmin[2], pmax[2], N)
            ]

            coords = [positions[:, 0], positions[:, 1], positions[:, 2]]

            def interp(x, xp, fp):
                i = tf.clip_by_value(tf.searchsorted(xp, x, side='right'), 0,
                                     xp.shape[0] - 1)
                df = tf.gather(fp, i) - tf.gather(fp, i - 1)
                dx = tf.gather(xp, i) - tf.gather(xp, i - 1)
                delta = x - tf.gather(xp, i - 1)
                f = tf.where((dx == 0), tf.cast(tf.gather(fp, i), tf.float32),
                             tf.cast(tf.gather(fp, i - 1), tf.float32) +
                             (delta / dx) * tf.cast(df, tf.float32))
                return f * 0.99999

            fractional_coordinates = [
                interp(coord, array, tf.range(N))
                for array, coord in zip(arrays, coords)
            ]
            voxels = tf.transpose(voxels, (3, 0, 1, 2))
            #[num_properties*2, n_node]
            values_at_positions = tf.vectorized_map(
                lambda voxels: map_coordinates(
                    voxels, fractional_coordinates, order=1), voxels)
            # [n_node, num_properties*2]
            values_at_positions = tf.transpose(values_at_positions, (1, 0))
            return values_at_positions

        values_at_positions = tf.vectorized_map(_interpolate_single_batch,
                                                (voxels, positions))
        values_at_positions = tf.reshape(
            values_at_positions,
            (self.num_token_samples, batch, n_node, 2 * self.num_properties))

        # mu, log_stddev = tf.split(values_at_positions, num_or_size_splits=2, axis=-1)  # [num_token_samples, batch, num_components, n_node_per_graph, num_properties]
        mu = values_at_positions[..., :self.num_properties]
        log_stddev = values_at_positions[..., self.num_properties:]
        return mu, log_stddev
Ejemplo n.º 17
0
    def _incrementally_decode(self, token_samples_idx_2d):
        """
        Args:
            token_samples_idx_2d: [batch, H2, W2]

        Returns:
            token_samples_idx_3d: [batch, H3,W3,D3]
        """
        idx_dtype = token_samples_idx_2d.dtype

        batch, H2, W2 = get_shape(token_samples_idx_2d)
        H3 = W3 = D3 = self.discrete_voxel_vae.voxels_per_dimension // self.discrete_voxel_vae.shrink_factor

        token_samples_idx_2d = tf.reshape(token_samples_idx_2d, (batch, H2 * W2))

        # N, _ = get_shape(self.positional_encodings)

        # batch, H3*W3*D3, num_embedding3
        token_samples_idx_3d = tf.zeros((batch, H3 * W3 * D3), dtype=idx_dtype)

        def _core(output_token_idx, token_samples_idx_3d):
            """

            Args:
                output_token_idx: which element is being replaced
                token_samples_idx_3d: [batch, H3*W3*D3]
            """
            # [batch, 1 + H2*W2 + 1 + H3*W3*D3 + 1]
            sequence = self.construct_sequence(token_samples_idx_2d, token_samples_idx_3d)
            input_sequence = sequence[:, :-1]
            input_graphs = self.construct_input_graph(input_sequence, H2*W2)
            latent_logits = self.compute_logits(input_graphs)

            #batch, H3 * W3 * D3, num_embedding3
            # . a b . c d
            # a b . c d .
            prior_latent_logits_3d = latent_logits[:, H2*W2+1:H2*W2+1+H3*W3*D3,
                                                self.discrete_image_vae.num_embedding:self.discrete_image_vae.num_embedding + self.discrete_voxel_vae.num_embedding]

            prior_dist = tfp.distributions.Categorical(logits=prior_latent_logits_3d, dtype=idx_dtype)
            prior_latent_tokens_idx_3d = prior_dist.sample(1)[0] # batch, H3*W3*D3
            # import pylab as plt
            # # plt.imshow(tf.one_hot(prior_latent_tokens_idx_3d[0, :30], self.discrete_voxel_vae.num_embedding))
            # plt.imshow(latent_logits[0, 1020:1050], aspect='auto', interpolation='nearest')
            # plt.show()

            _mask = tf.range(H3 * W3 * D3) == output_token_idx  # [H3*W3*D3]

            output_token_samples_idx_3d = tf.where(_mask[None, :],
                                                      prior_latent_tokens_idx_3d,
                                                      token_samples_idx_3d
                                                      )

            return (output_token_idx + 1, output_token_samples_idx_3d)

        _, token_samples_idx_3d = tf.while_loop(
            cond=lambda output_token_idx, _: output_token_idx < (H3 * W3 * D3),
            body=_core,
            loop_vars=(tf.convert_to_tensor(0), token_samples_idx_3d))

        # latent_graphs = GraphsTuple(**latent_graphs_data_dict, edges=None, globals=None)

        token_samples_idx_3d = tf.reshape(token_samples_idx_3d,
                                             (batch, H3, W3, D3))
        return token_samples_idx_3d
Ejemplo n.º 18
0
 def initialize_positional_encodings(self, nodes):
     _, n_node, _ = get_shape(nodes)
     self.positional_encodings = tf.Variable(
         initial_value=tf.random.truncated_normal((n_node, self.embedding_dim)),
         name='positional_encodings')
Ejemplo n.º 19
0
    def _im_to_components(self, im, positions, temperature):
        '''

        Args:
            im: [batch=1, H, W, 2*C]
            positions: [num_positions, 3]
            temperature: scalar

        Returns:

        '''
        latent_logits = self.encoder_2d(im)  # batch, H, W, num_embeddings

        [_, H, W, num_embedding] = get_shape(latent_logits)

        latent_logits = tf.reshape(
            latent_logits,
            [self.batch, H * W, num_embedding])  # [batch, H*W, num_embeddings]
        reduce_logsumexp = tf.math.reduce_logsumexp(latent_logits,
                                                    axis=-1)  # [batch, H*W]
        reduce_logsumexp = tf.tile(
            reduce_logsumexp[..., None],
            [1, 1, num_embedding])  # [batch, H*W, num_embedding]
        latent_logits -= reduce_logsumexp  # [batch, H*W, num_embeddings]

        token_samples_onehot, token_samples = self.sample_latent_2d(
            latent_logits, temperature, self.num_token_samples
        )  # [num_token_samples, batch, H*W, num_embedding / embedding_dim]

        [_, _, _, embedding_dim] = get_shape(token_samples)

        token_samples = tf.reshape(token_samples, [
            self.num_token_samples * self.batch * self.n_node_per_graph,
            embedding_dim
        ])  # [num_token_samples * batch * H * W, embedding_dim]

        token_graphs = GraphsTuple(
            nodes=token_samples,
            edges=None,
            globals=None,
            senders=None,
            receivers=None,
            n_node=tf.constant(self.num_token_samples * self.batch *
                               [self.n_node_per_graph],
                               dtype=tf.int32),
            n_edge=tf.constant(self.num_token_samples * self.batch * [0],
                               dtype=tf.int32)
        )  # [n_node, embedding_dim], n_node = n_graphs * n_node_per_graph

        tokens_3d, kl_div, token_3d_samples_onehot, tokens_3d_logits = self.autoregressive_prior(
            token_graphs, temperature
        )  # [n_graphs, num_output, embedding_dim_3d], [n_graphs], [n_graphs, num_output, num_embedding_3d], [n_graphs, num_output]
        tokens_3d = tf.reshape(tokens_3d, [
            self.num_token_samples, self.batch, self.num_components,
            self.component_size
        ])  # [num_token_samples, batch, num_output, embedding_dim_3d]
        mu, log_stddev = self.reconstruct_field(
            tokens_3d, positions
        )  # [num_token_samples, batch, n_node_per_graph, num_properties]
        # take off the fake sample and batch dims
        return mu[0, 0, :, :], log_stddev[
            0, 0, :, :]  # [n_node_per_graph, num_properties]
Ejemplo n.º 20
0
    def _build(self, voxels, images):

        idx_dtype = tf.int32

        latent_logits_2d = self.discrete_image_vae.compute_logits(images)  # [batch, H, W, num_embeddings]
        latent_logits_3d = self.discrete_voxel_vae.compute_logits(voxels)  # [batch, H, W, D, num_embeddings]

        batch, H2, W2, _ = get_shape(latent_logits_2d)
        batch, H3, W3, D3, _ = get_shape(latent_logits_3d)
        G = self.num_token_samples * batch

        latent_logits_2d = tf.reshape(latent_logits_2d, (batch, H2 * W2, self.discrete_image_vae.num_embedding))
        latent_logits_3d = tf.reshape(latent_logits_3d, (batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding))

        q_dist_2d = tfp.distributions.Categorical(logits=latent_logits_2d, dtype=idx_dtype)
        token_samples_idx_2d = q_dist_2d.sample(self.num_token_samples)  # [num_samples, batch, H2*W2]
        token_samples_idx_2d = tf.reshape(token_samples_idx_2d, (G, H2*W2))

        q_dist_3d = tfp.distributions.Categorical(logits=latent_logits_3d, dtype=idx_dtype)
        token_samples_idx_3d = q_dist_3d.sample(
            self.num_token_samples)  # [num_samples, batch, H3*W3*D3]
        token_samples_idx_3d = tf.reshape(token_samples_idx_3d, (G, H3*W3*D3))

        entropy_2d = tf.reduce_sum(q_dist_2d.entropy(), axis=-1)
        entropy_3d = tf.reduce_sum(q_dist_3d.entropy(), axis=-1)
        entropy = entropy_3d + entropy_2d  # [batch]
        ## create sequence

        sequence = self.construct_sequence(token_samples_idx_2d, token_samples_idx_3d)
        input_sequence = sequence[:, :-1]
        input_graphs = self.construct_input_graph(input_sequence, H2*W2)
        latent_logits = self.compute_logits(input_graphs)

        prior_dist = tfp.distributions.Categorical(logits=latent_logits, dtype=idx_dtype)
        output_sequence = sequence[:, 1:]
        cross_entropy = -prior_dist.log_prob(output_sequence)#num_samples*batch, H2*W2+1+H3*W3*D3+1
        # . a . b
        # a . b .
        cross_entropy = cross_entropy[:, H2*W2+1:-1]
        cross_entropy = tf.reshape(tf.reduce_sum(cross_entropy, axis=-1), (self.num_token_samples, batch))  # num_samples,batch
        kl_term = cross_entropy + entropy  # [num_samples, batch]

        kl_div = tf.reduce_mean(kl_term)  # scalar
        # elbo = tf.stop_gradient(var_exp) - self.beta * kl_div
        elbo = - kl_div  # scalar

        loss = - elbo  # scalar

        if self.log_counter % int(128 / kl_term.get_shape()[0]) == 0:
            prior_latent_logits_2d = tf.reshape(latent_logits[:, :H2*W2, :self.discrete_image_vae.num_embedding],
                                                (self.num_token_samples, batch, H2, W2, self.discrete_image_vae.num_embedding))
            prior_latent_logits_3d = tf.reshape(latent_logits[:, H2*W2+1:H2*W2+1+H3*W3*D3,
                                     self.discrete_image_vae.num_embedding:self.discrete_image_vae.num_embedding+self.discrete_voxel_vae.num_embedding],
                                                (self.num_token_samples, batch, H3, W3, D3, self.discrete_voxel_vae.num_embedding))

            latent_logits_2d = tf.reshape(latent_logits_2d,
                                                (batch, H2, W2, self.discrete_image_vae.num_embedding))
            latent_logits_3d = tf.reshape(latent_logits_3d,
                                          (batch, H3, W3, D3, self.discrete_voxel_vae.num_embedding))

            self.write_summary(images, voxels,
                               latent_logits_2d,
                               latent_logits_3d,
                               prior_latent_logits_2d[0],
                               prior_latent_logits_3d[0])

        return dict(loss=loss)
Ejemplo n.º 21
0
    def _build(self, graphs, imgs, **kwargs) -> dict:

        # graphs.nodes: [batch, n_node_per_graph, 3+num_properties]
        # imgs: [batch, H', W', C]
        latent_logits = self.encoder_2d(imgs)  #batch, H, W, num_embeddings

        [_, H, W, num_embedding] = get_shape(latent_logits)

        latent_logits = tf.reshape(
            latent_logits,
            [self.batch, H * W, num_embedding])  # [batch, H*W, num_embeddings]
        latent_logits -= tf.reduce_mean(latent_logits, axis=-1, keepdims=True)
        latent_logits /= tf.math.reduce_std(latent_logits,
                                            axis=-1,
                                            keepdims=True)
        reduce_logsumexp = tf.math.reduce_logsumexp(latent_logits,
                                                    axis=-1)  # [batch, H*W]
        reduce_logsumexp = tf.tile(
            reduce_logsumexp[..., None],
            [1, 1, num_embedding])  # [batch, H*W, num_embedding]
        latent_logits -= reduce_logsumexp  # [batch, H*W, num_embeddings]

        # temperature = tf.maximum(0.1, tf.cast(10. - 0.1 * (self.step / 1000), tf.float32))
        # temperature = 10.
        token_samples_onehot, token_samples = self.sample_latent_2d(
            latent_logits, self.temperature, self.num_token_samples
        )  # [num_token_samples, batch, H*W, num_embedding / embedding_dim]

        # token_samples = tf.reshape(token_samples, [self.num_token_samples * self.batch, H * W, self.component_size])  # [num_token_samples*batch, H*W, embedding_dim]
        [_, _, _, embedding_dim] = get_shape(token_samples)
        n_graphs = self.num_token_samples * self.batch

        token_samples = tf.reshape(
            token_samples,
            [n_graphs * self.n_node_per_graph, embedding_dim
             ])  # [num_token_samples * batch * H * W, embedding_dim]

        token_graphs = GraphsTuple(
            nodes=token_samples,
            edges=None,
            globals=None,
            senders=None,
            receivers=None,
            n_node=tf.constant(n_graphs * [self.n_node_per_graph],
                               dtype=tf.int32),
            n_edge=tf.constant(n_graphs * [0], dtype=tf.int32)
        )  # [n_node, embedding_dim], n_node = n_graphs * n_node_per_graph

        tokens_3d, kl_div, token_3d_samples_onehot, logits_3d = self.autoregressive_prior(
            token_graphs, self.temperature
        )  # [n_graphs, num_output, embedding_dim_3d], [n_graphs], [n_graphs, num_output, num_embedding_3d], [n_graphs, num_output]

        tokens_3d = tf.reshape(tokens_3d, [
            self.num_token_samples, self.batch, self.num_components,
            self.component_size
        ])  # [num_token_samples, batch, num_output, embedding_dim_3d]

        kl_div = tf.reshape(
            kl_div,
            [self.num_token_samples, self.batch])  # [num_token_samples, batch]

        # properties: [num_token_samples, batch, n_node_per_graph, 3+num_properties]
        # tokens_3d: [num_token_samples, batch, num_output, embedding_dim_3d]
        properties = tf.tile(graphs.nodes[None, ...],
                             [self.num_token_samples, 1, 1, 1])
        field_properties, log_likelihood = self.log_likelihood(
            tokens_3d, properties)
        # field_properties: [num_token_samples, batch, n_node_per_graph, num_properties]
        # log_likelihood: [num_token_samples, batch]

        field_properties = tf.reduce_mean(
            field_properties,
            axis=0)  # [batch, n_node_per_graph, num_properties]
        var_exp = tf.reduce_mean(log_likelihood, axis=0)  # [batch]
        kl_div = tf.reduce_mean(kl_div, axis=0)  # [batch]
        elbo = tf.reduce_mean(var_exp - self.beta * kl_div)  # scalar

        entropy = -tf.reduce_sum(latent_logits * tf.math.exp(latent_logits),
                                 axis=-1)  # [batch, H, W]
        perplexity = 2.**(-entropy / tf.math.log(2.))
        mean_perplexity = tf.reduce_mean(perplexity)

        loss = -elbo  # maximize ELBO so minimize -ELBO

        [_, _, _, num_channels] = get_shape(imgs)

        if self.step % 10 == 0:
            tf.summary.scalar('perplexity', mean_perplexity, step=self.step)
            tf.summary.scalar('var_exp',
                              tf.reduce_mean(var_exp),
                              step=self.step)
            tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step)
            tf.summary.scalar('temperature', self.temperature, step=self.step)

            token_3d_samples_onehot -= tf.reduce_min(token_3d_samples_onehot)
            token_3d_samples_onehot /= tf.reduce_max(token_3d_samples_onehot)
            tf.summary.image('token_3d_samples_onehot',
                             token_3d_samples_onehot[..., None],
                             step=self.step)

            logits_3d -= tf.reduce_min(logits_3d)
            logits_3d /= tf.reduce_max(logits_3d)
            tf.summary.image('logits_3d', logits_3d[..., None], step=self.step)

            token_samples_onehot = token_samples_onehot[0]

            token_samples_onehot -= tf.reduce_min(token_samples_onehot)
            token_samples_onehot /= tf.reduce_max(token_samples_onehot)
            tf.summary.image('token_samples_onehot',
                             token_samples_onehot[..., None],
                             step=self.step)

        if self.step % 50 == 0:
            input_properties = graphs.nodes[0]
            reconstructed_properties = field_properties[0]
            pos = tf.reverse(input_properties[:, :2], [1])

            for i in range(num_channels):
                img_i = imgs[0, ..., i][None, ..., None]
                img_i = (img_i - tf.reduce_min(img_i)) / (
                    tf.reduce_max(img_i) - tf.reduce_min(img_i))
                tf.summary.image(f'img_before_autoencoder_{i}',
                                 img_i,
                                 step=self.step)

            for i in range(self.num_properties):
                image_before, _ = histogramdd(pos,
                                              bins=64,
                                              weights=input_properties[:,
                                                                       3 + i])
                image_before -= tf.reduce_min(image_before)
                image_before /= tf.reduce_max(image_before)
                tf.summary.image(f"{3+i}_xy_image_before_b",
                                 image_before[None, :, :, None],
                                 step=self.step)
                tf.summary.scalar(f"properties{3+i}_std_before",
                                  tf.math.reduce_std(input_properties[:,
                                                                      3 + i]),
                                  step=self.step)

                image_after, _ = histogramdd(
                    pos, bins=64, weights=reconstructed_properties[:, i])
                # image_after, _ = histogramdd(pos, bins=50, weights=tf.random.truncated_normal((10000, )))
                image_after -= tf.reduce_min(image_after)
                image_after /= tf.reduce_max(image_after)
                tf.summary.image(f"{3+i}_xy_image_after",
                                 image_after[None, :, :, None],
                                 step=self.step)
                tf.summary.scalar(f"properties{3+i}_std_after",
                                  tf.math.reduce_std(
                                      reconstructed_properties[:, i]),
                                  step=self.step)

        return dict(loss=loss,
                    metrics=dict(var_exp=var_exp,
                                 kl_term=kl_div,
                                 mean_perplexity=mean_perplexity))
    def _build(self, img, **kwargs) -> dict:
        """

        Args:
            img: [batch, H', W', num_channel]
            **kwargs:

        Returns:

        """
        encoded_img_logits = self.encoder(img)  # [batch, H, W, num_embedding]
        [batch, H, W, _] = get_shape(encoded_img_logits)

        logits = tf.reshape(
            encoded_img_logits,
            [batch * H * W, self.num_embedding])  # [batch*H*W, num_embeddings]
        logits -= tf.reduce_mean(logits, axis=-1)
        logits /= tf.math.reduce_std(logits, axis=-1)
        reduce_logsumexp = tf.math.reduce_logsumexp(logits,
                                                    axis=-1)  # [batch*H*W]
        reduce_logsumexp = tf.tile(
            reduce_logsumexp[:, None],
            [1, self.num_embedding])  # [batch*H*W, num_embedding]
        logits -= reduce_logsumexp  # [batch*H*W, num_embeddings]

        # temperature = tf.maximum(0.1, tf.cast(10. - 0.1 * (self.step / 1000), tf.float32))
        token_distribution = tfp.distributions.RelaxedOneHotCategorical(
            self.temperature, logits=logits)
        # token_distribution = tfp.distributions.OneHotCategorical(logits=logits)
        token_samples_onehot = token_distribution.sample(
            (self.num_token_samples, ),
            name='token_samples')  # [S, batch*H*W, num_embeddings]
        token_samples_onehot = tf.cast(token_samples_onehot, dtype=tf.float32)

        def _single_decode_part_1(token_sample_onehot):
            #[batch*H*W, num_embeddings] @ [num_embeddings, embedding_dim]
            token_sample = tf.matmul(
                token_sample_onehot,
                self.embeddings)  # [batch*H*W, embedding_dim]  # = z ~ q(z|x)
            latent_img = tf.reshape(token_sample,
                                    [batch, H, W, self.embedding_dim
                                     ])  # [batch, H, W, embedding_dim]
            return latent_img

        def _single_decode_part_2(args):
            decoded_img, token_sample_onehot = args
            img_mu = decoded_img[..., :self.num_channels]  #[batch, H', W', C]
            img_logb = decoded_img[..., self.num_channels:]
            log_likelihood = self.log_likelihood(img, img_mu,
                                                 img_logb)  #[batch, H', W']
            log_likelihood = tf.reduce_sum(log_likelihood,
                                           axis=[-2, -1])  # [batch]
            sum_selected_logits = tf.math.reduce_sum(token_sample_onehot *
                                                     logits,
                                                     axis=-1)  # [batch*H*W]
            sum_selected_logits = tf.reshape(sum_selected_logits,
                                             [batch, H, W])
            kl_term = tf.reduce_sum(sum_selected_logits, axis=[-2,
                                                               -1])  #[batch]
            return log_likelihood, kl_term, decoded_img

        latent_imgs = tf.vectorized_map(
            _single_decode_part_1,
            token_samples_onehot)  # [S, batch, H, W, embedding_dim]

        # decoder outside the vectorized map, first merge batch and sample dimension
        latent_imgs = tf.reshape(
            latent_imgs,
            [self.num_token_samples * batch, H, W, self.embedding_dim
             ])  # [S * batch, H, W, embedding_dim]
        decoded_imgs = self.decoder(latent_imgs)  # [S * batch, H', W', C*2]

        # reshape again for input of the second vectorized map
        [_, H_2, W_2, _] = get_shape(decoded_imgs)
        decoded_imgs = tf.reshape(
            decoded_imgs,
            [self.num_token_samples, batch, H_2, W_2, 2 * self.num_channels
             ])  # [S, batch, H, W, embedding_dim]
        log_likelihood_samples, kl_term_samples, decoded_ims = tf.vectorized_map(
            _single_decode_part_2, (decoded_imgs, token_samples_onehot))

        var_exp = tf.reduce_mean(log_likelihood_samples, axis=0)  # [batch]
        kl_div = tf.reduce_mean(kl_term_samples, axis=0)  # [batch]
        elbo = var_exp - kl_div  # batch
        loss = -tf.reduce_mean(elbo)  # scalar

        latent_logits = tf.reshape(logits, [batch, H, W, self.num_embedding])
        token_samples_onehot_perp = tf.reshape(
            token_samples_onehot,
            [self.num_token_samples, batch, H, W, self.num_embedding])

        q_dist = tfp.distributions.RelaxedOneHotCategorical(
            self.temperature, logits=latent_logits)
        log_prob_q = q_dist.log_prob(
            token_samples_onehot_perp)  # num_samples, batch, H, W
        entropy = -tf.reduce_sum(tf.math.exp(log_prob_q) * log_prob_q,
                                 axis=[-1, -2, -3])  # [S, batch]
        perplexity = 2.**(entropy / tf.math.log(2.))  # [S, batch, H, W]
        mean_perplexity = tf.reduce_mean(perplexity)  # scalar

        # entropy = -tf.reduce_sum(logits * token_samples_onehot, axis=-1)  # [S, batch*H*W]
        # perplexity = 2. ** (-entropy / tf.math.log(2.))  # [S, batch*H*W]
        # mean_perplexity = tf.reduce_mean(perplexity)  # scalar

        if self.step % 10 == 0:
            for i in range(self.num_channels):
                img_mu_0 = tf.reduce_mean(decoded_ims, axis=0)[..., i][...,
                                                                       None]
                img_mu_0 -= tf.reduce_min(img_mu_0)
                img_mu_0 /= tf.reduce_max(img_mu_0)
                tf.summary.image(f'mu_{i}', img_mu_0, step=self.step)

                img_i = img[..., i][..., None]
                img_i = (img_i - tf.reduce_min(img_i)) / (
                    tf.reduce_max(img_i) - tf.reduce_min(img_i))
                tf.summary.image(f'img_before_autoencoder_{i}',
                                 img_i,
                                 step=self.step)

            # logits = tf.nn.softmax(logits, axis=-1)  # [batch*H*W, num_embedding]
            logits -= tf.reduce_min(logits)
            logits /= tf.reduce_max(logits)
            logits = tf.reshape(logits, [batch, H * W, self.num_embedding
                                         ])  # [batch, H*W, num_embedding]
            tf.summary.image('logits', logits[:, :, :, None], step=self.step)

            token_sample_onehot = token_samples_onehot[0, ...]
            token_sample_onehot -= tf.reduce_min(
                token_sample_onehot)  # [batch*H*W, num_embeddings]
            token_sample_onehot /= tf.reduce_max(token_sample_onehot)
            token_sample_onehot = tf.reshape(
                token_sample_onehot, [batch, H * W, self.num_embedding
                                      ])  # [batch, H*W, num_embedding]
            tf.summary.image('token_sample_onehot',
                             token_sample_onehot[:, :, :, None],
                             step=self.step)

            latent_img = latent_imgs[:, :, :, 0]
            latent_img -= tf.reduce_min(latent_img)
            latent_img /= tf.reduce_max(latent_img)
            tf.summary.image('latent_im',
                             latent_img[..., None],
                             step=self.step)

        if self.step % 10 == 0:
            tf.summary.scalar('perplexity', mean_perplexity, step=self.step)
            tf.summary.scalar('var_exp',
                              tf.reduce_mean(var_exp),
                              step=self.step)
            tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step)
            tf.summary.scalar('temperature', self.temperature, step=self.step)

        return dict(loss=loss,
                    metrics=dict(var_exp=var_exp,
                                 kl_div=kl_div,
                                 mean_perplexity=mean_perplexity))