Example #1
0
def test_vmf_samples_shape():
    concentration = torch.tensor(
        [
            1,
            5,
            10,
            50.0,
            75,
            90,
            91,
            92,
            93,
            94,
            95,
            96,
            97,
            98,
            99,
            100,
            1000,
            5000,
            10000.0,
        ]
    )
    loc = torch.randn([19, 8])
    loc /= loc.norm(dim=-1).unsqueeze(-1).repeat(1, 8)
    vmf = VonMisesFisher(loc, concentration)
    sample = vmf.rsample()

    assert sample.size() == loc.size()
Example #2
0
def test_vmf_samples_positive_kl_divergence():
    concentration = torch.tensor(
        [
            1,
            5,
            10,
            50.0,
            75,
            90,
            91,
            92,
            93,
            94,
            95,
            96,
            97,
            98,
            99,
            100,
            1000,
            5000,
            10000.0,
        ]
    )
    loc = torch.randn([19, 8])
    loc /= loc.norm(dim=-1).unsqueeze(-1).repeat(1, 8)
    vmf = VonMisesFisher(loc, concentration)

    kl_divergence = vmf.kl_divergence()

    assert torch.all(kl_divergence >= 0)
Example #3
0
def test_vmf_samples_are_unit_vectors():
    concentration = torch.tensor(
        [
            1,
            5,
            10,
            50.0,
            75,
            90,
            91,
            92,
            93,
            94,
            95,
            96,
            97,
            98,
            99,
            100,
            1000,
            5000,
            10000.0,
        ]
    )
    loc = torch.randn([19, 8])
    loc /= loc.norm(dim=-1).unsqueeze(-1).repeat(1, 8)
    vmf = VonMisesFisher(loc, concentration)
    sample = vmf.rsample()
    sample_norm = sample.norm(dim=-1)
    expected_norm = torch.ones(sample_norm.size())

    assert torch.all(torch.isclose(sample_norm, expected_norm))
Example #4
0
    def forward(self, x: torch.Tensor):
        """"""
        x_initial_projection = self.initial_latent_projection(x)

        # Project input into a mean for the vMF. Ensure each mean is a unit vector.
        mean_prime = self.mean_encoder(
            x_initial_projection)  # Shape: (batch_size, hidden_dim)
        mean = mean_prime / (mean_prime.norm(dim=-1).unsqueeze(-1).repeat(
            1, mean_prime.shape[-1]))  # Shape: (batch_size, hidden_dim).

        # Concentration needs to be non-negative for numerical stability in computation of KL-divergence.
        # More specifically, since the log modified Bessel is used, the instability is introduced when
        # log(Iv(m/2, 0)) = log(0). This also prevents collapsing into the uniform prior.
        concentration = (
            self.concentration_encoder(x_initial_projection) +
            10  # TODO: Should concentration be fixed?
        )  # Shape: (batch_size,)

        vmf = VonMisesFisher(mean, concentration)

        z = vmf.rsample()

        x_prime = self.decoder(z)

        # TODO: Return loss in two parts.
        loss = vmf.kl_divergence().mean() + self.reconstruction_loss(
            x_prime, x)

        return {
            "concentration": concentration,
            "loss": loss,
            "mean": mean,
            "reconstruction": x_prime,  # TODO: Batchnorm?
            "z": z,
        }
Example #5
0
    def plot_latent_representation_of_noisy_samples(vmf: VonMisesFisher,
                                                    model: SphericalVAE,
                                                    ax,
                                                    color: str,
                                                    num_samples=200):
        ax.set_xlim([-2, 2])
        ax.set_ylim([-2, 2])
        average_mean = torch.zeros(2).unsqueeze(0)
        average_concentration = torch.zeros(1).unsqueeze(0)
        for _ in tqdm(range(num_samples)):
            sample = vmf.sample()
            sample_transformed = noisy_nonlinear_transformation(sample)
            output = model(sample_transformed)
            average_mean += output["mean"]
            average_concentration += output["concentration"]
            x, y = output["z"].squeeze().tolist()
            ax.scatter(x, y, color=color, marker=".")

        average_mean /= num_samples
        average_concentration /= num_samples

        ax.text(-1.5, 1.5,
                "Average mean: {}".format(average_mean.squeeze().data))
        ax.text(
            -1.5,
            -1.5,
            "Average concentration: {}".format(
                average_concentration.squeeze().data),
        )
Example #6
0
def main():
    sns.set_theme()
    axes = plt.gca()
    axes.set_xlim([-2, 2])
    axes.set_ylim([-2, 2])
    plt.gca().set_aspect("equal", adjustable="box")

    mean_1 = torch.tensor([math.sqrt(3) / 2, -1 / 2])
    mean_2 = torch.tensor([-math.sqrt(2) / 2, -math.sqrt(2) / 2])
    mean_3 = torch.tensor([-1 / 2, math.sqrt(3) / 2])

    vmf_1 = VonMisesFisher(mean_1, torch.tensor([5.0]))
    vmf_2 = VonMisesFisher(mean_2, torch.tensor([10.0]))
    vmf_3 = VonMisesFisher(mean_3, torch.tensor([2.0]))

    training_data = []

    for i in range(NUM_SAMPLES):
        sample_1 = vmf_1.sample()
        x_1, y_1 = sample_1.squeeze().tolist()
        plt.scatter(x_1, y_1, color="r", marker=".")

        sample_2 = vmf_2.sample()
        x_2, y_2 = sample_2.squeeze().tolist()
        plt.scatter(x_2, y_2, color="g", marker=".")

        sample_3 = vmf_3.sample()
        x_3, y_3 = sample_3.squeeze().tolist()
        plt.scatter(x_3, y_3, color="m", marker=".")

        training_data.append(sample_1)
        training_data.append(sample_2)
        training_data.append(sample_3)

    training_dataloader = DataLoader(
        training_data, batch_size=4, shuffle=True, num_workers=4
    )

    plt.show()
Example #7
0
def main():
    """
    A VAE with a hidden dim of 2 will be trained on 100-dimensional inputs
    created via noisy,nonlinear transformation of the 2-dimensional vMF samples.

    At evaluation time, the VAE will be given additional samples and the
    2-dimensional latent vectors will be plotted to verify that the VAE can learn
    a circular latent space.
    """
    torch.autograd.set_detect_anomaly(True)

    mean_1 = torch.tensor([math.sqrt(3) / 2, -1 / 2])
    mean_2 = torch.tensor([-math.sqrt(2) / 2, -math.sqrt(2) / 2])
    mean_3 = torch.tensor([-1 / 2, math.sqrt(3) / 2])

    vmf_1 = VonMisesFisher(mean_1, torch.tensor([5.0]))
    vmf_2 = VonMisesFisher(mean_2, torch.tensor([10.0]))
    vmf_3 = VonMisesFisher(mean_3, torch.tensor([2.0]))

    noisy_nonlinear_transformation = create_noisy_nonlinear_transformation(
        2, 100)
    training_data = []

    for i in tqdm(range(NUM_SAMPLES)):
        sample_1 = vmf_1.sample()
        sample_2 = vmf_2.sample()
        sample_3 = vmf_3.sample()

        training_data.append(noisy_nonlinear_transformation(
            sample_1.squeeze()))
        training_data.append(noisy_nonlinear_transformation(
            sample_2.squeeze()))
        training_data.append(noisy_nonlinear_transformation(
            sample_3.squeeze()))

    model = SphericalVAE(100, 25, 2)
    training_dataloader = DataLoader(training_data,
                                     batch_size=BATCH_SIZE,
                                     shuffle=True)
    for _ in range(NUM_EPOCHS):
        training_epoch(model, training_dataloader)

    def plot_latent_representation_of_noisy_samples(vmf: VonMisesFisher,
                                                    model: SphericalVAE,
                                                    ax,
                                                    color: str,
                                                    num_samples=200):
        ax.set_xlim([-2, 2])
        ax.set_ylim([-2, 2])
        average_mean = torch.zeros(2).unsqueeze(0)
        average_concentration = torch.zeros(1).unsqueeze(0)
        for _ in tqdm(range(num_samples)):
            sample = vmf.sample()
            sample_transformed = noisy_nonlinear_transformation(sample)
            output = model(sample_transformed)
            average_mean += output["mean"]
            average_concentration += output["concentration"]
            x, y = output["z"].squeeze().tolist()
            ax.scatter(x, y, color=color, marker=".")

        average_mean /= num_samples
        average_concentration /= num_samples

        ax.text(-1.5, 1.5,
                "Average mean: {}".format(average_mean.squeeze().data))
        ax.text(
            -1.5,
            -1.5,
            "Average concentration: {}".format(
                average_concentration.squeeze().data),
        )

    sns.set_theme()
    fig = plt.figure(figsize=(10, 10))
    ax1 = fig.add_subplot(221, adjustable="box", aspect=1.0)
    ax2 = fig.add_subplot(222, adjustable="box", aspect=1.0)
    ax3 = fig.add_subplot(223, adjustable="box", aspect=1.0)
    plot_latent_representation_of_noisy_samples(vmf_1, model, ax1, "r")
    plot_latent_representation_of_noisy_samples(vmf_2, model, ax2, "g")
    plot_latent_representation_of_noisy_samples(vmf_3, model, ax3, "m")

    ax4 = fig.add_subplot(224, adjustable="box", aspect=1.0)
    ax4.set_xlim([-2, 2])
    ax4.set_ylim([-2, 2])
    for _ in tqdm(range(200)):
        sample_1 = vmf_1.sample()
        sample_2 = vmf_2.sample()
        sample_3 = vmf_3.sample()

        sample_1_transformed = noisy_nonlinear_transformation(sample_1)
        sample_2_transformed = noisy_nonlinear_transformation(sample_2)
        sample_3_transformed = noisy_nonlinear_transformation(sample_3)

        output_1 = model(sample_1_transformed)
        output_2 = model(sample_2_transformed)
        output_3 = model(sample_3_transformed)

        x_1, y_1 = output_1["z"].squeeze().tolist()
        x_2, y_2 = output_2["z"].squeeze().tolist()
        x_3, y_3 = output_3["z"].squeeze().tolist()

        ax4.scatter(x_1, y_1, color="r", marker=".")
        ax4.scatter(x_2, y_2, color="g", marker=".")
        ax4.scatter(x_3, y_3, color="m", marker=".")

    plt.show()