Exemplo n.º 1
0
epochs = 300  # number of epochs to train for
batch_size = 256  # set batch size
lr = 2e-4  # set the learning rate of the adam optimiser
plot_interval = 1
path = r"/Users/edvardhulten/real_nvp_2d/"  # change to your own path (unless your name is Edvard Hultén too)
distr_name = "two_moons"
duration = 0.1
# ------------------------------------

if not os.path.exists("results"):
    os.makedirs("results")
if not os.path.exists("gifs"):
    os.makedirs("gifs")

data = Dataset(density)
x = data.generate_data(n_samples)

model = OldRealNVP(data_dim=x.shape[1],
                   n_c_layers=n_c_layers,
                   n_hidden=200,
                   hidden_dims=1)
# model = RealNVP(2, n_c_layers=n_c_layers, n_hidden=200, hidden_dims=2, bn=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss_log_det_J = NegLogLik(model, out="mean")
train_loader = DataLoader(dataset=x, batch_size=batch_size, shuffle=True)
# for plotting
z_norm = np.random.multivariate_normal(np.zeros(2), np.eye(2),
                                       5000).astype(np.float32)

start = time.time()
Exemplo n.º 2
0
path = r"/Users/edvardhulten/real_nvp_2d/"  # change to your own path (unless your name is Edvard Hultén too)
# ------------------------------------

if not os.path.exists("evals"):
    os.makedirs("evals")

model_ntnu = torch.load(path + "model_ntnu.pt")
model = torch.load(path + "model.pt")

gridspec = dict(wspace=0, width_ratios=[1, 0.1, 1, 1, 0.1, 1, 1])
fig, axes = plt.subplots(nrows=1,
                         ncols=7,
                         figsize=(12, 3),
                         gridspec_kw=gridspec)
data = Dataset(density)
x = data.generate_data(n_samples=10000)
axes[0].scatter(x[:, 0], x[:, 1], s=6, color="darkblue")
if density == "ntnu":
    axes[0].set_xlim(2.1, 8.2)
    axes[0].set_ylim(2.1, 8.2)
elif density == "moons":
    axes[0].set_xlim(-1.7, 2.6)
    axes[0].set_ylim(-2.1 + 0.225, 2.65 - 0.225)
axes[0].set_aspect(1)

x = data.generate_data(n_samples=2000)
sns.scatterplot(x[:, 0], x[:, 1], s=6, ax=axes[2], color="darkblue")
if density == "ntnu":
    axes[2].set_xlim(2.1, 8.2)
    axes[2].set_ylim(2.1, 8.2)
elif density == "moons":