Ejemplo n.º 1
0
def evaluate_auto_regressive_prior(autoregressive_prior: AutoRegressivePrior,
                                   output_dir):
    dataset = build_dataset()

    tf_grid_graphs = tf.function(lambda graphs: grid_graphs(
        graphs, autoregressive_prior.discrete_voxel_vae.voxels_per_dimension))
    fig_dir = os.path.join(output_dir, 'evaluated_pairs')
    os.makedirs(fig_dir, exist_ok=True)
    pair_idx = 0
    for (graphs, images) in iter(dataset):
        actual_voxels = tf_grid_graphs(graphs)
        actual_voxels = actual_voxels.numpy()

        # batch, H,W,D,C
        mu_3d, b_3d = autoregressive_prior._deproject_images(images[0:1])
        mu_3d = mu_3d.numpy()
        b_3d = b_3d.numpy()

        for _actual_voxels, _mu_3d, _b_3d, _image in zip(
                actual_voxels, mu_3d, b_3d, images):
            np.savez(os.path.join(fig_dir,
                                  "evaluation_{:05}.npz".format(pair_idx)),
                     actual_voxels=_actual_voxels,
                     mu_3d=_mu_3d,
                     b_3d=_b_3d,
                     image=_image)
            # plot_voxel(_image, _mu_3d, _actual_voxels)

            pair_idx += 1
Ejemplo n.º 2
0
def train_discrete_voxel_vae(config, kwargs):
    # with strategy.scope():
    train_one_epoch = build_training(**config, **kwargs)

    dataset = build_example_dataset(100,
                                    batch_size=2,
                                    num_blobs=3,
                                    num_nodes=64**3,
                                    image_dim=256)

    # the model will call grid_graphs internally to learn the 3D autoencoder.
    # we show here what that produces from a batch of graphs.
    for graphs, image in iter(dataset):
        assert image.numpy().shape == (2, 256, 256, 1)
        plt.imshow(image[0].numpy())
        plt.colorbar()
        plt.show()
        voxels = grid_graphs(graphs, 64)
        assert voxels.numpy().shape == (2, 64, 64, 64, 1)
        plt.imshow(tf.reduce_mean(voxels[0], axis=-2))
        plt.colorbar()
        plt.show()
        break

    # drop the image as the model expects only graphs
    dataset = dataset.map(lambda graphs, images: (graphs, ))

    # run on first input to set variable shapes
    for batch in iter(dataset):
        train_one_epoch.model(*batch)
        break

    log_dir = build_log_dir('log_dir', config)
    checkpoint_dir = build_checkpoint_dir('checkpointing', config)

    os.makedirs(checkpoint_dir, exist_ok=True)
    with open(os.path.join(checkpoint_dir, 'config.json'), 'w') as f:
        json.dump(config, f)

    vanilla_training_loop(train_one_epoch=train_one_epoch,
                          training_dataset=dataset,
                          num_epochs=1,
                          early_stop_patience=5,
                          checkpoint_dir=checkpoint_dir,
                          log_dir=log_dir,
                          save_model_dir=checkpoint_dir,
                          debug=False)
Ejemplo n.º 3
0
    def compute_logits(self, graphs):
        """
        Computes normalised logits representing the variational posterior, q(z | img).

        Args:
            graphs: GraphsTuple

        Returns:
            logits: [batch, W, H, D, num_embeddings]
        """
        #[batch, H', W', D', num_channels]
        img = grid_graphs(graphs,
                          voxels_per_dimension=self.voxels_per_dimension)
        # number of channels must be known for conv3d
        img.set_shape([None, None, None, None, self.num_channels])
        logits = self._encoder(img)
        logits /= 1e-6 + tf.math.reduce_std(logits, axis=-1, keepdims=True)
        logits -= tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
        return logits
Ejemplo n.º 4
0
    def log_likelihood(self, graphs: GraphsTuple, mu, logb):
        """
        Log-Laplace distribution.

        The pdf of log-Laplace is,

            P(x | mu, b) = 1 / (2 * log(x) * b) * exp(|log(x) - mu|/b)

        Args:
            graphs: GraphsTuple in standard form. Assumes properties are of the form log(maximum(1e-5, properties))
            mu: [num_samples, batch, H', W', D', channels]
            logb: [num_samples, batch, H', W', D', channels]

        Returns:
            log_prob [num_samples, batch]
        """
        properties = grid_graphs(
            graphs, self.voxels_per_dimension
        )  # [num_samples, batch, H', W', D', num_properties]
        log_prob = - tf.math.abs(properties - mu) / tf.math.exp(logb) \
                   - tf.math.log(2.) - properties - logb # [num_samples, batch, H', W', D', num_properties]
        #num_samples, batch
        return tf.reduce_sum(log_prob, axis=[-1, -2, -3, -4])
Ejemplo n.º 5
0
    def _build(self, graphs, **kwargs) -> dict:
        """
        Args:
            graphs: GraphsTuple in standard form.
            **kwargs:

        Returns:

        """
        latent_logits = self.compute_logits(
            graphs)  # [batch, H, W, D, num_embeddings]
        log_token_samples_onehot, token_samples_onehot, latent_tokens = self.sample_latent(
            latent_logits, self.temperature, self.num_token_samples
        )  # [num_samples, batch, H, W, D, num_embeddings], [num_samples, batch, H, W, D, embedding_size]
        mu, logb = self.compute_likelihood_parameters(
            latent_tokens
        )  # [num_samples, batch, H', W', D', C], [num_samples, batch, H', W', D', C]
        log_likelihood = self.log_likelihood(graphs, mu,
                                             logb)  # [num_samples, batch]
        kl_term = self.kl_term(
            latent_logits, log_token_samples_onehot)  # [num_samples, batch]

        var_exp = tf.reduce_mean(log_likelihood, axis=0)  # [batch]
        kl_div = tf.reduce_mean(kl_term, axis=0)  # [batch]
        elbo = var_exp - self.beta * kl_div  # batch
        loss = -tf.reduce_mean(elbo)  # scalar

        entropy = -tf.reduce_sum(tf.math.exp(latent_logits) * latent_logits,
                                 axis=[-1])  # [batch, H, W, D]
        perplexity = 2.**(entropy / tf.math.log(2.))  # [batch, H, W, D]
        mean_perplexity = tf.reduce_mean(perplexity)  # scalar

        if self.log_counter % int(128 / mu.get_shape()[0]) == 0:
            tf.summary.scalar('perplexity', mean_perplexity, step=self.step)
            tf.summary.scalar('var_exp',
                              tf.reduce_mean(var_exp),
                              step=self.step)
            tf.summary.scalar('kl_div', tf.reduce_mean(kl_div), step=self.step)
            tf.summary.scalar('temperature', self.temperature, step=self.step)
            tf.summary.scalar('beta', self.beta, step=self.step)

            projected_mu = tf.reduce_sum(mu[0], axis=-2)  #[batch, H', W', C]
            voxels = grid_graphs(
                graphs, self.voxels_per_dimension)  #[batch, H', W', D', C]
            projected_img = tf.reduce_sum(voxels, axis=-2)  #[batch, H', W', C]
            for i in range(self.num_channels):
                vmin = tf.reduce_min(projected_mu[..., i])
                vmax = tf.reduce_max(projected_mu[..., i])
                _projected_mu = (projected_mu[..., i:i + 1] - vmin) / (
                    vmax - vmin)  #batch, H', W', 1
                _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.)
                vmin = tf.reduce_min(projected_img[..., i])
                vmax = tf.reduce_max(projected_img[..., i])
                _projected_img = (projected_img[..., i:i + 1] - vmin) / (
                    vmax - vmin)  #batch, H', W', 1
                _projected_img = tf.clip_by_value(_projected_img, 0., 1.)

                tf.summary.image(f'voxels_predict[{i}]',
                                 _projected_mu,
                                 step=self.step)
                tf.summary.image(f'voxels_actual[{i}]',
                                 _projected_img,
                                 step=self.step)

            batch, H, W, D, _ = get_shape(latent_logits)
            _latent_logits = latent_logits  # [batch, H, W, D, num_embeddings]
            _latent_logits -= tf.reduce_min(_latent_logits,
                                            axis=-1,
                                            keepdims=True)
            _latent_logits /= tf.reduce_max(_latent_logits,
                                            axis=-1,
                                            keepdims=True)
            _latent_logits = tf.reshape(
                _latent_logits, [batch, H * W * D, self.num_embedding, 1
                                 ])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image('latent_logits', _latent_logits, step=self.step)

            token_sample_onehot = token_samples_onehot[
                0]  # [batch, H, W, D, num_embeddings]
            token_sample_onehot = tf.reshape(
                token_sample_onehot, [batch, H * W * D, self.num_embedding, 1
                                      ])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image('latent_samples_onehot',
                             token_sample_onehot,
                             step=self.step)

        return dict(loss=loss,
                    metrics=dict(var_exp=var_exp,
                                 kl_div=kl_div,
                                 mean_perplexity=mean_perplexity))
    def write_summary(self, images, graphs, latent_logits_2d, latent_logits_3d,
                      prior_latent_logits_2d, prior_latent_logits_3d):

        dist_2d = tfp.distributions.OneHotCategorical(
            logits=latent_logits_2d, dtype=latent_logits_2d.dtype)
        dist_3d = tfp.distributions.OneHotCategorical(
            logits=latent_logits_3d, dtype=latent_logits_3d.dtype)
        token_samples_onehot_2d = dist_2d.sample(1)[0]
        token_samples_onehot_3d = dist_3d.sample(1)[0]

        dist_2d_prior = tfp.distributions.OneHotCategorical(
            logits=prior_latent_logits_2d, dtype=prior_latent_logits_2d.dtype)
        dist_3d_prior = tfp.distributions.OneHotCategorical(
            logits=prior_latent_logits_3d, dtype=prior_latent_logits_3d.dtype)
        prior_token_samples_onehot_2d = dist_2d_prior.sample(1)[0]
        prior_token_samples_onehot_3d = dist_3d_prior.sample(1)[0]

        kl_div_2d = tf.reduce_mean(
            tf.reduce_sum(dist_2d.kl_divergence(dist_2d_prior), axis=[-1, -2]))
        kl_div_3d = tf.reduce_mean(
            tf.reduce_sum(dist_3d.kl_divergence(dist_3d_prior),
                          axis=[-1, -2, -3]))
        tf.summary.scalar('kl_div_2d', kl_div_2d, step=self.step)
        tf.summary.scalar('kl_div_3d', kl_div_3d, step=self.step)
        tf.summary.scalar('kl_div', kl_div_2d + kl_div_3d, step=self.step)

        perplexity_2d = 2.**(dist_2d_prior.entropy() / tf.math.log(2.))  #
        mean_perplexity_2d = tf.reduce_mean(perplexity_2d)  # scalar

        perplexity_3d = 2.**(dist_3d_prior.entropy() / tf.math.log(2.))  #
        mean_perplexity_3d = tf.reduce_mean(perplexity_3d)  #

        tf.summary.scalar('perplexity_2d_prior',
                          mean_perplexity_2d,
                          step=self.step)
        tf.summary.scalar('perplexity_3d_prior',
                          mean_perplexity_3d,
                          step=self.step)

        prior_latent_tokens_2d = tf.einsum('sbhwd,de->sbhwe',
                                           prior_token_samples_onehot_2d[None],
                                           self.discrete_image_vae.embeddings)
        prior_latent_tokens_3d = tf.einsum('sbhwdn,ne->sbhwde',
                                           prior_token_samples_onehot_3d[None],
                                           self.discrete_voxel_vae.embeddings)
        mu_2d, logb_2d = self.discrete_image_vae.compute_likelihood_parameters(
            prior_latent_tokens_2d
        )  # [num_samples, batch, H', W', C], [num_samples, batch, H', W', C]
        log_likelihood_2d = self.discrete_image_vae.log_likelihood(
            images, mu_2d, logb_2d)  # [num_samples, batch]
        var_exp_2d = tf.reduce_mean(log_likelihood_2d)  # [scalar]
        mu_3d, logb_3d = self.discrete_voxel_vae.compute_likelihood_parameters(
            prior_latent_tokens_3d
        )  # [num_samples, batch, H', W', D', C], [num_samples, batch, H', W', D', C]
        log_likelihood_3d = self.discrete_voxel_vae.log_likelihood(
            graphs, mu_3d, logb_3d)  # [num_samples, batch]
        var_exp_3d = tf.reduce_mean(log_likelihood_3d)  # [scalar]
        var_exp = log_likelihood_2d + log_likelihood_3d

        tf.summary.scalar('var_exp_3d',
                          tf.reduce_mean(var_exp_3d),
                          step=self.step)
        tf.summary.scalar('var_exp_2d',
                          tf.reduce_mean(var_exp_2d),
                          step=self.step)
        tf.summary.scalar('var_exp', tf.reduce_mean(var_exp), step=self.step)

        projected_mu = tf.reduce_sum(mu_3d[0], axis=-2)  # [batch, H', W', C]
        voxels = grid_graphs(graphs, self.discrete_voxel_vae.
                             voxels_per_dimension)  # [batch, H', W', D', C]
        projected_img = tf.reduce_sum(voxels, axis=-2)  # [batch, H', W', C]
        for i in range(self.discrete_voxel_vae.num_channels):
            vmin = tf.reduce_min(projected_mu[..., i])
            vmax = tf.reduce_max(projected_mu[..., i])
            _projected_mu = (projected_mu[..., i:i + 1] - vmin) / (
                vmax - vmin)  # batch, H', W', 1
            _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.)

            vmin = tf.reduce_min(projected_img[..., i])
            vmax = tf.reduce_max(projected_img[..., i])
            _projected_img = (projected_img[..., i:i + 1] - vmin) / (
                vmax - vmin)  # batch, H', W', 1
            _projected_img = tf.clip_by_value(_projected_img, 0., 1.)

            tf.summary.image(f'voxels_predict_prior[{i}]',
                             _projected_mu,
                             step=self.step)
            tf.summary.image(f'voxels_actual[{i}]',
                             _projected_img,
                             step=self.step)

        for name, _latent_logits_3d, _tokens_onehot_3d in zip(
            ['', '_prior'], [latent_logits_3d, prior_latent_logits_3d],
            [token_samples_onehot_3d, prior_token_samples_onehot_3d]):
            batch, H3, W3, D3, _ = get_shape(_latent_logits_3d)
            _latent_logits_3d -= tf.reduce_min(_latent_logits_3d,
                                               axis=-1,
                                               keepdims=True)
            _latent_logits_3d /= tf.reduce_max(_latent_logits_3d,
                                               axis=-1,
                                               keepdims=True)
            _latent_logits_3d = tf.reshape(_latent_logits_3d, [
                batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding, 1
            ])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f"latent_logits_3d{name}",
                             _latent_logits_3d,
                             step=self.step)

            _tokens_onehot_3d = tf.reshape(_tokens_onehot_3d, [
                batch, H3 * W3 * D3, self.discrete_voxel_vae.num_embedding, 1
            ])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f'latent_samples_onehot_3d{name}',
                             _tokens_onehot_3d,
                             step=self.step)

        _mu = mu_2d[0]  # [batch, H', W', C]
        _img = images  # [batch, H', W', C]
        for i in range(self.discrete_image_vae.num_channels):
            vmin = tf.reduce_min(_mu[..., i])
            vmax = tf.reduce_max(_mu[..., i])
            _projected_mu = (_mu[..., i:i + 1] - vmin) / (vmax - vmin
                                                          )  # batch, H', W', 1
            _projected_mu = tf.clip_by_value(_projected_mu, 0., 1.)

            vmin = tf.reduce_min(_img[..., i])
            vmax = tf.reduce_max(_img[..., i])
            _projected_img = (_img[..., i:i + 1] - vmin) / (
                vmax - vmin)  # batch, H', W', 1
            _projected_img = tf.clip_by_value(_projected_img, 0., 1.)

            tf.summary.image(f'image_predict_prior[{i}]',
                             _projected_mu,
                             step=self.step)
            tf.summary.image(f'image_actual[{i}]',
                             _projected_img,
                             step=self.step)

        for name, _latent_logits_2d, _tokens_onehot_2d in zip(
            ['', '_prior'], [latent_logits_2d, prior_latent_logits_2d],
            [token_samples_onehot_2d, prior_token_samples_onehot_2d]):
            batch, H2, W2, _ = get_shape(_latent_logits_2d)
            _latent_logits_2d -= tf.reduce_min(_latent_logits_2d,
                                               axis=-1,
                                               keepdims=True)
            _latent_logits_2d /= tf.reduce_max(_latent_logits_2d,
                                               axis=-1,
                                               keepdims=True)
            _latent_logits_2d = tf.reshape(
                _latent_logits_2d,
                [batch, H2 * W2, self.discrete_image_vae.num_embedding, 1
                 ])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f"latent_logits_2d{name}",
                             _latent_logits_2d,
                             step=self.step)

            _tokens_onehot_2d = tf.reshape(
                _tokens_onehot_2d,
                [batch, H2 * W2, self.discrete_image_vae.num_embedding, 1
                 ])  # [batch, H*W*D, num_embedding, 1]
            tf.summary.image(f'latent_samples_onehot_2d{name}',
                             _tokens_onehot_2d,
                             step=self.step)