Пример #1
0
def experiment(method):
    optim.zero_grad()
    b = Bernoulli(logits=eta.repeat(3))
    x = method(b)
    cost = torch.sum((x - p)**2, -1)
    storch.add_cost(cost, "cost")
    storch.backward()
    return eta.grad.clone()
Пример #2
0
def experiment(method):
    for i in range(2000):
        optim.zero_grad()
        b = Bernoulli(logits=eta)
        x = method(b)
        cost = torch.sum((x - p)**2, -1)
        storch.add_cost(cost, "cost")
        storch.backward()
        optim.step()
        if i % 100 == 0:
            print(eta)
Пример #3
0
    def forward(
        self, x: storch.Tensor
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):
        logits = self.encode(x)
        params = self.logits_to_params(logits, self.latents)
        var_posterior = self.variational_posterior(params)
        prior = self.prior(var_posterior)

        KLD = self.KLD(var_posterior, prior)
        storch.add_cost(KLD, "KL-divergence")
        z = self.sampling_method(var_posterior)
        return self.decode(z), KLD, z
Пример #4
0
def estimate_variance(method):
    gradient_samples = []
    for i in range(1000):
        f, c = compute_f(method)
        storch.add_cost(f, "f")
        storch.backward()
        gradient_samples.append(c.grad)
    gradients = storch.gather_samples(gradient_samples, "gradients")
    # print(gradients)
    print("variance", storch.variance(gradients, "gradients"))
    print("mean", storch.reduce_plates(gradients, "gradients"))
    print("st dev", torch.sqrt(storch.variance(gradients, "gradients")))

    print(type(gradients))
    print(gradients.shape)
    print(gradients.plates)
Пример #5
0
def generative_story(
    method: storch.method.Method, model: DiscreteVAE, data: torch.Tensor
):
    x = storch.denote_independent(data.view(-1, 784), 0, "data")

    # Encode data. Shape: (data, 2 * 10)
    q_logits = model.encode(x)
    # Shape: (data, 2, 10)
    q_logits = q_logits.reshape(-1, 2, 10)
    q = OneHotCategorical(probs=q_logits.softmax(dim=-1))
    # Sample from variational posterior
    z = method(q)

    prior = OneHotCategorical(probs=torch.ones_like(q.probs) / 10.0)
    # Shape: (data)
    KL_div = torch.distributions.kl_divergence(q, prior).sum(-1)
    storch.add_cost(KL_div, "kl-div")

    z_in = z.reshape(z.shape[:-2] + (2 * 10,))
    reconstruction = model.decode(z_in)
    bce = torch.nn.BCELoss(reduction="none")(reconstruction, x).sum(-1)
    # bce = torch.nn.BCELoss(reduction="sum")(reconstruction, x)
    storch.add_cost(bce, "reconstruction")
    return z
Пример #6
0
import storch
import torch
from torch.distributions import Bernoulli
from storch.method import GumbelSoftmax

torch.manual_seed(0)

p = torch.tensor(0.5, requires_grad=True)

for i in range(10000):
    sample = GumbelSoftmax(f"sample_{i}")(Bernoulli(p))
    storch.add_cost(sample, f"cost_{i}")

storch.backward()
print("Finished")
Пример #7
0
if isinstance(swr_method.sampling_method, storch.sampling.SampleWithoutReplacement):
    assert z_5.shape == (plt_n1, plt_n2, k, d_yv) or z_5.shape == (
        plt_n2,
        plt_n1,
        k,
        d_yv,
    )
else:
    assert z_5.shape == (plt_n2, k, d_yv) or z_5.shape == (k, plt_n2, d_yv)

d6 = OneHotCategorical(logits=z_3 + z_4.unsqueeze(-2) - z_5.unsqueeze(-2))

z_6 = swr_method.sample(d6)

print("z6", z_6)

assert z_6.shape == (plt_n1, plt_n2, k, event, d_yv) or z_6.shape == (
    plt_n2,
    plt_n1,
    k,
    event,
    d_yv,
)
# Sum the amount of event 1's being true.
cost = torch.sum(z_6[..., 0], -1)
storch.add_cost(cost, "cost")
storch.backward()

# Print what values of z1 are selected in the final sample step. As it has very low entropy, this should be all [0,0,1,0]
# print("final z1", z_6.plates[2].on_unwrap_tensor(z_1))
Пример #8
0
score_method = storch.method.ScoreFunction("white_noise_1", n_samples=2)
infer_method = storch.method.Infer("white_noise_2", Normal)


def loss(v):
    return torch.nn.MSELoss(reduction="none")(v, theta).mean(dim=-1)


mu = lax_method(Normal(mu_prior, 1))
k = expect(
    Categorical(probs=torch.tensor([[0.1, 0.3, 0.6], [0.1, 0.8, 0.1]],
                                   requires_grad=True)), )

agg_v = 0.0
s1 = 1.0
for i in range(2):
    k1, k2 = 0, 0
    if i == 1:
        k1 = k[:, 0]
        k2 = k[:, 1]
    s1 = score_method(Normal(mu + k1, 1))
    aaa = -mu + s1 * k2
    s2 = infer_method(Normal(-mu + s1 * k2, 1))
    # plus = lambda a, b: a + b
    # plus = storch.deterministic(plus)
    agg_v = agg_v + s1 + s2 * mu
    print(isinstance(agg_v, Iterable))
    storch.add_cost(loss(agg_v), "loss")

storch.backward(debug=False, print_costs=True)
Пример #9
0
import storch
import torch
from torch.distributions import Bernoulli, OneHotCategorical
from storch.method import RELAX, REBAR, ARM

torch.manual_seed(0)

p = torch.tensor(0.5, requires_grad=True)
d = Bernoulli(p)
sample = RELAX("sample", in_dim=1)(d)
# sample = ARM('sample', n_samples=10)(d)
storch.add_cost(sample, "cost")
storch.backward()

method = REBAR("test", n_samples=1)
x = torch.Tensor([[0.2, 0.4, 0.4], [0.5, 0.1, 0.4], [0.2, 0.2, 0.6],
                  [0.15, 0.15, 0.7]])
qx = OneHotCategorical(x)
print(method(qx))
Пример #10
0
import storch
import torch
from torch.distributions import Bernoulli, OneHotCategorical

expect = storch.method.Expect("x")
probs = torch.tensor([0.95, 0.01, 0.01, 0.01, 0.01, 0.01], requires_grad=True)
indices = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
b = OneHotCategorical(probs=probs)
z = expect.sample(b)
c = (2.4 * z * indices).sum(-1)
storch.add_cost(c, "no_baseline_cost")

storch.backward()

expect_grad = z.grad["probs"].clone()


def eval(grads):
    print("----------------------------------")
    grad_samples = storch.gather_samples(grads, "variance")
    mean = storch.reduce_plates(grad_samples, plates=["variance"])
    print("mean grad", mean)
    print("expected grad", expect_grad)
    print("specific_diffs", (mean - expect_grad)**2)
    mse = storch.reduce_plates((grad_samples - expect_grad)**2).sum()
    print("MSE", mse)
    bias = (storch.reduce_plates((mean - expect_grad)**2)).sum()
    print("bias", bias)
    return bias

Пример #11
0
    c = torch.tensor(0.23, requires_grad=True)
    d = a + b

    # Sample e from a normal distribution using reparameterization
    normal_distribution = Normal(b + c, 1)
    e = method(normal_distribution)

    f = d * e * e
    return f, c


# e*e follows a noncentral chi-squared distribution https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution
# exp_f = d * (1 + mu * mu)
repar = Reparameterization("e", n_samples=1)
f, c = compute_f(repar)
storch.add_cost(f, "f")
print(storch.backward())

print("first derivative estimate", c.grad)

f, c = compute_f(repar)
storch.add_cost(f, "f")
print(storch.backward())

print("second derivative estimate", c.grad)


def estimate_variance(method):
    gradient_samples = []
    for i in range(1000):
        f, c = compute_f(method)
Пример #12
0
def train(epoch, model, train_loader, device, optimizer, args, writer):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        storch.reset()

        # Denote the minibatch dimension as being independent
        data = storch.denote_independent(data.view(-1, 784), 0, "data")
        recon_batch, KLD, z = model(data)
        storch.add_cost(loss_function(recon_batch, data), "reconstruction")
        cost = backward()
        train_loss += cost.item()

        optimizer.step()

        cond_log = batch_idx % args.log_interval == 0

        if cond_log:
            step = 100.0 * batch_idx / len(train_loader)
            global_step = 100 * (epoch - 1) + step

            # Variance of expect method is 0 by definition.
            variances = {}
            if args.method != "expect" and args.variance_samples > 1:
                _consider_param = "probs"
                if args.latents < 3:
                    old_method = model.sampling_method
                    model.sampling_method = Expect("z")
                    optimizer.zero_grad()
                    recon_batch, _, z = model(data)
                    storch.add_cost(loss_function(recon_batch, data),
                                    "reconstruction")
                    backward()
                    expect_grad = storch.reduce_plates(
                        z.grad[_consider_param]).detach_tensor()

                    optimizer.zero_grad()
                    model.sampling_method = old_method
                grads = {n: [] for n in z.grad}

                for i in range(args.variance_samples):
                    optimizer.zero_grad()
                    recon_batch, _, z = model(data)
                    storch.add_cost(loss_function(recon_batch, data),
                                    "reconstruction")
                    backward()

                    for param_name, grad in z.grad.items():
                        # Make sure to reduce the data dimension and detach, for memory reasons.
                        grads[param_name].append(
                            storch.reduce_plates(grad).detach_tensor())

                variances = {}
                for param_name, gradz in grads.items():
                    # Create a new independent dimension for the different gradient samples
                    grad_samples = storch.gather_samples(gradz, "variance")
                    # Compute the variance over this independent dimension
                    variances[param_name] = storch.variance(
                        grad_samples, "variance")._tensor
                    if param_name == _consider_param and args.latents < 3:
                        mean = storch.reduce_plates(grad_samples, "variance")
                        mse = storch.reduce_plates(
                            (grad_samples - expect_grad)**2).sum()
                        bias = (storch.reduce_plates(
                            (mean - expect_grad)**2)).sum()
                        print("mse", mse._tensor.item())
                        # Should approach 0 when increasing variance_samples for unbiased estimators.
                        print("bias", bias._tensor.item())
                        writer.add_scalar("train/probs_bias", bias._tensor,
                                          global_step)
                        writer.add_scalar("train/probs_mse", mse._tensor,
                                          global_step)

            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tCost: {:.6f}\t Logits var {}"
                .format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    step,
                    cost.item(),
                    variances,
                ))
            writer.add_scalar("train/ELBO", cost, global_step)
            for param_name, var in variances.items():
                writer.add_scalar("train/variance/" + param_name, var,
                                  global_step)
    avg_train_loss = train_loss / (batch_idx + 1)
    print("====> Epoch: {} Average loss: {:.4f}".format(epoch, avg_train_loss))
    return avg_train_loss