예제 #1
0
파일: test_optim.py 프로젝트: pyro-ppl/pyro
    def do_test_per_param_optim(self, fixed_param, free_param):
        pyro.clear_param_store()

        def model():
            prior_dist = Normal(self.loc0, torch.pow(self.lam0, -0.5))
            loc_latent = pyro.sample("loc_latent", prior_dist)
            x_dist = Normal(loc_latent, torch.pow(self.lam, -0.5))
            pyro.sample("obs", x_dist, obs=self.data)
            return loc_latent

        def guide():
            loc_q = pyro.param("loc_q", torch.zeros(1, requires_grad=True))
            log_sig_q = pyro.param("log_sig_q",
                                   torch.zeros(1, requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample("loc_latent", Normal(loc_q, sig_q))

        def optim_params(param_name):
            if param_name == fixed_param:
                return {"lr": 0.00}
            elif param_name == free_param:
                return {"lr": 0.01}

        adam = optim.Adam(optim_params)
        adam2 = optim.Adam(optim_params)
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
        svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO())

        svi.step()
        adam_initial_step_count = list(
            adam.get_state()["loc_q"]["state"].items())[0][1]["step"]
        with TemporaryDirectory() as tempdir:
            filename = os.path.join(tempdir, "optimizer_state.pt")
            adam.save(filename)
            svi.step()
            adam_final_step_count = list(
                adam.get_state()["loc_q"]["state"].items())[0][1]["step"]
            adam2.load(filename)
        svi2.step()
        adam2_step_count_after_load_and_step = list(
            adam2.get_state()["loc_q"]["state"].items())[0][1]["step"]

        assert adam_initial_step_count == 1
        assert adam_final_step_count == 2
        assert adam2_step_count_after_load_and_step == 2

        free_param_unchanged = torch.equal(
            pyro.param(free_param).data, torch.zeros(1))
        fixed_param_unchanged = torch.equal(
            pyro.param(fixed_param).data, torch.zeros(1))
        assert fixed_param_unchanged and not free_param_unchanged
예제 #2
0
    def do_test_per_param_optim(self, fixed_param, free_param):
        pyro.clear_param_store()

        def model():
            prior_dist = Normal(self.loc0, torch.pow(self.lam0, -0.5))
            loc_latent = pyro.sample("loc_latent", prior_dist)
            x_dist = Normal(loc_latent, torch.pow(self.lam, -0.5))
            pyro.sample("obs", x_dist, obs=self.data)
            return loc_latent

        def guide():
            loc_q = pyro.param("loc_q", torch.zeros(1, requires_grad=True))
            log_sig_q = pyro.param("log_sig_q",
                                   torch.zeros(1, requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample("loc_latent", Normal(loc_q, sig_q))

        def optim_params(module_name, param_name):
            if param_name == fixed_param:
                return {'lr': 0.00}
            elif param_name == free_param:
                return {'lr': 0.01}

        adam = optim.Adam(optim_params)
        adam2 = optim.Adam(optim_params)
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
        svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO())

        svi.step()
        adam_initial_step_count = list(
            adam.get_state()['loc_q']['state'].items())[0][1]['step']
        adam.save('adam.unittest.save')
        svi.step()
        adam_final_step_count = list(
            adam.get_state()['loc_q']['state'].items())[0][1]['step']
        adam2.load('adam.unittest.save')
        svi2.step()
        adam2_step_count_after_load_and_step = list(
            adam2.get_state()['loc_q']['state'].items())[0][1]['step']

        assert adam_initial_step_count == 1
        assert adam_final_step_count == 2
        assert adam2_step_count_after_load_and_step == 2

        free_param_unchanged = torch.equal(
            pyro.param(free_param).data, torch.zeros(1))
        fixed_param_unchanged = torch.equal(
            pyro.param(fixed_param).data, torch.zeros(1))
        assert fixed_param_unchanged and not free_param_unchanged
예제 #3
0
    def do_elbo_test(self, reparameterized, n_steps, beta1, lr):
        logger.info(" - - - - - DO BETA-BERNOULLI ELBO TEST [repa = %s] - - - - - " % reparameterized)
        pyro.clear_param_store()
        Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta

        def model():
            p_latent = pyro.sample("p_latent", Beta(self.alpha0, self.beta0))
            with pyro.plate("data", len(self.data)):
                pyro.sample("obs", dist.Bernoulli(p_latent), obs=self.data)
            return p_latent

        def guide():
            alpha_q_log = pyro.param("alpha_q_log",
                                     self.log_alpha_n + 0.17)
            beta_q_log = pyro.param("beta_q_log",
                                    self.log_beta_n - 0.143)
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            p_latent = pyro.sample("p_latent", Beta(alpha_q, beta_q),
                                   infer=dict(baseline=dict(use_decaying_avg_baseline=True)))
            with pyro.plate("data", len(self.data)):
                pass
            return p_latent

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for k in range(n_steps):
            svi.step()
            alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n)
            beta_error = param_abs_error("beta_q_log", self.log_beta_n)
            if k % 500 == 0:
                logger.debug("alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error))

        assert_equal(0.0, alpha_error, prec=0.03)
        assert_equal(0.0, beta_error, prec=0.04)
예제 #4
0
def test_dynamic_lr(scheduler, num_steps):
    pyro.clear_param_store()

    def model():
        sample = pyro.sample('latent',
                             Normal(torch.tensor(0.), torch.tensor(0.3)))
        return pyro.sample('obs',
                           Normal(sample, torch.tensor(0.2)),
                           obs=torch.tensor(0.1))

    def guide():
        loc = pyro.param('loc', torch.tensor(0.))
        scale = pyro.param('scale',
                           torch.tensor(0.5),
                           constraint=constraints.positive)
        pyro.sample('latent', Normal(loc, scale))

    svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
    for epoch in range(2):
        scheduler.set_epoch(epoch)
        for _ in range(num_steps):
            svi.step()
        if epoch == 1:
            loc = pyro.param('loc').unconstrained()
            scale = pyro.param('scale').unconstrained()
            opt = scheduler.optim_objs[loc].optimizer
            assert opt.state_dict()['param_groups'][0]['lr'] == 0.02
            assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
            opt = scheduler.optim_objs[scale].optimizer
            assert opt.state_dict()['param_groups'][0]['lr'] == 0.02
            assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
            assert abs(pyro.param('loc').item()) > 1e-5
            assert abs(pyro.param('scale').item() - 0.5) > 1e-5
예제 #5
0
def test_three_indep_iarange_at_different_depths_ok():
    """
      /\
     /\ ia
    ia ia
    """
    def model():
        p = torch.tensor(0.5)
        inner_iarange = pyro.iarange("iarange1", 10, 5)
        for i in pyro.irange("irange0", 2):
            pyro.sample("x_%d" % i, dist.Bernoulli(p))
            if i == 0:
                for j in pyro.irange("irange1", 2):
                    with inner_iarange as ind:
                        pyro.sample("y_%d" % j,
                                    dist.Bernoulli(p).expand_by([len(ind)]))
            elif i == 1:
                with inner_iarange as ind:
                    pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        inner_iarange = pyro.iarange("iarange1", 10, 5)
        for i in pyro.irange("irange0", 2):
            pyro.sample("x_%d" % i, dist.Bernoulli(p))
            if i == 0:
                for j in pyro.irange("irange1", 2):
                    with inner_iarange as ind:
                        pyro.sample("y_%d" % j,
                                    dist.Bernoulli(p).expand_by([len(ind)]))
            elif i == 1:
                with inner_iarange as ind:
                    pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind)]))

    assert_ok(model, guide, TraceGraph_ELBO())
예제 #6
0
    def do_inference(self, use_decaying_avg_baseline, tolerance=0.8):
        pyro.clear_param_store()
        optimizer_params = {"lr": 0.0005, "betas": (0.93, 0.999)}
        #optimizer = optim.Adam(optimizer_params)
        optimizer = optim.Adam(self.per_param_args)
        svi = SVI(self.model, self.nnguide, optimizer, loss=TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline=%s" %
              use_decaying_avg_baseline)

        ae, be = [], []
        # do up to this many steps of inference
        for k in range(self.max_steps):
            svi.step(use_decaying_avg_baseline, torch.tensor([1.0]))
            if k % 100 == 0:
                print('.', end='')
                sys.stdout.flush()

            # compute the distance to the parameters of the true posterior
            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            ae.append(alpha_error)
            be.append(beta_error)
            # stop inference early if we're close to the true posterior
            if alpha_error < tolerance and beta_error < tolerance:
                print("Stopped early at step: {}".format(k))
                break
        return (ae, be)
예제 #7
0
def main(model, guide, args):
    # init
    if args.seed is not None: pyro.set_rng_seed(args.seed)
    logger = get_logger(args.log, __name__)
    logger.info(args)

    # load data
    X, true_counts = load_data()
    X_size = X.size(0)

    # setup svi
    def per_param_optim_args(module_name, param_name):
        def isBaselineParam(module_name, param_name):
            return 'bl_' in module_name or 'bl_' in param_name
        lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\
             else args.learning_rate
        return {'lr': lr}

    opt = optim.Adam(per_param_optim_args)
    svi = SVI(model.main, guide.main, opt, loss=TraceGraph_ELBO())

    # train
    times = [time.time()]
    logger.info(f"\nstep\t" + "epoch\t" + "elbo\t" + "time(sec)")

    for i in range(1, args.num_steps + 1):
        loss = svi.step(X)

        if (args.eval_frequency > 0
                and i % args.eval_frequency == 0) or (i == 1):
            times.append(time.time())
            logger.info(f"{i:06d}\t"
                        f"{(i * args.batch_size) / X_size:.3f}\t"
                        f"{-loss / X_size:.4f}\t"
                        f"{times[-1]-times[-2]:.3f}")
예제 #8
0
def main(**kwargs):
    args = argparse.Namespace(**kwargs)
    args.batch_size = 64
    
    pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)

    def per_param_optim_args(module_name, param_name):
        def isBaselineParam(module_name, param_name):
            return 'bl_' in module_name or 'bl_' in param_name
        lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\
             else args.learning_rate
        return {'lr': lr}

    adam = optim.Adam(per_param_optim_args)
    elbo = TraceGraph_ELBO()
    svi = SVI(air.model, air.guide, adam, loss=elbo) # wy

    t0 = time.time()
    for i in range(1, args.num_steps + 1):
        loss = svi.step(X) # wy

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i,
                (i * args.batch_size) / X_size,
                (time.time() - t0) / 3600,
                loss / X_size))

        if args.eval_every > 0 and i % args.eval_every == 0:
            acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
예제 #9
0
    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        elbo = TraceGraph_ELBO(vectorize_particles=False, num_particles=4)
        svi = SVI(self.model.model, self.model.guide, self.optimizer, loss=elbo)
        imps = ImportanceSampler(self.model.model, self.model.guide,
                                 num_samples=4)

        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                loss = svi.evaluate_loss(observations=data) / data.shape[0]
                imps.sample(observations=data)
                log_likelihood = imps.get_log_likelihood().item() / data.shape[0]
                log_marginal = imps.get_log_normalizer().item() / data.shape[0]

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss)
                self.valid_metrics.update('log_likelihood', log_likelihood)
                self.valid_metrics.update('log_marginal', log_marginal)

                for met in self.metric_ftns:
                    metric_val = met(self.model.model, self.model.guide, data, target, 4)
                    self.valid_metrics.update(met.__name__, metric_val)

                if self.log_images:
                    self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

        return self.valid_metrics.result()
예제 #10
0
    def test_nested_irange_in_elbo(self, n_steps=4000):
        pyro.clear_param_store()

        def model():
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(self.loc0,
                                               torch.pow(self.lam0,
                                                         -0.5)).independent(1))
            for i in pyro.irange("outer", self.n_outer):
                for j in pyro.irange("inner_%d" % i, self.n_inner):
                    pyro.sample("obs_%d_%d" % (i, j),
                                dist.Normal(loc_latent,
                                            torch.pow(self.lam,
                                                      -0.5)).independent(1),
                                obs=self.data[i][j])

        def guide():
            loc_q = pyro.param(
                "loc_q",
                torch.tensor(self.analytic_loc_n.expand(2) + 0.234,
                             requires_grad=True))
            log_sig_q = pyro.param(
                "log_sig_q",
                torch.tensor(self.analytic_log_sig_n.expand(2) - 0.27,
                             requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(loc_q, sig_q).independent(1),
                infer=dict(baseline=dict(use_decaying_avg_baseline=True)))

            for i in pyro.irange("outer", self.n_outer):
                for j in pyro.irange("inner_%d" % i, self.n_inner):
                    pass

        guide_trace = pyro.poutine.trace(guide, graph_type="dense").get_trace()
        model_trace = pyro.poutine.trace(pyro.poutine.replay(
            model, trace=guide_trace),
                                         graph_type="dense").get_trace()
        assert len(model_trace.edges()) == 27
        assert len(model_trace.nodes()) == 16
        assert len(guide_trace.edges()) == 0
        assert len(guide_trace.nodes()) == 9

        adam = optim.Adam({"lr": 0.0008, "betas": (0.96, 0.999)})
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for k in range(n_steps):
            svi.step()
            loc_error = param_mse("loc_q", self.analytic_loc_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
            if k % 500 == 0:
                logger.debug("loc error, log(scale) error:  %.4f, %.4f" %
                             (loc_error, log_sig_error))

        assert_equal(0.0, loc_error, prec=0.04)
        assert_equal(0.0, log_sig_error, prec=0.04)
def main(**kwargs):
    args = argparse.Namespace(**kwargs)
    args.batch_size = 64

    # WL: edited to fix a bug. =====
    #pyro.set_rng_seed(args.seed)
    if args.seed is not None:
        pyro.set_rng_seed(args.seed)
    # ==============================

    X, true_counts = load_data()
    X_size = X.size(0)

    def per_param_optim_args(module_name, param_name):
        def isBaselineParam(module_name, param_name):
            return 'bl_' in module_name or 'bl_' in param_name
        lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\
             else args.learning_rate
        return {'lr': lr}

    adam = optim.Adam(per_param_optim_args)
    elbo = TraceGraph_ELBO()
    svi = SVI(air.model, air.guide, adam, loss=elbo)  # wy

    # WL: added for test. =====
    print(args)
    times = [time.time()]
    print("\nstep\t" + "epoch\t" + "elbo\t" + "time(sec)", flush=True)
    #==========================

    t0 = time.time()
    for i in range(1, args.num_steps + 1):
        loss = svi.step(X)  # wy

        if args.progress_every > 0 and i % args.progress_every == 0:
            # # WL: edited for test. =====
            # print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
            #     i,
            #     (i * args.batch_size) / X_size,
            #     (time.time() - t0) / 3600,
            #     loss / X_size))
            times.append(time.time())
            print(
                f"{i:06d}\t"
                f"{(i * args.batch_size) / X_size:.3f}\t"
                f"{-loss / X_size:.4f}\t"
                f"{times[-1]-times[-2]:.3f}",
                flush=True)
            # ===========================

        if args.eval_every > 0 and i % args.eval_every == 0:
            acc, counts, error_z, error_ix = count_accuracy(
                X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(
                i, acc,
                counts.numpy().tolist()))
예제 #12
0
def test_iarange_wrong_size_error():
    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)]))

    assert_error(model, guide, TraceGraph_ELBO())
예제 #13
0
파일: test_optim.py 프로젝트: pyro-ppl/pyro
def test_dynamic_lr(scheduler):
    pyro.clear_param_store()

    def model():
        sample = pyro.sample("latent",
                             Normal(torch.tensor(0.0), torch.tensor(0.3)))
        return pyro.sample("obs",
                           Normal(sample, torch.tensor(0.2)),
                           obs=torch.tensor(0.1))

    def guide():
        loc = pyro.param("loc", torch.tensor(0.0))
        scale = pyro.param("scale",
                           torch.tensor(0.5),
                           constraint=constraints.positive)
        pyro.sample("latent", Normal(loc, scale))

    svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
    for epoch in range(4):
        svi.step()
        svi.step()
        loc = pyro.param("loc").unconstrained()
        opt_loc = scheduler.optim_objs[loc].optimizer
        opt_scale = scheduler.optim_objs[loc].optimizer
        if issubclass(
                scheduler.pt_scheduler_constructor,
                torch.optim.lr_scheduler.ReduceLROnPlateau,
        ):
            scheduler.step(1.0)
            if epoch == 2:
                assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.1
                assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.1
            if epoch == 4:
                assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.01
                assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.01
            continue
        assert opt_loc.state_dict()["param_groups"][0]["initial_lr"] == 0.01
        assert opt_scale.state_dict()["param_groups"][0]["initial_lr"] == 0.01
        if epoch == 0:
            scheduler.step()
            assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.02
            assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.02
            assert abs(pyro.param("loc").item()) > 1e-5
            assert abs(pyro.param("scale").item() - 0.5) > 1e-5
        if epoch == 2:
            scheduler.step()
            assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.04
            assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.04
예제 #14
0
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        if self.jit:
            elbo = JitTraceGraph_ELBO(vectorize_particles=False,
                                      num_particles=self.num_particles)
        else:
            elbo = TraceGraph_ELBO(vectorize_particles=False,
                                   num_particles=self.num_particles)
        svi = SVI(self.model.model,
                  self.model.guide,
                  self.optimizer,
                  loss=elbo)

        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            loss = svi.step(observations=data)

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(target))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))
                self.writer.add_image(
                    'input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log
예제 #15
0
    def do_elbo_test(self, reparameterized, n_steps, prec):
        logger.info(
            " - - - - - DO NORMALNORMAL ELBO TEST  [reparameterized = %s] - - - - - "
            % reparameterized)
        pyro.clear_param_store()
        Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

        def model():
            with pyro.iarange("iarange", 2):
                loc_latent = pyro.sample(
                    "loc_latent", Normal(self.loc0, torch.pow(self.lam0,
                                                              -0.5)))
                for i, x in enumerate(self.data):
                    pyro.sample("obs_%d" % i,
                                dist.Normal(loc_latent,
                                            torch.pow(self.lam, -0.5)),
                                obs=x)
            return loc_latent

        def guide():
            loc_q = pyro.param(
                "loc_q",
                torch.tensor(self.analytic_loc_n.expand(2) + 0.334,
                             requires_grad=True))
            log_sig_q = pyro.param(
                "log_sig_q",
                torch.tensor(self.analytic_log_sig_n.expand(2) - 0.29,
                             requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            with pyro.iarange("iarange", 2):
                loc_latent = pyro.sample("loc_latent", Normal(loc_q, sig_q))
            return loc_latent

        adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for k in range(n_steps):
            svi.step()

            loc_error = param_mse("loc_q", self.analytic_loc_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
            if k % 250 == 0:
                logger.debug("loc error, log(scale) error:  %.4f, %.4f" %
                             (loc_error, log_sig_error))

        assert_equal(0.0, loc_error, prec=prec)
        assert_equal(0.0, log_sig_error, prec=prec)
예제 #16
0
    def on_epoch_start(self):
        total_epochs = self._n_epochs
        current_epoch = self._current_epoch

        if current_epoch > total_epochs * self.emb_train_epochs:
            if not self.start_finetune:
                pyro.clear_param_store()
                print('Switching to pyro')
                self.start_finetune = True
                self.optimizer = Adam({"lr": 1e-4})
                self.svi = SVI(self.model_wrapper.pyro_model,
                               self.model_wrapper.pyro_guide,
                               self.optimizer,
                               loss=TraceGraph_ELBO())
        else:
            if self.optimizer is None:
                self.optimizer = torch.optim.Adam(
                    self.model_wrapper.model.parameters())
예제 #17
0
    def do_elbo_test(self, reparameterized, n_steps, beta1, lr):
        logger.info(
            " - - - - - DO EXPONENTIAL-GAMMA ELBO TEST [repa = %s] - - - - - "
            % reparameterized)
        pyro.clear_param_store()
        Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma

        def model():
            lambda_latent = pyro.sample("lambda_latent",
                                        Gamma(self.alpha0, self.beta0))
            with pyro.iarange("data", len(self.data)):
                pyro.sample("obs",
                            dist.Exponential(lambda_latent),
                            obs=self.data)
            return lambda_latent

        def guide():
            alpha_q_log = pyro.param(
                "alpha_q_log",
                torch.tensor(self.log_alpha_n + 0.17, requires_grad=True))
            beta_q_log = pyro.param(
                "beta_q_log",
                torch.tensor(self.log_beta_n - 0.143, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            pyro.sample(
                "lambda_latent",
                Gamma(alpha_q, beta_q),
                infer=dict(baseline=dict(use_decaying_avg_baseline=True)))
            with pyro.iarange("data", len(self.data)):
                pass

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for k in range(n_steps):
            svi.step()
            alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n)
            beta_error = param_abs_error("beta_q_log", self.log_beta_n)
            if k % 500 == 0:
                logger.debug("alpha_error, beta_error: %.4f, %.4f" %
                             (alpha_error, beta_error))

        assert_equal(0.0, alpha_error, prec=0.04)
        assert_equal(0.0, beta_error, prec=0.04)
예제 #18
0
    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        if self.jit:
            elbo = JitTraceGraph_ELBO(vectorize_particles=False,
                                      num_particles=self.num_particles)
        else:
            elbo = TraceGraph_ELBO(vectorize_particles=False,
                                   num_particles=self.num_particles)
        svi = SVI(self.model.model,
                  self.model.guide,
                  self.optimizer,
                  loss=elbo)

        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                loss = svi.evaluate_loss(observations=data)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(target))
                self.writer.add_image(
                    'input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()
예제 #19
0
    def do_elbo_test(
        self,
        repa1,
        repa2,
        n_steps,
        prec,
        lr,
        use_nn_baseline,
        use_decaying_avg_baseline,
    ):
        logger.info(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -")
        logger.info(
            "[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]"
            % (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline))
        pyro.clear_param_store()
        Normal1 = dist.Normal if repa1 else fakes.NonreparameterizedNormal
        Normal2 = dist.Normal if repa2 else fakes.NonreparameterizedNormal

        if use_nn_baseline:

            class VanillaBaselineNN(nn.Module):
                def __init__(self, dim_input, dim_h):
                    super().__init__()
                    self.lin1 = nn.Linear(dim_input, dim_h)
                    self.lin2 = nn.Linear(dim_h, 2)
                    self.sigmoid = nn.Sigmoid()

                def forward(self, x):
                    h = self.sigmoid(self.lin1(x))
                    return self.lin2(h)

            loc_prime_baseline = pyro.module("loc_prime_baseline",
                                             VanillaBaselineNN(2, 5))
        else:
            loc_prime_baseline = None

        def model():
            with pyro.plate("plate", 2):
                loc_latent_prime = pyro.sample(
                    "loc_latent_prime",
                    Normal1(self.loc0, torch.pow(self.lam0, -0.5)))
                loc_latent = pyro.sample(
                    "loc_latent",
                    Normal2(loc_latent_prime, torch.pow(self.lam0, -0.5)))
                with pyro.plate("data", len(self.data)):
                    pyro.sample(
                        "obs",
                        dist.Normal(loc_latent, torch.pow(
                            self.lam, -0.5)).expand_by(self.data.shape[:1]),
                        obs=self.data,
                    )
            return loc_latent

        # note that the exact posterior is not mean field!
        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.29)
            loc_q_prime = pyro.param("loc_q_prime", torch.tensor([-0.34,
                                                                  0.52]))
            kappa_q = pyro.param("kappa_q", torch.tensor([0.74]))
            log_sig_q_prime = pyro.param("log_sig_q_prime",
                                         -0.5 * torch.log(1.2 * self.lam0))
            sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(
                log_sig_q_prime)
            with pyro.plate("plate", 2):
                loc_latent = pyro.sample(
                    "loc_latent",
                    Normal2(loc_q, sig_q),
                    infer=dict(baseline=dict(
                        use_decaying_avg_baseline=use_decaying_avg_baseline)),
                )
                pyro.sample(
                    "loc_latent_prime",
                    Normal1(
                        kappa_q.expand_as(loc_latent) * loc_latent +
                        loc_q_prime,
                        sig_q_prime,
                    ),
                    infer=dict(baseline=dict(
                        nn_baseline=loc_prime_baseline,
                        nn_baseline_input=loc_latent,
                        use_decaying_avg_baseline=use_decaying_avg_baseline,
                    )),
                )
                with pyro.plate("data", len(self.data)):
                    pass

            return loc_latent

        adam = optim.Adam({"lr": 0.0015, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for k in range(n_steps):
            svi.step()

            loc_error = param_mse("loc_q", self.analytic_loc_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
            loc_prime_error = param_mse("loc_q_prime", 0.5 * self.loc0)
            kappa_error = param_mse("kappa_q", 0.5 * torch.ones(1))
            log_sig_prime_error = param_mse("log_sig_q_prime",
                                            -0.5 * torch.log(2.0 * self.lam0))

            if k % 500 == 0:
                logger.debug("errors:  %.4f, %.4f" %
                             (loc_error, log_sig_error))
                logger.debug(", %.4f, %.4f" %
                             (loc_prime_error, log_sig_prime_error))
                logger.debug(", %.4f" % kappa_error)

        assert_equal(0.0, loc_error, prec=prec)
        assert_equal(0.0, log_sig_error, prec=prec)
        assert_equal(0.0, loc_prime_error, prec=prec)
        assert_equal(0.0, log_sig_prime_error, prec=prec)
        assert_equal(0.0, kappa_error, prec=prec)
예제 #20
0
            global init_state
            init_state = reset_init_state()
            survive = guide(500)
            results.append(survive)
        self.echo = False
        if np.mean(results) > imme_time * 0.9 and imme_time < MAXTIME:
            imme_time = imme_time * 2
            print("update training max_time to", imme_time)


agent = AgentModel()
guide = agent.guide
model = agent.model
learning_rate = 2e-5
optimizer = optim.Adam({"lr": learning_rate})
svi = SVI(model, guide, optimizer, loss=TraceGraph_ELBO())


def optimize():
    global imme_time
    loss = 0
    print("Optimizing...")
    for t in range(num_steps):
        global init_state
        init_state = reset_init_state()
        loss += svi.step(imme_time)
        if (t % 1000 == 0) and (t > 0):
            print("at {} step loss is {}".format(t, loss / t))


def train(epoch=2, batch_size=10):
    def do_elbo_test(self,
                     reparameterized,
                     n_steps,
                     lr,
                     prec,
                     beta1,
                     difficulty=1.0,
                     model_permutation=False):
        n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \
            else len(self.q_topo_sort)
        logger.info((
            " - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " +
            "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -"
        ) % (self.N, (2**self.N) - 1, reparameterized, n_repa_nodes,
             len(self.q_topo_sort), model_permutation))
        pyro.clear_param_store()
        # check graph structure is as expected but only for N=2
        if self.N == 2:
            guide_trace = pyro.poutine.trace(
                self.guide, graph_type="dense").get_trace(
                    reparameterized=reparameterized,
                    model_permutation=model_permutation,
                    difficulty=difficulty)
            expected_nodes = set([
                'log_sig_1R', 'kappa_1_1L', '_INPUT',
                'constant_term_loc_latent_1R', '_RETURN', 'loc_latent_1R',
                'loc_latent_1', 'constant_term_loc_latent_1', 'loc_latent_1L',
                'constant_term_loc_latent_1L', 'log_sig_1L', 'kappa_1_1R',
                'kappa_1R_1L', 'log_sig_1'
            ])
            expected_edges = set([('loc_latent_1R', 'loc_latent_1'),
                                  ('loc_latent_1L', 'loc_latent_1R'),
                                  ('loc_latent_1L', 'loc_latent_1')])
            assert expected_nodes == set(guide_trace.nodes)
            assert expected_edges == set(guide_trace.edges)

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO())

        for step in range(n_steps):
            t0 = time.time()
            svi.step(reparameterized=reparameterized,
                     model_permutation=model_permutation,
                     difficulty=difficulty)

            if step % 50 == 0 or step == n_steps - 1:
                log_sig_errors = []
                for node in self.target_lambdas:
                    target_log_sig = -0.5 * torch.log(
                        self.target_lambdas[node])
                    log_sig_error = param_mse('log_sig_' + node,
                                              target_log_sig)
                    log_sig_errors.append(log_sig_error)
                max_log_sig_error = np.max(log_sig_errors)
                min_log_sig_error = np.min(log_sig_errors)
                mean_log_sig_error = np.mean(log_sig_errors)
                leftmost_node = self.q_topo_sort[0]
                leftmost_constant_error = param_mse(
                    'constant_term_' + leftmost_node,
                    self.target_leftmost_constant)
                almost_leftmost_constant_error = param_mse(
                    'constant_term_' + leftmost_node[:-1] + 'R',
                    self.target_almost_leftmost_constant)

                logger.debug(
                    "[mean function constant errors (partial)]   %.4f  %.4f" %
                    (leftmost_constant_error, almost_leftmost_constant_error))
                logger.debug(
                    "[min/mean/max log(scale) errors]   %.4f  %.4f   %.4f" %
                    (min_log_sig_error, mean_log_sig_error, max_log_sig_error))
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" %
                             (time.time() - t0, self.N, step))

        assert_equal(0.0, max_log_sig_error, prec=prec)
        assert_equal(0.0, leftmost_constant_error, prec=prec)
        assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
    def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0):
        n_repa_nodes = torch.sum(
            self.which_nodes_reparam) if not reparameterized else self.N
        logger.info(
            " - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST  [reparameterized = %s; %d/%d] - - - - - "
            % (self.N, reparameterized, n_repa_nodes, self.N))
        if self.N < 0:

            def array_to_string(y):
                return str(
                    map(lambda x: "%.3f" % x.detach().cpu().numpy()[0], y))

            logger.debug("lambdas: " + array_to_string(self.lambdas))
            logger.debug("target_mus: " + array_to_string(self.target_mus[1:]))
            logger.debug("target_kappas: "******"lambda_posts: " +
                         array_to_string(self.lambda_posts[1:]))
            logger.debug("lambda_tilde_posts: " +
                         array_to_string(self.lambda_tilde_posts))
            pyro.clear_param_store()

        adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)})
        elbo = TraceGraph_ELBO()
        loss_and_grads = elbo.loss_and_grads
        # loss_and_grads = elbo.jit_loss_and_grads  # This fails.
        svi = SVI(self.model,
                  self.guide,
                  adam,
                  loss=elbo.loss,
                  loss_and_grads=loss_and_grads)

        for step in range(n_steps):
            t0 = time.time()
            svi.step(reparameterized=reparameterized, difficulty=difficulty)

            if step % 5000 == 0 or step == n_steps - 1:
                kappa_errors, log_sig_errors, loc_errors = [], [], []
                for k in range(1, self.N + 1):
                    if k != self.N:
                        kappa_error = param_mse("kappa_q_%d" % k,
                                                self.target_kappas[k])
                        kappa_errors.append(kappa_error)

                    loc_errors.append(
                        param_mse("loc_q_%d" % k, self.target_mus[k]))
                    log_sig_error = param_mse(
                        "log_sig_q_%d" % k,
                        -0.5 * torch.log(self.lambda_posts[k]))
                    log_sig_errors.append(log_sig_error)

                max_errors = (np.max(loc_errors), np.max(log_sig_errors),
                              np.max(kappa_errors))
                min_errors = (np.min(loc_errors), np.min(log_sig_errors),
                              np.min(kappa_errors))
                mean_errors = (np.mean(loc_errors), np.mean(log_sig_errors),
                               np.mean(kappa_errors))
                logger.debug(
                    "[max errors]   (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)"
                    % max_errors)
                logger.debug(
                    "[min errors]   (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)"
                    % min_errors)
                logger.debug(
                    "[mean errors]  (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)"
                    % mean_errors)
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" %
                             (time.time() - t0, self.N, step))

        assert_equal(0.0, max_errors[0], prec=prec)
        assert_equal(0.0, max_errors[1], prec=prec)
        assert_equal(0.0, max_errors[2], prec=prec)
예제 #23
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(
                args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:

        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = [
        'window_size', 'rnn_hidden_size', 'decoder_output_bias',
        'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net',
        'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net',
        'non_linearity', 'pos_prior_mean', 'pos_prior_sd', 'scale_prior_mean',
        'scale_prior_sd'
    ]
    model_args = {
        key: getattr(args, key)
        for key in model_arg_keys if key in args
    }
    air = AIR(num_steps=args.model_steps,
              x_size=50,
              use_masking=not args.no_masking,
              use_baselines=not args.no_baselines,
              z_what_size=args.encoder_latent_size,
              use_cuda=args.cuda,
              **model_args)

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    # Viz sample from prior.
    if args.viz:
        vis = visdom.Visdom(env=args.visdom_env)
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def isBaselineParam(module_name, param_name):
        return 'bl_' in module_name or 'bl_' in param_name

    def per_param_optim_args(module_name, param_name):
        lr = args.baseline_learning_rate if isBaselineParam(
            module_name, param_name) else args.learning_rate
        return {'lr': lr}

    adam = optim.Adam(per_param_optim_args)
    elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
    svi = SVI(air.model, air.guide, adam, loss=elbo)

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X,
                        batch_size=args.batch_size,
                        z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior,
                                       trace=trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(
                X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(
                i, acc,
                counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]],
                                     tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
예제 #24
0
def embed(data, embed_args_dict, sample_args_dict):
    r"""Fits a combined matrix factorization and linear regression model on
    the observed network data.

    Args:
        data: This is expected to be the data object returned from
            bitcoin.BitcoinOTC(), and this script probably won't work
            otherwise.

        embed_args_dict (dict): A dictionary which is expected to contain
            the following keys:

            embed_dim (int): An integer specifying the embedding dimension.
            omega_model_scale (float): A positive float which specifies the
                prior variance on the embedding vectors.
            obs_scale (float): A positive float which specifies the variance
                for the observed logit scaled ratings.
            krein (bool): Specifies whether to use a Krein style inner
                product (a difference of two inner products; krein = True)
                or a regular inner product (krein = False) between
                embedding vectors for the matrix factorization part
                of the model. This implicitly assumes that the embedding
                dimension is even; if not, then an error will likely pop up
                somewhere.
            learning_rate (float): Learning rate for the ADAM optimizer.
            num_iters (int): Number of iterations to perform SVI.
            logging (bool): If True, then ELBO updates and the time ellapsed
                are output every 500 iterations, unless num_iters <= 1000
                in which case it is set to 100.

        sample_args_dict (dict): A dictionary with handles the subsampling
            procedure used on the data object. Check the docstring for the
            data object for the relevant details.

    Returns:
        logits (): The estimated logit probability of belonging to the
            truthfull class
    """
    # Extract keys from embed_args_dict
    embed_dim = embed_args_dict['embed_dim']
    omega_model_scale = embed_args_dict['omega_model_scale']
    obs_scale = embed_args_dict['obs_scale']
    learning_rate = embed_args_dict['learning_rate']
    num_iters = embed_args_dict['num_iters']
    logging_ind = embed_args_dict['logging']

    # Set logging printing rate
    if (num_iters > 1000):
        log_update = 500
    else:
        log_update = 100

    # Make sure logging actually works
    logging.basicConfig(format='%(message)s', level=logging.INFO)

    # Define guide object for embedding model
    def guide(data, node_ind, edge_ind, edge_list):
        r"""Defines a variational family to use to fit an approximate posterior
        distribution for the probability model defined in model."""
        # Deleting arguments not used in the guide for linting purposes
        del edge_ind, edge_list

        # Parameters governing the priors on the embedding vectors
        # omega_loc should have shape [embed_dum, data.num_nodes]
        omega_loc = pyro.param(
            'omega_loc', lambda: torch.randn(embed_dim, data.num_nodes) / np.
            sqrt(embed_dim))
        # omega_scale should be a single positive tensor
        omega_scale = pyro.param('omega_scale',
                                 torch.tensor(1.0),
                                 constraint=constraints.positive)

        # Paramaeters governing the prior fr the linear regression
        # beta_loc should be of shape [embed_dim]
        beta_loc = pyro.param('beta_loc', 0.5 * torch.randn(embed_dim))
        # beta_scale should be a single positive tensor
        beta_scale = pyro.param('beta_scale',
                                torch.tensor(1.0),
                                constraint=constraints.positive)
        # mu_loc should be a single tensor
        mu_loc = pyro.param('mu_loc', torch.tensor([0.0]))
        # mu_scale should be a single positive tensor
        mu_scale = pyro.param('mu_scale',
                              torch.tensor(1.0),
                              constraint=constraints.positive)

        # Sample the coefficient vector and intercept for linear regression
        beta = pyro.sample(
            'beta',
            dist.Normal(loc=beta_loc,
                        scale=beta_scale * torch.ones(embed_dim)).to_event(1))
        mu = pyro.sample('mu', dist.Normal(mu_loc, mu_scale).to_event(1))

        # Handle the subsampling of the embedding vectors
        with poutine.scale(scale=data.num_nodes / len(node_ind)):
            omega = pyro.sample(
                'omega',
                dist.Normal(loc=omega_loc[:, node_ind],
                            scale=omega_scale).to_event(2))

        return beta, mu, omega

    # Defines the model to use for SVI when using the usual inner product
    def model_ip(data, node_ind, edge_ind, edge_list):
        r"""Defines a probabilistic model for the observed network data."""
        # Define priors on the regression coefficients
        mu = pyro.sample(
            'mu',
            dist.Normal(torch.tensor([0.0]), torch.tensor([2.0])).to_event(1))

        beta = pyro.sample(
            'beta',
            dist.Normal(loc=torch.zeros(embed_dim),
                        scale=torch.tensor(2.0)).to_event(1))

        # Define prior on the embedding vectors, with subsampling
        with poutine.scale(scale=data.num_nodes / len(node_ind)):
            omega = pyro.sample(
                'omega',
                dist.Normal(loc=torch.zeros(embed_dim, len(node_ind)),
                            scale=omega_model_scale).to_event(2))

        # Before proceeding further, define a list t which acts as the
        # inverse function of node_ind - i.e it takes a number in node_ind
        # to its index location
        t = torch.zeros(node_ind.max() + 1, dtype=torch.long)
        t[node_ind] = torch.arange(len(node_ind))

        # Create mask corresponding to entries of ind which lie within the
        # training set (i.e data.train_nodes)
        gt_data = data.gt[node_ind]
        obs_mask = np.isin(node_ind, data.nodes_train).tolist()
        gt_data[gt_data != gt_data] = 0.0
        obs_mask = torch.tensor(obs_mask, dtype=torch.bool)

        # Compute logits, compute relevant parts of sample
        if sum(obs_mask) != 0:
            logit_prob = mu + torch.mv(omega.t(), beta)
            with poutine.scale(scale=len(data.nodes_train) / sum(obs_mask)):
                pyro.sample(
                    'trust',
                    dist.Bernoulli(logits=logit_prob[obs_mask]).independent(1),
                    obs=gt_data[obs_mask])

        # Begin extracting the relevant components of the gram matrix
        # formed by omega. Note that to extract the relevant indices,
        # we need to account for the change in indexing induced by
        # subsampling omega
        gram = torch.mm(omega.t(), omega)
        gram_sample = gram[t[edge_list[0, :]], t[edge_list[0, :]]]

        # Finally draw terms corresponding to the edges
        with poutine.scale(scale=data.num_edges / len(edge_ind)):
            pyro.sample('a',
                        dist.Normal(loc=gram_sample,
                                    scale=obs_scale).to_event(1),
                        obs=data.edge_weight_logit[edge_ind])

    # Defines the model to use for SVI when using the usual inner product
    def model_krein(data, node_ind, edge_ind, edge_list):
        r"""Defines a probabilistic model for the observed network data."""
        # Define priors on the regression coefficients
        mu = pyro.sample(
            'mu',
            dist.Normal(torch.tensor([0.0]), torch.tensor([2.0])).to_event(1))

        beta = pyro.sample(
            'beta',
            dist.Normal(loc=torch.zeros(embed_dim),
                        scale=torch.tensor(2.0)).to_event(1))

        # Define prior on the embedding vectors, with subsampling
        with poutine.scale(scale=data.num_nodes / len(node_ind)):
            omega = pyro.sample(
                'omega',
                dist.Normal(loc=torch.zeros(embed_dim, len(node_ind)),
                            scale=omega_model_scale).to_event(2))

        # Before proceeding further, define a list t which acts as the
        # inverse function of node_ind - i.e it takes a number in node_ind
        # to its index location
        t = torch.zeros(node_ind.max() + 1, dtype=torch.long)
        t[node_ind] = torch.arange(len(node_ind))

        # Create mask corresponding to entries of ind which lie within the
        # training set (i.e data.train_nodes)
        gt_data = data.gt[node_ind]
        obs_mask = np.isin(node_ind, data.nodes_train).tolist()
        gt_data[gt_data != gt_data] = 0.0
        obs_mask = torch.tensor(obs_mask, dtype=torch.bool)

        # Compute logits, compute relevant parts of sample
        if sum(obs_mask) != 0:
            logit_prob = mu + torch.mv(omega.t(), beta)
            with poutine.scale(scale=len(data.nodes_train) / sum(obs_mask)):
                pyro.sample(
                    'trust',
                    dist.Bernoulli(logits=logit_prob[obs_mask]).independent(1),
                    obs=gt_data[obs_mask])

        # Begin extracting the relevant components of the gram matrix
        # formed by omega. Note that to extract the relevant indices,
        # we need to account for the change in indexing induced by
        # subsampling omega
        gram_pos = torch.mm(omega[:int(embed_dim / 2), :].t(),
                            omega[:int(embed_dim / 2), :])
        gram_neg = torch.mm(omega[int(embed_dim / 2):, :].t(),
                            omega[int(embed_dim / 2):, :])
        gram = gram_pos - gram_neg
        gram_sample = gram[t[edge_list[0, :]], t[edge_list[0, :]]]

        # Finally draw terms corresponding to the edges
        with poutine.scale(scale=data.num_edges / len(edge_ind)):
            pyro.sample('a',
                        dist.Normal(loc=gram_sample,
                                    scale=obs_scale).to_event(1),
                        obs=data.edge_weight_logit[edge_ind])

    # Define SVI object depending on if we're using a positive definite
    # bilinear form on embedding vectors or the Krein inner product
    if embed_args_dict['krein']:
        svi = SVI(model_krein,
                  guide,
                  optim.Adam({"lr": learning_rate}),
                  loss=TraceGraph_ELBO())
    else:
        svi = SVI(model_ip,
                  guide,
                  optim.Adam({"lr": learning_rate}),
                  loss=TraceGraph_ELBO())

    # Begin optimization
    # Keep track of time/optizing if desired
    if logging_ind:
        time_store = []
        t0 = time.time()
        elbo = []

    pyro.clear_param_store()
    for i in range(num_iters):
        # Really bad error handling for when the subsampling code for the
        # random walk decides to break
        count = 0
        while (count < 20):
            try:
                subsample_dict = data.subsample(**sample_args_dict)
                count = 30
            except IndexError:
                count += 1

        elbo_val = svi.step(data, **subsample_dict)
        if logging_ind & (i % log_update == 0) & (i > 0):
            elbo.append(elbo_val)
            t1 = time.time()
            time_store.append(t1 - t0)
            logging.info('Elbo loss: {}'.format(elbo_val))
            logging.info('Expected completion time: {}s'.format(
                int(np.average(time_store) * (num_iters - i) / log_update)))
            t0 = time.time()

    # Extract the variational parameters and return them
    vp_dict = {}
    vp_dict['mu_loc'] = pyro.param('mu_loc')
    vp_dict['beta_loc'] = pyro.param('beta_loc')
    vp_dict['omega_loc'] = pyro.param('omega_loc')
    vp_dict['mu_scale'] = pyro.param('mu_scale')
    vp_dict['beta_scale'] = pyro.param('beta_scale')
    vp_dict['omega_scale'] = pyro.param('omega_scale')

    if 'elbo' in locals():
        return vp_dict, elbo
    else:
        return vp_dict
예제 #25
0
def do_inference(use_baseline=True):
    def guide(data):

        a = pyro.param("alpha_q",
                       torch.tensor(15.0),
                       constraint=constraints.positive)
        b = pyro.param("beta_q",
                       torch.tensor(15.0),
                       constraint=constraints.positive)

        baseline = dict(baseline={
            'use_decaying_avg_baseline': use_baseline,
            'baseline_beta': 0.9
        })
        pyro.sample("z_fairness", NonreparameterizedBeta(a, b), infer=baseline)

    # Inference
    pyro.clear_param_store()

    adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
    optimizer = Adam(adam_params)

    alpha0, beta0 = 10.0, 10.0
    svi = SVI(model(alpha0, beta0), guide, optimizer, loss=TraceGraph_ELBO())

    NPos = 10
    NNeg = 5
    data = torch.tensor(NPos * [1.0] + NNeg * [0.0])

    def param_abs_error(name, target):
        return torch.sum(torch.abs(target - pyro.param(name))).item()

    # True parameters

    true_alpha = data.sum() + alpha0
    true_beta = len(data) - data.sum() + beta0

    # Run

    n_steps = 10000
    for step in tqdm(range(n_steps)):
        svi.step(data)

        # compute the distance to the parameters of the true posterior
        alpha_error = param_abs_error("alpha_q", true_alpha)
        beta_error = param_abs_error("beta_q", true_beta)

        # stop inference early if we're close to the true posterior
        if alpha_error < 0.8 and beta_error < 0.8:
            break

    alpha_q = pyro.param("alpha_q").item()
    beta_q = pyro.param("beta_q").item()

    inferred_mean = alpha_q / (alpha_q + beta_q)
    factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
    inferred_std = inferred_mean * math.sqrt(factor)

    print("Parameters after {} steps: {} {} {} {}".format(
        step, true_alpha, true_beta, alpha_q, beta_q))
    print("\nbased on the data and our prior belief, the fairness " +
          "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))
    print("True posterior based on counting real and pseudoflips: {}".format(
        (NPos + 10.0) / (NPos + NNeg + 20.0)))
예제 #26
0
    def _test_plate_in_elbo(self,
                            n_superfluous_top,
                            n_superfluous_bottom,
                            n_steps,
                            lr=0.0012):
        pyro.clear_param_store()
        self.data_tensor = torch.zeros(9, 2)
        for _out in range(self.n_outer):
            for _in in range(self.n_inner):
                self.data_tensor[3 * _out + _in, :] = self.data[_out][_in]
        self.data_as_list = [
            self.data_tensor[0:4, :],
            self.data_tensor[4:7, :],
            self.data_tensor[7:9, :],
        ]

        def model():
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(self.loc0,
                                               torch.pow(self.lam0,
                                                         -0.5)).to_event(1),
            )

            for i in pyro.plate("outer", 3):
                x_i = self.data_as_list[i]
                with pyro.plate("inner_%d" % i, x_i.size(0)):
                    for k in range(n_superfluous_top):
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(0, 1).expand_by(
                                [4 - i]),
                        )
                        assert z_i_k.shape == (4 - i, )
                    obs_i = pyro.sample(
                        "obs_%d" % i,
                        dist.Normal(loc_latent, torch.pow(self.lam,
                                                          -0.5)).to_event(1),
                        obs=x_i,
                    )
                    assert obs_i.shape == (4 - i, 2)
                    for k in range(n_superfluous_top,
                                   n_superfluous_top + n_superfluous_bottom):
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(0, 1).expand_by(
                                [4 - i]),
                        )
                        assert z_i_k.shape == (4 - i, )

        pt_loc_baseline = torch.nn.Linear(1, 1)
        pt_superfluous_baselines = []
        for k in range(n_superfluous_top + n_superfluous_bottom):
            pt_superfluous_baselines.extend([
                torch.nn.Linear(2, 4),
                torch.nn.Linear(2, 3),
                torch.nn.Linear(2, 2)
            ])

        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.07)
            sig_q = torch.exp(log_sig_q)
            trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline)
            baseline_value = trivial_baseline(torch.ones(1)).squeeze()
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1),
                infer=dict(baseline=dict(baseline_value=baseline_value)),
            )

            for i in pyro.plate("outer", 3):
                with pyro.plate("inner_%d" % i, 4 - i):
                    for k in range(n_superfluous_top + n_superfluous_bottom):
                        z_baseline = pyro.module(
                            "z_baseline_%d_%d" % (i, k),
                            pt_superfluous_baselines[3 * k + i],
                        )
                        baseline_value = z_baseline(loc_latent.detach())
                        mean_i = pyro.param("mean_%d_%d" % (i, k),
                                            0.5 * torch.ones(4 - i))
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(mean_i, 1),
                            infer=dict(baseline=dict(
                                baseline_value=baseline_value)),
                        )
                        assert z_i_k.shape == (4 - i, )

        def per_param_callable(param_name):
            if "baseline" in param_name:
                return {"lr": 0.010, "betas": (0.95, 0.999)}
            else:
                return {"lr": lr, "betas": (0.95, 0.999)}

        adam = optim.Adam(per_param_callable)
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for step in range(n_steps):
            svi.step()

            loc_error = param_abs_error("loc_q", self.analytic_loc_n)
            log_sig_error = param_abs_error("log_sig_q",
                                            self.analytic_log_sig_n)

            if n_superfluous_top > 0 or n_superfluous_bottom > 0:
                superfluous_errors = []
                for k in range(n_superfluous_top + n_superfluous_bottom):
                    mean_0_error = torch.sum(
                        torch.pow(pyro.param("mean_0_%d" % k), 2.0))
                    mean_1_error = torch.sum(
                        torch.pow(pyro.param("mean_1_%d" % k), 2.0))
                    mean_2_error = torch.sum(
                        torch.pow(pyro.param("mean_2_%d" % k), 2.0))
                    superfluous_error = torch.max(
                        torch.max(mean_0_error, mean_1_error), mean_2_error)
                    superfluous_errors.append(
                        superfluous_error.detach().cpu().numpy())

            if step % 500 == 0:
                logger.debug("loc error, log(scale) error:  %.4f, %.4f" %
                             (loc_error, log_sig_error))
                if n_superfluous_top > 0 or n_superfluous_bottom > 0:
                    logger.debug("superfluous error: %.4f" %
                                 np.max(superfluous_errors))

        assert_equal(0.0, loc_error, prec=0.04)
        assert_equal(0.0, log_sig_error, prec=0.05)
        if n_superfluous_top > 0 or n_superfluous_bottom > 0:
            assert_equal(0.0, np.max(superfluous_errors), prec=0.04)
예제 #27
0
def test_elbo_mapdata(map_type, batch_size, n_steps, lr):
    # normal-normal: known covariance
    lam0 = torch.tensor([0.1, 0.1])  # precision of prior
    loc0 = torch.tensor([0.0, 0.5])  # prior mean
    # known precision of observation noise
    lam = torch.tensor([6.0, 4.0])
    data = []
    sum_data = torch.zeros(2)

    def add_data_point(x, y):
        data.append(torch.tensor([x, y]))
        sum_data.data.add_(data[-1].data)

    add_data_point(0.1, 0.21)
    add_data_point(0.16, 0.11)
    add_data_point(0.06, 0.31)
    add_data_point(-0.01, 0.07)
    add_data_point(0.23, 0.25)
    add_data_point(0.19, 0.18)
    add_data_point(0.09, 0.41)
    add_data_point(-0.04, 0.17)

    data = torch.stack(data)
    n_data = torch.tensor([float(len(data))])
    analytic_lam_n = lam0 + n_data.expand_as(lam) * lam
    analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n)
    analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * (
        lam0 / analytic_lam_n)

    logger.debug("DOING ELBO TEST [bs = {}, map_type = {}]".format(
        batch_size, map_type))
    pyro.clear_param_store()

    def model():
        loc_latent = pyro.sample(
            "loc_latent",
            dist.Normal(loc0, torch.pow(lam0, -0.5)).to_event(1))
        if map_type == "iplate":
            for i in pyro.plate("aaa", len(data), batch_size):
                pyro.sample(
                    "obs_%d" % i,
                    dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1),
                    obs=data[i],
                ),
        elif map_type == "plate":
            with pyro.plate("aaa", len(data), batch_size) as ind:
                pyro.sample(
                    "obs",
                    dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1),
                    obs=data[ind],
                ),
        else:
            for i, x in enumerate(data):
                pyro.sample(
                    "obs_%d" % i,
                    dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1),
                    obs=x,
                )
        return loc_latent

    def guide():
        loc_q = pyro.param(
            "loc_q",
            analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23]))
        log_sig_q = pyro.param(
            "log_sig_q",
            analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23]),
        )
        sig_q = torch.exp(log_sig_q)
        pyro.sample("loc_latent", dist.Normal(loc_q, sig_q).to_event(1))
        if map_type == "iplate" or map_type is None:
            for i in pyro.plate("aaa", len(data), batch_size):
                pass
        elif map_type == "plate":
            # dummy plate to do subsampling for observe
            with pyro.plate("aaa", len(data), batch_size):
                pass
        else:
            pass

    adam = optim.Adam({"lr": lr})
    svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

    for k in range(n_steps):
        svi.step()

        loc_error = torch.sum(
            torch.pow(analytic_loc_n - pyro.param("loc_q"), 2.0))
        log_sig_error = torch.sum(
            torch.pow(analytic_log_sig_n - pyro.param("log_sig_q"), 2.0))

        if k % 500 == 0:
            logger.debug("errors - {}, {}".format(loc_error, log_sig_error))

    assert_equal(loc_error.item(), 0, prec=0.05)
    assert_equal(log_sig_error.item(), 0, prec=0.06)
print('15')

def isBaselineParam(module_name, param_name):
    return 'bl_' in module_name or 'bl_' in param_name

print('16')

def per_param_optim_args(module_name, param_name):
    lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate
    return {'lr': lr}

print('17')

adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)

print('18')

t0 = time.time()
examples_to_viz = X[5:10]
print('19')
print('Epochs starting')

for i in range(1, args.num_steps + 1):
    print("Epoch: {}".format(i))
    loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

    if args.progress_every > 0 and i % args.progress_every == 0:
        print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(