예제 #1
0
    svi = pyro.infer.SVI(model, guide, opt, loss=pyro.infer.Trace_ELBO())

    X = 35  # Predict variable 35.
    for batch in data.batches(btchsz=256):
        b = batch.shape[0]  # Batch size.
        d = batch.shape[1]  # Dimension of the observations.
        # Pre-process the data to avoid NaN on 0s and 1.
        batch = torch.clamp(batch, min=.01, max=.99).to(device)
        # Remove variable to predict.
        idx = torch.tensor(np.delete(np.arange(d), X)).to(device)
        obs = batch[:, idx]
        loss = 0
        for step in range(2500):
            loss += svi.step(obs, idx)
            if (step + 1) % 10 == 0:
                print(loss)
                loss = 0
        # Inferred parameters.
        infmu = pyro.param('mu')
        infsd = pyro.param('sd')
        import pdb
        pdb.set_trace()
        # Sample latent variables with approximate posterior.
        z = Normal(infmu, infsd).sample([100]).view(100 * b, -1)
        # Propagate forward and sample observable 'x'.
        a, b = genr.detfwd(z)
        xx = Beta(a[:, X:X + 1], b[:, X:X + 1]).sample().view(100, -1, 1)
        x = xx.sum(dim=0) / 100.
        # Compare.
        torch.cat([x, batch[:, X:X + 1]], dim=1)
        nbatch = 0
        btot = 0
        for batch in data.batches(btchsz=128):
            bsz = batch.shape[0]  # Batch size.
            d = batch.shape[1]  # Dimension of the observations.
            # Pre-process the data to avoid NaN on 0s and 1.
            batch = torch.clamp(batch, min=.01, max=.99).to(device)
            # Remove variable to predict.
            #idx = torch.tensor(np.delete(np.arange(d), X)).to(device)
            idx = X
            obs = batch[:, idx:idx + 1]
            pyro.clear_param_store()  # Important on every restart.
            for step in range(2000):
                svi.step(obs, idx)
            # Inferred parameters.
            infmu = pyro.param('mu')
            infsd = pyro.param('sd')
            # Sample latent variables with approximate posterior.
            z = Normal(infmu, infsd).sample([1000]).view(1000 * bsz, -1)
            # Propagate forward and sample observable 'x'.
            with torch.no_grad():
                a, b = genr.detfwd(z)
            xx = Beta(a, b).sample().view(1000, -1, d)
            #xx = Beta(a,b).sample().view(1000,-1,1)
            tot += float(torch.sum((xx.sum(dim=0) / 1000 - batch)**2))
            nbatch += 1
            btot += float(bsz)
            if nbatch >= 20:
                break
        print(X, tot / btot)