def _post_backward_callback(mixture, n_iter, log_probs, prior_crossent, q_entropy, temperature, writer):
    if n_iter % 100 == 0:
        densities = mixture.forward(torch.Tensor(z)).numpy()
        f = plt.figure(figsize=(10, 10))
        zz = np.argmax(densities, axis=1).reshape([1000, 1000])

        plt.contourf(xx, yy, zz, 50, cmap="rainbow")

        with torch.no_grad():
            for i, component in enumerate(mixture.components):
                X_k = component.sample(512)

                plt.scatter(X_k[:, 0].numpy(), X_k[:, 1].numpy(),
                            c=[colors[i]],
                            s=5)

        plt.xlim(-1.1, 1.1)
        plt.ylim(-1.1, 1.1)
        
        writer.add_image("distributions", figure2tensor(f), n_iter)
        plt.close(f)        
Exemple #2
0
        opt.step()

        if it % 2 == 0:
            writer.add_scalar("debug/loss", loss, it)
            #writer.add_scalar("debug/log_probs", base_log_probs_, it)
            #writer.add_scalar("debug/log_abs_det_jacobian", log_abs_det_jacob_, it)
        if it % 5 == 0:
            with torch.no_grad():
                Xhat = flow.sample(1000)
                f = plt.figure(figsize=(10, 10))
                #plt.xlim(-1.5, 1.5)
                #plt.ylim(-1.5, 1.5)
                plt.title(f"{it} iterations")
                plt.scatter(X[:, 0], X[:, 1], s=5, c="blue", alpha=0.5)
                plt.scatter(Xhat[:, 0], Xhat[:, 1], s=5, c="red", alpha=0.5)
                writer.add_image("debug/samples", figure2tensor(f), it)
                plt.close(f)

                f = plt.figure(figsize=(10, 10))
                plt.title(f"{it} iterations")
                #plt.xlim(-1.5, 1.5)
                #plt.ylim(-1.5, 1.5)
                z = flow.base_dist.sample((1000, ))
                zhat, _ = flow.inverse(X)
                plt.scatter(z[:, 0], z[:, 1], s=5, c="blue", alpha=0.5)
                plt.scatter(zhat[:, 0], zhat[:, 1], s=5, c="red", alpha=0.5)
                writer.add_image("debug/base_dist", figure2tensor(f), it)
                plt.close(f)

# %%
loss
Exemple #3
0
        best_flow = deepcopy(flow)
    
    loss.backward()
    
    if it % 50 == 0:
        writer.add_scalar("loss", loss, it)

    if it % 5000 == 0:
        with torch.no_grad():
            xhat_samples = flow.final_density.sample((1000, ))
            f = plt.figure(figsize=(10, 10))
            plt.xlim(-30, 30)
            plt.ylim(-20, 20)
            plt.scatter(xhat_samples[:, 0], xhat_samples[:, 1], s=5, c="red", alpha=0.5)
            plt.scatter(x_samples[:, 0], x_samples[:, 1], s=5, c="blue", alpha=0.5)
            writer.add_image("distributions", figure2tensor(f), it)
            plt.close(f)

    opt.step()

# %%
flow = best_flow

# %%
xhat_samples = flow.final_density.sample((1000, ))
plt.scatter(xhat_samples[:, 0], xhat_samples[:, 1], s=5, c="red")
plt.scatter(x_samples[:, 0], x_samples[:, 1], s=5, c="blue")
#plt.xlim(0, 60)
#plt.ylim(-15, 15)
plt.show()
Exemple #4
0
    def fit(self,
            X,
            dataloader,
            n_epochs=1,
            opt=None,
            temperature_schedule=None,
            clip_grad=None,
            verbose=False,
            writer=None):

        best_loss = float("inf")
        best_params = dict()

        if temperature_schedule is None:
            temperature_schedule = lambda t: 1

        if verbose:
            epochs = trange(n_epochs, desc="epoch")
        else:
            epochs = range(n_epochs)

        for epoch in epochs:
            for i, xb in enumerate(dataloader):
                opt.zero_grad()
                n_iter = epoch * (
                    (len(X) - 1) // dataloader.batch_size + 1) + i

                log_probs, prior_crossent, q_entropy = self.elbo(
                    xb, temperature_schedule(n_iter))
                loss = -(log_probs + prior_crossent + q_entropy)

                if loss != loss:
                    continue

                if loss <= best_loss:
                    best_loss = loss.item()
                    best_params = self.state_dict()

                # if we're writing to tensorboard
                if writer is not None:
                    if n_iter % 20 == 0:
                        writer.add_scalar('losses/log_probs', log_probs,
                                          n_iter)
                        #writer.add_scalar('losses/prior_crossent', prior_crossent, n_iter)
                        writer.add_scalar('losses/q_entropy', q_entropy,
                                          n_iter)

                loss.backward()

                if n_iter % 100 == 0:
                    with torch.no_grad():
                        densities = mixture.forward(torch.Tensor(z)).numpy()
                        f = plt.figure(figsize=(10, 10))
                        zz = np.argmax(densities, axis=1).reshape([1000, 1000])

                        plt.contourf(xx, yy, zz, 50, cmap="rainbow")

                        colors = ["yellow", "green", "black", "cyan"]
                        with torch.no_grad():
                            for i, component in enumerate(mixture.components):
                                X_k = component.sample(500)

                                plt.scatter(X_k[:, 0].numpy(),
                                            X_k[:, 1].numpy(),
                                            c=colors[i],
                                            s=5)

                        plt.xlim(-1.1, 1.1)
                        plt.ylim(-1.1, 1.1)
                        writer.add_image("distributions", figure2tensor(f),
                                         n_iter)
                        plt.close(f)

                #if clip_grad is not None:
                #    torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad)

                opt.step()
                #if n_iter == 2500:
                #print("changing learning rates")
                #for param_group in opt.param_groups:
                #if param_group["label"] == "remaining":
                #   param_group["lr"] = 6e-2

                #if param_group["label"] == "encoder":
                #    param_group["lr"] = 1e-3

        return best_loss, best_params