Ejemplo n.º 1
0
    def step(self, *args, **kwargs):
        """
        :returns: estimate of the loss
        :rtype: float

        Take a gradient step on the loss function (and any auxiliary loss functions
        generated under the hood by `loss_and_grads`).
        Any args or kwargs are passed to the model and guide
        """
        # get loss and compute gradients
        loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

        # get active params
        params = pyro.get_param_store().get_active_params()

        # actually perform gradient steps
        # torch.optim objects gets instantiated for any params that haven't been seen yet
        self.optim(params)

        # zero gradients
        pyro.util.zero_grads(params)

        # mark parameters in the param store as inactive
        pyro.get_param_store().mark_params_inactive(params)

        return loss
Ejemplo n.º 2
0
def test_hmc_conjugate_gaussian(fixture,
                                num_samples,
                                warmup_steps,
                                hmc_params,
                                expected_means,
                                expected_precs,
                                mean_tol,
                                std_tol):
    pyro.get_param_store().clear()
    hmc_kernel = HMC(fixture.model, **hmc_params)
    mcmc_run = MCMC(hmc_kernel, num_samples, warmup_steps).run(fixture.data)
    for i in range(1, fixture.chain_len + 1):
        param_name = 'loc_' + str(i)
        marginal = EmpiricalMarginal(mcmc_run, sites=param_name)
        latent_loc = marginal.mean
        latent_std = marginal.variance.sqrt()
        expected_mean = torch.ones(fixture.dim) * expected_means[i - 1]
        expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1])

        # Actual vs expected posterior means for the latents
        logger.info('Posterior mean (actual) - {}'.format(param_name))
        logger.info(latent_loc)
        logger.info('Posterior mean (expected) - {}'.format(param_name))
        logger.info(expected_mean)
        assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol)

        # Actual vs expected posterior precisions for the latents
        logger.info('Posterior std (actual) - {}'.format(param_name))
        logger.info(latent_std)
        logger.info('Posterior std (expected) - {}'.format(param_name))
        logger.info(expected_std)
        assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
Ejemplo n.º 3
0
def test_module_nn(nn_module):
    pyro.clear_param_store()
    nn_module = nn_module()
    assert pyro.get_param_store()._params == {}
    pyro.module("module", nn_module)
    for name in pyro.get_param_store().get_all_param_names():
        assert pyro.params.user_param_name(name) in nn_module.state_dict().keys()
Ejemplo n.º 4
0
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
Ejemplo n.º 5
0
    def __call__(self, params,  *args, **kwargs):
        """
        :param params: a list of parameters
        :type params: an iterable of strings

        Do an optimization step for each param in params. If a given param has never been seen before,
        initialize an optimizer for it.
        """

        for p in params:
            # if we have not seen this param before, we instantiate and optim object to deal with it
            if p not in self.optim_objs:
                # get our constructor arguments
                def_optim_dict = self._get_optim_args(p)
                # create a single optim object for that param
                self.optim_objs[p] = self.pt_optim_constructor([p], **def_optim_dict)

                # set state from _state_waiting_to_be_consumed if present
                param_name = pyro.get_param_store().param_name(p)
                if param_name in self._state_waiting_to_be_consumed:
                    state = self._state_waiting_to_be_consumed.pop(param_name)
                    self.optim_objs[p].load_state_dict(state)

            # actually perform the step for the optim object
            self.optim_objs[p].step(*args, **kwargs)
Ejemplo n.º 6
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.enable_validation()

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss=Trace_ELBO())

    # Data is an arbitrary json-like structure with tensors at leaves.
    one = torch.tensor(1.0)
    data = {
        "foo": one,
        "bar": [0 * one, 1 * one, 2 * one],
        "baz": {
            "noun": {
                "concrete": 4 * one,
                "abstract": 6 * one,
            },
            "verb": 2 * one,
        },
    }

    print('Step\tLoss')
    loss = 0.0
    for step in range(args.num_epochs):
        loss += inference.step(data)
        if step and step % 10 == 0:
            print('{}\t{:0.5g}'.format(step, loss))
            loss = 0.0

    print('Parameters:')
    for name in sorted(pyro.get_param_store().get_all_param_names()):
        print('{} = {}'.format(name, pyro.param(name).detach().cpu().numpy()))
Ejemplo n.º 7
0
def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs):
    # construct all the reinforce-like terms.
    # we include only downstream costs to reduce variance
    # optionally include baselines to further reduce variance
    # XXX should the average baseline be in the param store as below?
    surrogate_elbo = 0.0
    baseline_loss = 0.0
    for node in non_reparam_nodes:
        guide_site = guide_trace.nodes[node]
        downstream_cost = downstream_costs[node]
        baseline = 0.0
        (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta,
            baseline_value) = _get_baseline_options(guide_site)
        use_nn_baseline = nn_baseline is not None
        use_baseline_value = baseline_value is not None
        assert(not (use_nn_baseline and use_baseline_value)), \
            "cannot use baseline_value and nn_baseline simultaneously"
        if use_decaying_avg_baseline:
            dc_shape = downstream_cost.shape
            param_name = "__baseline_avg_downstream_cost_" + node
            with torch.no_grad():
                avg_downstream_cost_old = pyro.param(param_name,
                                                     guide_site['value'].new_zeros(dc_shape))
                avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
                    baseline_beta * avg_downstream_cost_old
            pyro.get_param_store().replace_param(param_name, avg_downstream_cost_new,
                                                 avg_downstream_cost_old)
            baseline += avg_downstream_cost_old
        if use_nn_baseline:
            # block nn_baseline_input gradients except in baseline loss
            baseline += nn_baseline(detach_iterable(nn_baseline_input))
        elif use_baseline_value:
            # it's on the user to make sure baseline_value tape only points to baseline params
            baseline += baseline_value
        if use_nn_baseline or use_baseline_value:
            # accumulate baseline loss
            baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum()

        score_function_term = guide_site["score_parts"].score_function
        if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
            if downstream_cost.shape != baseline.shape:
                raise ValueError("Expected baseline at site {} to be {} instead got {}".format(
                    node, downstream_cost.shape, baseline.shape))
            downstream_cost = downstream_cost - baseline
        surrogate_elbo += (score_function_term * downstream_cost.detach()).sum()

    return surrogate_elbo, baseline_loss
Ejemplo n.º 8
0
def test_elbo_hmm_in_guide(enumerate1, num_steps):
    pyro.clear_param_store()
    data = torch.ones(num_steps)
    init_probs = torch.tensor([0.5, 0.5])

    def model(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                      constraint=constraints.simplex)
        emission_probs = pyro.param("emission_probs",
                                    torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                    constraint=constraints.simplex)

        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
            pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y)

    @config_enumerate(default=enumerate1)
    def guide(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                      constraint=constraints.simplex)
        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))

    elbo = TraceEnum_ELBO(max_iarange_nesting=0)
    elbo.loss_and_grads(model, guide, data)

    # These golden values simply test agreement between parallel and sequential.
    expected_grads = {
        2: {
            "transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]],
            "emission_probs": [[0.75, -0.75], [0.25, -0.25]],
        },
        3: {
            "transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]],
            "emission_probs": [[1.125, -1.125], [0.375, -0.375]],
        },
        10: {
            "transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]],
            "emission_probs": [[3.75, -3.75], [1.25, -1.25]],
        },
        20: {
            "transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]],
            "emission_probs": [[7.5, -7.5], [2.5, -2.5]],
        },
    }

    for name, value in pyro.get_param_store().named_parameters():
        actual = value.grad
        expected = torch.tensor(expected_grads[num_steps][name])
        assert_equal(actual, expected, msg=''.join([
            '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()),
            '\n  actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()),
        ]))
Ejemplo n.º 9
0
 def get_state(self):
     """
     Get state associated with all the optimizers in the form of a dictionary with
     key-value pairs (parameter name, optim state dicts)
     """
     state_dict = {}
     for param in self.optim_objs:
         param_name = pyro.get_param_store().param_name(param)
         state_dict[param_name] = self.optim_objs[param].state_dict()
     return state_dict
Ejemplo n.º 10
0
    def test_save_and_load(self):
        lin = pyro.module("mymodule", self.linear_module)
        pyro.module("mymodule2", self.linear_module2)
        x = torch.randn(1, 3)
        myparam = pyro.param("myparam", torch.tensor(1.234 * torch.ones(1), requires_grad=True))

        cost = torch.sum(torch.pow(lin(x), 2.0)) * torch.pow(myparam, 4.0)
        cost.backward()
        params = list(self.linear_module.parameters()) + [myparam]
        optim = torch.optim.Adam(params, lr=.01)
        myparam_copy_stale = copy(pyro.param("myparam").detach().cpu().numpy())

        optim.step()

        myparam_copy = copy(pyro.param("myparam").detach().cpu().numpy())
        param_store_params = copy(pyro.get_param_store()._params)
        param_store_param_to_name = copy(pyro.get_param_store()._param_to_name)
        assert len(list(param_store_params.keys())) == 5
        assert len(list(param_store_param_to_name.values())) == 5

        pyro.get_param_store().save('paramstore.unittest.out')
        pyro.clear_param_store()
        assert len(list(pyro.get_param_store()._params)) == 0
        assert len(list(pyro.get_param_store()._param_to_name)) == 0
        pyro.get_param_store().load('paramstore.unittest.out')

        def modules_are_equal():
            weights_equal = np.sum(np.fabs(self.linear_module3.weight.detach().cpu().numpy() -
                                   self.linear_module.weight.detach().cpu().numpy())) == 0.0
            bias_equal = np.sum(np.fabs(self.linear_module3.bias.detach().cpu().numpy() -
                                self.linear_module.bias.detach().cpu().numpy())) == 0.0
            return (weights_equal and bias_equal)

        assert not modules_are_equal()
        pyro.module("mymodule", self.linear_module3, update_module_params=False)
        assert id(self.linear_module3.weight) != id(pyro.param('mymodule$$$weight'))
        assert not modules_are_equal()
        pyro.module("mymodule", self.linear_module3, update_module_params=True)
        assert id(self.linear_module3.weight) == id(pyro.param('mymodule$$$weight'))
        assert modules_are_equal()

        myparam = pyro.param("myparam")
        store = pyro.get_param_store()
        assert myparam_copy_stale != myparam.detach().cpu().numpy()
        assert myparam_copy == myparam.detach().cpu().numpy()
        assert sorted(param_store_params.keys()) == sorted(store._params.keys())
        assert sorted(param_store_param_to_name.values()) == sorted(store._param_to_name.values())
        assert sorted(store._params.keys()) == sorted(store._param_to_name.values())
Ejemplo n.º 11
0
    def _get_optim_args(self, param):
        # if we were passed a fct, we call fct with param info
        # arguments are (module name, param name, tags) e.g. ('mymodule', 'bias', 'baseline')
        if callable(self.pt_optim_args):

            # get param name
            param_name = pyro.get_param_store().param_name(param)
            module_name = module_from_param_with_module_name(param_name)
            stripped_param_name = user_param_name(param_name)

            # get tags
            tags = pyro.get_param_store().get_param_tags(param_name)

            # invoke the user-provided callable
            opt_dict = self.pt_optim_args(module_name, stripped_param_name, tags)

            # must be dictionary
            assert isinstance(opt_dict, dict), "per-param optim arg must return defaults dictionary"
            return opt_dict
        else:
            return self.pt_optim_args
Ejemplo n.º 12
0
def finite_difference(eval_loss, delta=0.1):
    """
    Computes finite-difference approximation of all parameters.
    """
    params = pyro.get_param_store().get_all_param_names()
    assert params, "no params found"
    grads = {name: Variable(torch.zeros(pyro.param(name).size())) for name in params}
    for name in sorted(params):
        value = pyro.param(name).data
        for index in itertools.product(*map(range, value.size())):
            center = value[index]
            value[index] = center + delta
            pos = eval_loss()
            value[index] = center - delta
            neg = eval_loss()
            value[index] = center
            grads[name][index] = (pos - neg) / (2 * delta)
    return grads
Ejemplo n.º 13
0
def test_iarange(Elbo, reparameterized):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    num_particles = 20000
    precision = 0.06
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    @poutine.broadcast
    def model():
        particles_iarange = pyro.iarange("particles", num_particles, dim=-2)
        data_iarange = pyro.iarange("data", len(data), dim=-1)

        pyro.sample("nuisance_a", Normal(0, 1))
        with particles_iarange, data_iarange:
            z = pyro.sample("z", Normal(0, 1))
        pyro.sample("nuisance_b", Normal(2, 3))
        with data_iarange, particles_iarange:
            pyro.sample("x", Normal(z, 1), obs=data)
        pyro.sample("nuisance_c", Normal(4, 5))

    @poutine.broadcast
    def guide():
        loc = pyro.param("loc", torch.zeros(len(data)))
        scale = pyro.param("scale", torch.tensor([1.]))

        pyro.sample("nuisance_c", Normal(4, 5))
        with pyro.iarange("particles", num_particles, dim=-2):
            with pyro.iarange("data", len(data), dim=-1):
                pyro.sample("z", Normal(loc, scale))
        pyro.sample("nuisance_b", Normal(2, 3))
        pyro.sample("nuisance_a", Normal(0, 1))

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss=Elbo(strict_enumeration_warning=False))
    inference.loss_and_grads(model, guide)
    params = dict(pyro.get_param_store().named_parameters())
    actual_grads = {name: param.grad.detach().cpu().numpy() / num_particles
                    for name, param in params.items()}

    expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])}
    for name in sorted(params):
        logger.info('expected {} = {}'.format(name, expected_grads[name]))
        logger.info('actual   {} = {}'.format(name, actual_grads[name]))
    assert_equal(actual_grads, expected_grads, prec=precision)
Ejemplo n.º 14
0
def test_elbo_hmm_in_model(enumerate1, num_steps):
    pyro.clear_param_store()
    data = torch.ones(num_steps)
    init_probs = torch.tensor([0.5, 0.5])

    def model(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.9, 0.1], [0.1, 0.9]]),
                                      constraint=constraints.simplex)
        locs = pyro.param("obs_locs", torch.tensor([-1.0, 1.0]))
        scale = pyro.param("obs_scale", torch.tensor(1.0),
                           constraint=constraints.positive)

        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
            pyro.sample("y_{}".format(i), dist.Normal(locs[x], scale), obs=y)

    @config_enumerate(default=enumerate1)
    def guide(data):
        mean_field_probs = pyro.param("mean_field_probs", torch.ones(num_steps, 2) / 2,
                                      constraint=constraints.simplex)
        for i in range(num_steps):
            pyro.sample("x_{}".format(i), dist.Categorical(mean_field_probs[i]))

    elbo = TraceEnum_ELBO(max_iarange_nesting=0)
    elbo.loss_and_grads(model, guide, data)

    expected_unconstrained_grads = {
        "transition_probs": torch.tensor([[0.2, -0.2], [-0.2, 0.2]]) * (num_steps - 1),
        "obs_locs": torch.tensor([-num_steps, 0]),
        "obs_scale": torch.tensor(-num_steps),
        "mean_field_probs": torch.tensor([[0.5, -0.5]] * num_steps),
    }

    for name, value in pyro.get_param_store().named_parameters():
        actual = value.grad
        expected = expected_unconstrained_grads[name]
        assert_equal(actual, expected, msg=''.join([
            '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()),
            '\n  actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()),
        ]))
Ejemplo n.º 15
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.enable_validation()

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss=Trace_ELBO())
    data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
    k = 2

    print('Step\tLoss')
    loss = 0.0
    for step in range(args.num_epochs):
        if step and step % 10 == 0:
            print('{}\t{:0.5g}'.format(step, loss))
            loss = 0.0
        loss += inference.step(data, k)

    print('Parameters:')
    for name in sorted(pyro.get_param_store().get_all_param_names()):
        print('{} = {}'.format(name, pyro.param(name).detach().cpu().numpy()))
Ejemplo n.º 16
0
def test_subsample_gradient(trace_graph, reparameterized):
    pyro.clear_param_store()
    data_size = 2
    subsample_size = 1
    num_particles = 1000
    precision = 0.333
    data = dist.normal(ng_zeros(data_size), ng_ones(data_size))

    def model(subsample_size):
        with pyro.iarange("data", len(data), subsample_size) as ind:
            x = data[ind]
            z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)),
                                             reparameterized=reparameterized))
            pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x)

    def guide(subsample_size):
        mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
        sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
        with pyro.iarange("data", len(data), subsample_size) as ind:
            mu = mu[ind]
            sigma = sigma.expand(subsample_size)
            pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss="ELBO",
                    trace_graph=trace_graph, num_particles=num_particles)

    # Compute gradients without subsampling.
    inference.loss_and_grads(model, guide, subsample_size=data_size)
    params = dict(pyro.get_param_store().named_parameters())
    expected_grads = {name: param.grad.data.clone() for name, param in params.items()}
    zero_grads(params.values())

    # Compute gradients with subsampling.
    inference.loss_and_grads(model, guide, subsample_size=subsample_size)
    actual_grads = {name: param.grad.data.clone() for name, param in params.items()}

    for name in sorted(params):
        print('\nexpected {} = {}'.format(name, expected_grads[name].cpu().numpy()))
        print('actual   {} = {}'.format(name, actual_grads[name].cpu().numpy()))
    assert_equal(actual_grads, expected_grads, prec=precision)
Ejemplo n.º 17
0
def test_subsample_gradient(Elbo, reparameterized, subsample):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    num_particles = 50000
    precision = 0.06
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    def model(subsample):
        with pyro.iarange("particles", num_particles):
            with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
                x = data[ind].unsqueeze(-1).expand(-1, num_particles)
                z = pyro.sample("z", Normal(0, 1).expand_by(x.shape))
                pyro.sample("x", Normal(z, 1), obs=x)

    def guide(subsample):
        loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True))
        scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))
        with pyro.iarange("particles", num_particles):
            with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
                loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles)
                pyro.sample("z", Normal(loc_ind, scale))

    optim = Adam({"lr": 0.1})
    elbo = Elbo(strict_enumeration_warning=False)
    inference = SVI(model, guide, optim, loss=elbo)
    if subsample_size == 1:
        inference.loss_and_grads(model, guide, subsample=torch.LongTensor([0]))
        inference.loss_and_grads(model, guide, subsample=torch.LongTensor([1]))
    else:
        inference.loss_and_grads(model, guide, subsample=torch.LongTensor([0, 1]))
    params = dict(pyro.get_param_store().named_parameters())
    normalizer = 2 * num_particles / subsample_size
    actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()}

    expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])}
    for name in sorted(params):
        logger.info('expected {} = {}'.format(name, expected_grads[name]))
        logger.info('actual   {} = {}'.format(name, actual_grads[name]))
    assert_equal(actual_grads, expected_grads, prec=precision)
Ejemplo n.º 18
0
def train(args, dataset):
    """
    Train a model and guide to fit a dataset.
    """
    counts = dataset["counts"]
    num_stations = len(dataset["stations"])
    logging.info(
        "Training on {} stations over {} hours, {} batches/epoch".format(
            num_stations, len(counts),
            int(math.ceil(len(counts) / args.batch_size))))
    time_features = make_time_features(args, 0, len(counts))
    control_features = (counts.max(1)[0] + counts.max(2)[0]).clamp(max=1)
    logging.info(
        "On average {:0.1f}/{} stations are open at any one time".format(
            control_features.sum(-1).mean(), num_stations))
    features = torch.cat([time_features, control_features], -1)
    feature_dim = features.size(-1)
    logging.info("feature_dim = {}".format(feature_dim))
    metadata = {"args": args, "losses": [], "control": control_features}
    torch.save(metadata, args.training_filename)

    def optim_config(module_name, param_name):
        config = {
            "lr": args.learning_rate,
            "betas": (0.8, 0.99),
            "weight_decay": 0.01**(1 / args.num_steps),
        }
        if param_name == "init_scale":
            config["lr"] *= 0.1  # init_dist sees much less data per minibatch
        return config

    training_counts = counts[:args.truncate] if args.truncate else counts
    data_size = len(training_counts)
    model = Model(args, features, training_counts).to(device=args.device)
    guide = Guide(args, features, training_counts).to(device=args.device)
    elbo = (TraceMeanField_ELBO if args.analytic_kl else Trace_ELBO)()
    optim = ClippedAdam(optim_config)
    svi = SVI(model, guide, optim, elbo)
    losses = []
    forecaster = None
    for step in range(args.num_steps):
        begin_time = torch.randint(max(1, data_size - args.batch_size),
                                   ()).item()
        end_time = min(data_size, begin_time + args.batch_size)
        feature_batch = features[begin_time:end_time].to(device=args.device)
        counts_batch = counts[begin_time:end_time].to(device=args.device)
        loss = svi.step(feature_batch, counts_batch) / counts_batch.numel()
        assert math.isfinite(loss), loss
        losses.append(loss)
        logging.debug("step {} loss = {:0.4g}".format(step, loss))

        if step % 20 == 0:
            # Save state every few steps.
            pyro.get_param_store().save(args.param_store_filename)
            metadata = {
                "args": args,
                "losses": losses,
                "control": control_features
            }
            torch.save(metadata, args.training_filename)
            forecaster = Forecaster(args, dataset, features, model, guide)
            torch.save(forecaster, args.forecaster_filename)

            if logging.Logger(None).isEnabledFor(logging.DEBUG):
                init_scale = pyro.param("init_scale").data
                trans_scale = pyro.param("trans_scale").data
                trans_matrix = pyro.param("trans_matrix").data
                eigs = trans_matrix.eig()[0].norm(dim=-1).sort(
                    descending=True).values
                logging.debug("guide.diag_part = {}".format(
                    guide.diag_part.data.squeeze()))
                logging.debug(
                    "init scale min/mean/max: {:0.3g} {:0.3g} {:0.3g}".format(
                        init_scale.min(), init_scale.mean(), init_scale.max()))
                logging.debug(
                    "trans scale min/mean/max: {:0.3g} {:0.3g} {:0.3g}".format(
                        trans_scale.min(), trans_scale.mean(),
                        trans_scale.max()))
                logging.debug("trans mat eig:\n{}".format(eigs))

    return forecaster
Ejemplo n.º 19
0
    def transform(self,
                  X: np.ndarray,
                  num_samples: int = 1000,
                  random_state: int = None,
                  mean_estimate: bool = False) -> np.ndarray:
        """
        After model calibration, this function is used to get calibrated outputs of uncalibrated
        confidence estimates.

        Parameters
        ----------
        X : np.ndarray, shape=(n_samples, [n_classes]) or (n_samples, [n_box_features])
            NumPy array with confidence values for each prediction on classification with shapes
            1-D for binary classification, 2-D for multi class (softmax).
            On detection, this array must have 2 dimensions with number of additional box features in last dim.
        num_samples : int, optional, default: 1000
            Number of samples generated on MCMC sampling or Variational Inference.
        random_state : int, optional, default: None
            Fix the random seed for the random number
        mean_estimate : bool, optional, default: False
            If True, directly return the mean on probabilistic methods like MCMC or VI instead of the full
            distribution. This parameter has no effect on MLE.

        Returns
        -------
        np.ndarray, shape=(n_samples, [n_classes]) on MLE or on MCMC/VI if 'mean_estimate' is True
        or shape=(n_parameters, n_samples, [n_classes]) on VI, MCMC if 'mean_estimate' is False
            On MLE without uncertainty, return NumPy array with calibrated confidence estimates.
            1-D for binary classification, 2-D for multi class (softmax).
            On VI or MCMC, return NumPy array with leading dimension as the number of sampled parameters from the
            log regression parameter distribution obtained by VI or MCMC.
        """
        def process_model(weights: dict) -> torch.Tensor:
            """ Fix model weights to the weight vector given as the parameter and return calibrated data. """

            # model will return pytorch tensor
            model = pyro.condition(self.model, data=weights)
            logit = model(data)

            # distinguish between detection, binary and multiclass classification
            if self.detection or self._is_binary_classification():
                calibrated = torch.sigmoid(logit)
            else:
                calibrated = torch.softmax(logit, dim=1)

            return calibrated

        # prepare input data
        X = super().transform(X)
        self.to(self._device)

        # convert input data and weights to torch (and possibly to CUDA)
        data = self.prepare(X).float().to(self._device)

        # if weights is 2-D matrix, we are in sampling mode
        # treat each row as a separate weights vector
        if self.method in ['variational', 'mcmc']:

            if mean_estimate:
                weights = {}

                # on MCMC sampling, use mean over all weights as mean weight estimate
                # TODO: we need to find another way since the parameters are conditionally dependent
                # TODO: revise!!! We often have log-normals instead of normal distributions,
                #  thus the mean will be a different
                if self.mcmc_model is not None:
                    for name, site in self._sites.items():
                        weights[name] = torch.from_numpy(
                            np.mean(self.mcmc_model[name])).to(self._device)

                # on variational inference, use mean of the variational distribution for inference
                elif self.vi_model is not None:
                    for name, site in self._sites.items():
                        weights[name] = torch.from_numpy(
                            self.vi_model['params']['%s_mean' % name]).to(
                                self._device)

                else:
                    raise ValueError(
                        "Internal error: neither MCMC nor variational model given."
                    )

                # on MLE without uncertainty, only return the single model estimate
                calibrated = process_model(weights).cpu().numpy()
                calibrated = self.squeeze_generic(calibrated, axes_to_keep=0)
            else:

                parameter = []
                if self.mcmc_model is not None:

                    with manual_seed(seed=random_state):
                        idxs = torch.randint(0,
                                             self.mcmc_steps,
                                             size=(num_samples, ),
                                             device=self._device)
                        samples = {
                            k: v.index_select(0, idxs)
                            for k, v in self.mcmc_model.items()
                        }

                elif self.vi_model is not None:

                    # restore state of global parameter store of pyro and use this parameter store for the predictive
                    pyro.get_param_store().set_state(self.vi_model)
                    predictive = Predictive(self.model,
                                            guide=self.guide,
                                            num_samples=num_samples,
                                            return_sites=tuple(
                                                self._sites.keys()))

                    with manual_seed(seed=random_state):
                        samples = predictive(data)

                else:
                    raise ValueError(
                        "Internal error: neither MCMC nor variational model given."
                    )

                # remove unnecessary dims that possibly occur on MCMC or VI
                samples = {
                    k: torch.reshape(v, (num_samples, -1))
                    for k, v in samples.items()
                }

                # iterate over all parameter sets
                for i in range(num_samples):
                    param_dict = {}

                    # iterate over all sites and store into parameter dict
                    for site in self._sites.keys():
                        param_dict[site] = samples[site][i].detach().to(
                            self._device)

                    parameter.append(param_dict)

                calibrated = []

                # iterate over all parameter collections and compute calibration mapping
                for param_dict in parameter:
                    cal = process_model(param_dict)
                    calibrated.append(cal)

                # stack all calibrated estimates along axis 0 and calculate stddev as well as mean
                calibrated = torch.stack(calibrated, dim=0).cpu().numpy()
                calibrated = self.squeeze_generic(calibrated,
                                                  axes_to_keep=(0, 1))
        else:

            # extract all weight values of sites and store into single dict
            weights = {}
            for name, site in self._sites.items():
                weights[name] = torch.from_numpy(site['values']).to(
                    self._device)

            # on MLE without uncertainty, only return the single model estimate
            calibrated = process_model(weights).cpu().numpy()
            calibrated = self.squeeze_generic(calibrated, axes_to_keep=0)

        # delete torch data tensor
        del data

        # if device is cuda, empty GPU cache to free memory
        if self._device.type == 'cuda':
            with torch.cuda.device(self._device):
                torch.cuda.empty_cache()

        return calibrated
Ejemplo n.º 20
0
 def save_model(self):
     # save parameters from the pyro module not pytorch itself
     save_path = Path("data/saved_models/")
     save_path.mkdir(exist_ok=True, parents=True)
     pyro.get_param_store().save(
         save_path.joinpath(f"{self.config.id:02}_bnn_params.pr"))
Ejemplo n.º 21
0
trace = poutine.trace(pyromodel).get_trace(torch.tensor(x), torch.tensor(y))
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

ys = []
amp = 1.
sig = 1.0
#xs = np.linspace(0, 5, 500, dtype='float32')
xs = torch.tensor(xtest_pca.astype('float32'))
for i in range(50):
    sampled_model = guide(None, None)
    ys += [sampled_model(xs).cpu().detach().numpy().flatten()]
ys = np.stack(ys).T

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).shape)
"""
plt.figure()
plt.yscale('linear')
plt.title("Training Data")
plt.xlabel("hu (mu = <10> is Y for NN)")
plt.ylabel("Intensity (X for NN)")
plt.plot(hu,x[::10,:].T)
"""
plt.figure()
plt.yscale('linear')
plt.title("Fit to mu")
plt.xlabel("1st PCA component")
plt.ylabel("mu")
plt.legend()
def boosting_bbvi():
    n_iterations = 2

    initial_approximation = dummy_approximation
    components = [initial_approximation]
    weights = torch.tensor([1.])
    wrapped_approximation = partial(approximation,
                                    components=components,
                                    weights=weights)

    locs = [0]
    scales = [0]

    gradient_norms = defaultdict(list)
    for t in range(1, n_iterations + 1):
        # setup the inference algorithm
        wrapped_guide = partial(guide, index=t)
        # do gradient steps
        losses = []
        # Register hooks to monitor gradient norms.
        wrapped_guide(data)
        print(pyro.get_param_store().named_parameters())

        adam_params = {"lr": 0.002, "betas": (0.90, 0.999)}
        optimizer = Adam(adam_params)
        for name, value in pyro.get_param_store().named_parameters():
            if not name in gradient_norms:
                value.register_hook(lambda g, name=name: gradient_norms[name].
                                    append(g.norm().item()))

        svi = SVI(model, wrapped_guide, optimizer, loss=relbo)
        for step in range(n_steps):
            loss = svi.step(data, approximation=wrapped_approximation)
            losses.append(loss)

            if PRINT_INTERMEDIATE_LATENT_VALUES:
                print('Loss: {}'.format(loss))
                variance = pyro.param("variance_{}".format(t)).item()
                mu = pyro.param("mu_{}".format(t)).item()
                print('mu = {}'.format(mu))
                print('variance = {}'.format(variance))

            if step % 100 == 0:
                print('.', end=' ')

        pyplot.plot(range(len(losses)), losses)
        pyplot.xlabel('Update Steps')
        pyplot.ylabel('-ELBO')
        pyplot.title('-ELBO against time for component {}'.format(t))
        pyplot.show()

        components.append(wrapped_guide)
        new_weight = 2 / (t + 1)

        weights = weights * (1 - new_weight)
        weights = torch.cat((weights, torch.tensor([new_weight])))

        wrapped_approximation = partial(approximation,
                                        components=components,
                                        weights=weights)

        scale = pyro.param("variance_{}".format(t)).item()
        scales.append(scale)
        loc = pyro.param("mu_{}".format(t)).item()
        locs.append(loc)
        print('mu = {}'.format(loc))
        print('variance = {}'.format(scale))

    pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor('white')
    for name, grad_norms in gradient_norms.items():
        pyplot.plot(grad_norms, label=name)
        pyplot.xlabel('iters')
        pyplot.ylabel('gradient norm')
        # pyplot.yscale('log')
        pyplot.legend(loc='best')
        pyplot.title('Gradient norms during SVI')
    pyplot.show()

    print(weights)
    print(locs)
    print(scales)

    X = np.arange(-10, 10, 0.1)
    Y1 = weights[1].item() * scipy.stats.norm.pdf((X - locs[1]) / scales[1])
    Y2 = weights[2].item() * scipy.stats.norm.pdf((X - locs[2]) / scales[2])

    pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor('white')
    pyplot.plot(X, Y1, 'r-')
    pyplot.plot(X, Y2, 'b-')
    pyplot.plot(X, Y1 + Y2, 'k--')
    pyplot.plot(data.data.numpy(), np.zeros(len(data)), 'k*')
    pyplot.title('Approximation of posterior over mu')
    pyplot.ylabel('probability density')
    pyplot.show()
Ejemplo n.º 23
0
def run_svi(beta_hat, obs_error, K, true_beta):
    num_steps = TOTAL_ITS//K
    start = time()
    pyro.clear_param_store()
    pyro.enable_validation(True)

    def my_model():
        return prs_model(torch.tensor(beta_hat),
                         torch.tensor(obs_error))

    initial_approximation = partial(prs_guide, index=0)
    components = [initial_approximation]
    weights = torch.tensor([1.])
    wrapped_approximation = partial(approximation,
                                    components=components,
                                    weights=weights)
    optimizer = pyro.optim.Adam({'lr': LR})
    losses = []
    wrapped_guide = partial(prs_guide, index=0)
    svi = pyro.infer.SVI(
        my_model,
        wrapped_guide,
        optimizer,
        loss=pyro.infer.Trace_ELBO(num_particles=NUM_PARTICLES)
    )
    for step in range(num_steps):
        loss = svi.step()
        losses.append(loss)
        if step % 100 == 0:
            print('\t', step, np.mean(losses[-100:]))
        if step % 100 == 0:
            pstore = pyro.get_param_store()
            curr_mean = pstore.get_param(
                'var_mean_{}'.format(0)).detach().numpy()
            curr_psis = pstore.get_param(
                'var_psi_causal_{}'.format(0)).detach().numpy()
            curr_mean = curr_mean * curr_psis
            print('\t\t', np.corrcoef(true_beta, curr_mean)[0, 1],
                  np.mean((true_beta - curr_mean)**2))
    pstore = pyro.get_param_store()
    for t in range(1, K):
        print('Boost level', t)
        wrapped_guide = partial(prs_guide, index=t)
        losses = []
        optimizer = pyro.optim.Adam({'lr': LR})

        svi = pyro.infer.SVI(my_model, wrapped_guide, optimizer, loss=relbo)
        new_weight = 2 / ((t+1) + 2)
        new_weights = torch.cat((weights * (1-new_weight),
                                 torch.tensor([new_weight])))
        for step in range(num_steps):
            loss = svi.step(approximation=wrapped_approximation)
            losses.append(loss)
            if step % 100 == 0:
                print('\t', step, np.mean(losses[-100:]))
            if step % 100 == 0:
                pstore = pyro.get_param_store()
                curr_means = [
                    pstore.get_param(
                        'var_mean_{}'.format(s)).detach().numpy()
                    for s in range(t+1)
                ]
                curr_psis = [
                    pstore.get_param(
                        'var_psi_causal_{}'.format(0)).detach().numpy()
                    for s in range(t+1)
                ]
                curr_means = np.array(curr_means) * np.array(curr_psis)
                curr_mean = new_weights.detach().numpy().dot(curr_means)
                print('\t\t', np.corrcoef(true_beta, curr_mean)[0, 1],
                      np.mean((true_beta - curr_mean)**2))

        components.append(wrapped_guide)
        weights = new_weights
        wrapped_approximation = partial(approximation,
                                        components=components,
                                        weights=weights)
        # scales.append(
        #     pstore.get_param('var_mean_{}'.format(t)).detach().numpy()
        # )
    print('BBBVI ran in', time() - start)
    pstore = pyro.get_param_store()
    curr_means = [
        pstore.get_param(
            'var_mean_{}'.format(s)).detach().numpy()
        for s in range(K)
    ]
    return weights.detach().numpy().dot(np.array(np.array(curr_means)))
Ejemplo n.º 24
0
def backtest(data,
             covariates,
             model_fn,
             *,
             forecaster_fn=Forecaster,
             metrics=None,
             transform=None,
             train_window=None,
             min_train_window=1,
             test_window=None,
             min_test_window=1,
             stride=1,
             seed=1234567890,
             num_samples=100,
             batch_size=None,
             forecaster_options={}):
    """
    Backtest a forecasting model on a moving window of (train,test) data.

    :param data: A tensor dataset with time dimension -2.
    :type data: ~torch.Tensor
    :param covariates: A tensor of covariates with time dimension -2.
        For models not using covariates, pass a shaped empty tensor
        ``torch.empty(duration, 0)``.
    :type covariates: ~torch.Tensor
    :param callable model_fn: Function that returns an
        :class:`~pyro.contrib.forecast.forecaster.ForecastingModel` object.
    :param callable forecaster_fn: Function that returns a forecaster object
        (for example, :class:`~pyro.contrib.forecast.forecaster.Forecaster`
        or :class:`~pyro.contrib.forecast.forecaster.HMCForecaster`)
        given arguments model, training data, training covariates and
        keyword arguments defined in `forecaster_options`.
    :param dict metrics: A dictionary mapping metric name to metric function.
        The metric function should input a forecast ``pred`` and ground
        ``truth`` and can output anything, often a number. Example metrics
        include: :func:`eval_mae`, :func:`eval_rmse`, and :func:`eval_crps`.
    :param callable transform: An optional transform to apply before computing
        metrics. If provided this will be applied as
        ``pred, truth = transform(pred, truth)``.
    :param int train_window: Size of the training window. Be default trains
        from beginning of data. This must be None if forecaster is
        :class:`~pyro.contrib.forecast.forecaster.Forecaster` and
        ``forecaster_options["warm_start"]`` is true.
    :param int min_train_window: If ``train_window`` is None, this specifies
        the min training window size. Defaults to 1.
    :param int test_window: Size of the test window. By default forecasts to
        end of data.
    :param int min_test_window: If ``test_window`` is None, this specifies
        the min test window size. Defaults to 1.
    :param int stride: Optional stride for test/train split. Defaults to 1.
    :param int seed: Random number seed.
    :param int num_samples: Number of samples for forecast. Defaults to 100.
    :param int batch_size: Batch size for forecast sampling. Defaults to
        ``num_samples``.
    :param forecaster_options: Options dict to pass to forecaster, or callable
        inputting time window ``t0,t1,t2`` and returning such a dict. See
        :class:`~pyro.contrib.forecaster.Forecaster` for details.
    :type forecaster_options: dict or callable

    :returns: A list of dictionaries of evaluation data. Caller is responsible
        for aggregating the per-window metrics. Dictionary keys include: train
        begin time "t0", train/test split time "t1", test end  time "t2",
        "seed", "num_samples", "train_walltime", "test_walltime", and one key
        for each metric.
    :rtype: list
    """
    assert data.size(-2) == covariates.size(-2)
    assert isinstance(min_train_window, int) and min_train_window >= 1
    assert isinstance(min_test_window, int) and min_test_window >= 1
    if metrics is None:
        metrics = DEFAULT_METRICS
    assert metrics, "no metrics specified"

    if callable(forecaster_options):
        forecaster_options_fn = forecaster_options
    else:

        def forecaster_options_fn(*args, **kwargs):
            return forecaster_options

    if train_window is not None and forecaster_options_fn().get("warm_start"):
        raise ValueError("Cannot warm start with moving training window; "
                         "either set warm_start=False or train_window=None")

    duration = data.size(-2)
    if test_window is None:
        stop = duration - min_test_window + 1
    else:
        stop = duration - test_window + 1
    if train_window is None:
        start = min_train_window
    else:
        start = train_window

    pyro.clear_param_store()
    results = []
    for t1 in range(start, stop, stride):
        t0 = 0 if train_window is None else t1 - train_window
        t2 = duration if test_window is None else t1 + test_window
        assert 0 <= t0 < t1 < t2 <= duration
        logger.info(
            "Training on window [{t0}:{t1}], testing on window [{t1}:{t2}]".
            format(t0=t0, t1=t1, t2=t2))

        # Train a forecaster on the training window.
        pyro.set_rng_seed(seed)
        forecaster_options = forecaster_options_fn(t0=t0, t1=t1, t2=t2)
        if not forecaster_options.get("warm_start"):
            pyro.clear_param_store()
        train_data = data[..., t0:t1, :]
        train_covariates = covariates[..., t0:t1, :]
        start_time = default_timer()
        model = model_fn()
        forecaster = forecaster_fn(model, train_data, train_covariates,
                                   **forecaster_options)
        train_walltime = default_timer() - start_time

        # Forecast forward to testing window.
        test_covariates = covariates[..., t0:t2, :]
        start_time = default_timer()
        # Gradually reduce batch_size to avoid OOM errors.
        while True:
            try:
                pred = forecaster(train_data,
                                  test_covariates,
                                  num_samples=num_samples,
                                  batch_size=batch_size)
                break
            except RuntimeError as e:
                if "out of memory" in str(e) and batch_size > 1:
                    batch_size = (batch_size + 1) // 2
                    warnings.warn(
                        "out of memory, decreasing batch_size to {}".format(
                            batch_size), RuntimeWarning)
                else:
                    raise
        test_walltime = default_timer() - start_time
        truth = data[..., t1:t2, :]

        # We aggressively garbage collect because Monte Carlo forecast are memory intensive.
        del forecaster

        # Evaluate the forecasts.
        if transform is not None:
            pred, truth = transform(pred, truth)
        result = {
            "t0": t0,
            "t1": t1,
            "t2": t2,
            "seed": seed,
            "num_samples": num_samples,
            "train_walltime": train_walltime,
            "test_walltime": test_walltime,
            "params": {},
        }
        results.append(result)
        for name, fn in metrics.items():
            result[name] = fn(pred, truth)
        for name, value in pyro.get_param_store().items():
            if value.numel() == 1:
                value = value.cpu().item()
                result["params"][name] = value
        for dct in (result, result["params"]):
            for key, value in sorted(dct.items()):
                if isinstance(value, (int, float)):
                    logger.debug("{} = {:0.6g}".format(key, value))

        del pred

    return results
Ejemplo n.º 25
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # get info regarding rao-blackwellization of vectorized map_data
        guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"]
        model_vec_md_info = model_trace.graph["vectorized_map_data_info"]
        guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition']
        model_vec_md_condition = model_vec_md_info['rao-blackwellization-condition']
        do_vec_rb = guide_vec_md_condition and model_vec_md_condition
        if not do_vec_rb:
            warnings.warn(
                "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. "
                "Falling back to higher-variance gradient estimator. "
                "Try to avoid these issues in your model and guide:\n{}".format("\n".join(
                    guide_vec_md_info["warnings"] | model_vec_md_info["warnings"])))
        guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set()
        model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set()

        # have the trace compute all the individual (batch) log pdf terms
        # so that they are available below
        guide_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in guide_vec_md_nodes)
        guide_trace.log_pdf()
        model_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in model_vec_md_nodes)
        model_trace.log_pdf()

        # prepare a list of all the cost nodes, each of which is +- log_pdf
        cost_nodes = []
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    cost_nodes.append(CostNode(model_site["log_pdf"], True))
                else:
                    # cost node from model sample
                    cost_nodes.append(CostNode(model_site["log_pdf"], True))
                    # cost node from guide sample
                    guide_site = guide_trace.nodes[name]
                    zero_expectation = name in non_reparam_nodes
                    cost_nodes.append(CostNode(-guide_site["log_pdf"], not zero_expectation))

        # compute the elbo; if all stochastic nodes are reparameterizable, we're done
        # this bit is never differentiated: it's here for getting an estimate of the elbo itself
        elbo = torch_data_sum(sum(c.cost for c in cost_nodes))

        # compute the surrogate elbo, removing terms whose gradient is zero
        # this is the bit that's actually differentiated
        # XXX should the user be able to control if these terms are included?
        surrogate_elbo = sum(c.cost for c in cost_nodes if c.nonzero_expectation)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:

            # recursively compute downstream cost nodes for all sample sites in model and guide
            # (even though ultimately just need for non-reparameterizable sample sites)
            # 1. downstream costs used for rao-blackwellization
            # 2. model observe sites (as well as terms that arise from the model and guide having different
            # dependency structures) are taken care of via 'children_in_model' below
            topo_sort_guide_nodes = list(reversed(list(networkx.topological_sort(guide_trace))))
            topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes
                                     if guide_trace.nodes[x]["type"] == "sample"]
            downstream_guide_cost_nodes = {}
            downstream_costs = {}

            for node in topo_sort_guide_nodes:
                node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \
                    guide_trace.nodes[node][node_log_pdf_key]
                nodes_included_in_sum = set([node])
                downstream_guide_cost_nodes[node] = set([node])
                for child in guide_trace.successors(node):
                    child_cost_nodes = downstream_guide_cost_nodes[child]
                    downstream_guide_cost_nodes[node].update(child_cost_nodes)
                    if nodes_included_in_sum.isdisjoint(child_cost_nodes):  # avoid duplicates
                        if node_log_pdf_key == 'log_pdf':
                            downstream_costs[node] += downstream_costs[child].sum()
                        else:
                            downstream_costs[node] += downstream_costs[child]
                        nodes_included_in_sum.update(child_cost_nodes)
                missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum
                # include terms we missed because we had to avoid duplicates
                for missing_node in missing_downstream_costs:
                    mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf'
                    if node_log_pdf_key == 'log_pdf':
                        downstream_costs[node] += (model_trace.nodes[missing_node][mn_log_pdf_key] -
                                                   guide_trace.nodes[missing_node][mn_log_pdf_key]).sum()
                    else:
                        downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \
                                                  guide_trace.nodes[missing_node][mn_log_pdf_key]

            # finish assembling complete downstream costs
            # (the above computation may be missing terms from model)
            # XXX can we cache some of the sums over children_in_model to make things more efficient?
            for site in non_reparam_nodes:
                children_in_model = set()
                for node in downstream_guide_cost_nodes[site]:
                    children_in_model.update(model_trace.successors(node))
                # remove terms accounted for above
                children_in_model.difference_update(downstream_guide_cost_nodes[site])
                for child in children_in_model:
                    child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf'
                    site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf'
                    assert (model_trace.nodes[child]["type"] == "sample")
                    if site_log_pdf_key == 'log_pdf':
                        downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key].sum()
                    else:
                        downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key]

            # construct all the reinforce-like terms.
            # we include only downstream costs to reduce variance
            # optionally include baselines to further reduce variance
            # XXX should the average baseline be in the param store as below?
            elbo_reinforce_terms = 0.0
            for node in non_reparam_nodes:
                guide_site = guide_trace.nodes[node]
                log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                downstream_cost = downstream_costs[node]
                baseline = 0.0
                (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta,
                    baseline_value) = _get_baseline_options(guide_site)
                use_nn_baseline = nn_baseline is not None
                use_baseline_value = baseline_value is not None
                assert(not (use_nn_baseline and use_baseline_value)), \
                    "cannot use baseline_value and nn_baseline simultaneously"
                if use_decaying_avg_baseline:
                    avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node,
                                                         ng_zeros(1), tags="__tracegraph_elbo_internal_tag")
                    avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
                        baseline_beta * avg_downstream_cost_old
                    avg_downstream_cost_old.data = avg_downstream_cost_new.data  # XXX copy_() ?
                    baseline += avg_downstream_cost_old
                if use_nn_baseline:
                    # block nn_baseline_input gradients except in baseline loss
                    baseline += nn_baseline(detach_iterable(nn_baseline_input))
                elif use_baseline_value:
                    # it's on the user to make sure baseline_value tape only points to baseline params
                    baseline += baseline_value
                if use_nn_baseline or use_baseline_value:
                    # accumulate baseline loss
                    baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum()

                guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"]  # not scaled by subsampling
                if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
                    if downstream_cost.size() != baseline.size():
                        raise ValueError("Expected baseline at site {} to be {} instead got {}".format(
                            node, downstream_cost.size(), baseline.size()))
                    downstream_cost = downstream_cost - baseline
                elbo_reinforce_terms += (guide_log_pdf * downstream_cost.detach()).sum()

            surrogate_elbo += elbo_reinforce_terms

        # collect parameters to train from model and guide
        trainable_params = set(site["value"]
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values()
                               if site["type"] == "param")

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))
            pyro.get_param_store().mark_params_active(trainable_params)

        loss = -elbo
        return weight * loss
Ejemplo n.º 26
0
def get_encodings(model: VariationalInferenceModel,
                  dataset_obj,
                  cells_only: bool = True) -> Tuple[np.ndarray,
                                                    np.ndarray,
                                                    np.ndarray]:
    """Get inferred quantities from a trained model.

    Run a dataset through the model's trained encoder and return the inferred
    quantities.

    Args:
        model: A trained cellbender.model.VariationalInferenceModel, which will be
            used to generate the encodings from data.
        dataset_obj: The dataset to be encoded.
        cells_only: If True, only returns the encodings of barcodes that are
            determined to contain cells.

    Returns:
        z: Latent variable embedding of gene expression in a low-dimensional
            space.
        d: Latent variable scale factor for the number of UMI counts coming
            from each real cell.  Not in log space, but actual size.  This is
            not just the encoded d, but the mean of the LogNormal distribution,
            which is exp(mean + sigma^2 / 2).
        p: Latent variable denoting probability that each barcode contains a
            real cell.

    """

    logging.info("Encoding data according to model.")

    # Get the count matrix with genes trimmed.
    if cells_only:
        dataset = dataset_obj.get_count_matrix()
    else:
        dataset = dataset_obj.get_count_matrix_all_barcodes()

    # Initialize numpy arrays as placeholders.
    z = np.zeros((dataset.shape[0], model.z_dim))
    d = np.zeros((dataset.shape[0]))
    p = np.zeros((dataset.shape[0]))

    # Get chi ambient, if it was part of the model.
    chi_ambient = get_ambient_expression()
    if chi_ambient is not None:
        chi_ambient = torch.Tensor(chi_ambient).to(device=model.device)

    # Send dataset through the learned encoder in chunks.
    s = 200
    for i in np.arange(0, dataset.shape[0], s):

        # Put chunk of data into a torch.Tensor.
        x = torch.Tensor(np.array(
            dataset[i:min(dataset.shape[0], i + s), :].todense(),
            dtype=int).squeeze()).to(device=model.device)

        # Send data chunk through encoder.
        enc = model.encoder.forward(x, chi_ambient)

        # Get d_cell_scale from fit model.
        d_sig = \
            pyro.get_param_store().get_param('d_cell_scale').detach().cpu().numpy()

        # Put the resulting encodings into the appropriate numpy arrays.
        z[i:min(dataset.shape[0], i + s), :] = \
            enc['z']['loc'].detach().cpu().numpy()
        d[i:min(dataset.shape[0], i + s)] = \
            np.exp(enc['d_loc'].detach().cpu().numpy() + d_sig.item()**2 / 2)
        try:  # p is not always available: it depends which model was used.
            p[i:min(dataset.shape[0], i + s)] = \
                enc['p_y'].detach().sigmoid().cpu().numpy()
        except KeyError:
            p = None  # Simple model gets None for p.

    return z, d, p
Ejemplo n.º 27
0
def main(args):
    logging.info(f"CUDA available: {torch.cuda.is_available()}")
    logging.info('Generating data')
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(True)

    # Debugging the trace of the model. For showing the shapes of the tensors through the model
    # tracemodel = functools.partial(model, args=args)
    # trace = poutine.trace(tracemodel).get_trace()
    # trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
    # print(trace.format_shapes())

    # We can generate synthetic data directly by calling the model.
    data = model(args=args)

    gen_doc_word_data = data["doc_word_data"]
    gen_doc_category_data = data["doc_category_data"]

    # Loading data
    corpora = prepro_file_load("corpora")
    documents = list(prepro_file_load("id2pre_text").values())
    category_list = [[cat]
                     for cat in list(prepro_file_load("id2category").values())]
    category_corpora = prepro_file_load("category_corpora")

    doc_word_data = [
        torch.tensor(list(filter(lambda a: a != -1, corpora.doc2idx(doc))),
                     dtype=torch.int64) for doc in documents
    ]
    doc_category_data = [
        torch.tensor(next(
            filter(lambda a: a != -1, category_corpora.doc2idx(cat))),
                     dtype=torch.int64) for cat in category_list
    ]
    # TODO X check if there are differences in this date and model generated data

    # Slice data to only use data from the first n documents
    data_slice = None
    if data_slice is not None:
        doc_word_data = doc_word_data[:data_slice]
        doc_category_data = doc_category_data[:data_slice]

    # Setting the new args
    args.num_words_per_doc = list(map(len, doc_word_data))
    args.num_words = len(corpora)
    args.num_docs = len(doc_word_data)
    args.num_categories = len(category_corpora)
    args.num_topics = args.num_categories * 2  # TODO X test different amounts of topics

    # We'll train using SVI.
    logging.info('-' * 40)
    logging.info('Training on {} documents'.format(args.num_docs))
    Elbo = JitTraceEnum_ELBO if args.jit else Trace_ELBO  # TODO X test TraceEnum_ vs Trace_ vs TraceMeanField_
    elbo = Elbo(
        max_plate_nesting=2
    )  # TODO Changing the max plate nesting value might be worth looking at
    optim = ClippedAdam({'lr': args.learning_rate
                         })  # TODO X try different learning rates
    # TODO try something other than ClippedAdam or changing its parameters
    svi = SVI(model, parametrized_guide, optim, elbo)

    losses = []

    # Training for num_steps iterations
    logging.info('Step\tLoss')
    for step in tqdm(range(args.num_steps)):
        loss = svi.step(doc_word_data=doc_word_data,
                        category_data=doc_category_data,
                        args=args,
                        batch_size=args.batch_size)
        losses.append(loss)
        if step % 10 == 0:
            logging.info('{: >5d}\t{}'.format(step, loss))
    loss = elbo.loss(model,
                     parametrized_guide,
                     doc_word_data=doc_word_data,
                     category_data=doc_category_data,
                     args=args,
                     batch_size=args.batch_size)
    logging.info('final loss = {}'.format(loss))

    # Print params after training
    print('topic_weights_posterior = ', pyro.param("topic_weights_posterior"))
    print('topic_words_posterior = ', pyro.param("topic_words_posterior"))
    print('category_weights_posterior = ',
          pyro.param("category_weights_posterior"))
    print('category_topics_posterior = ',
          pyro.param("category_topics_posterior"))
    print('doc_category_posterior = ', pyro.param("doc_category_posterior"))

    # Plot loss over iterations
    plt.plot(losses)
    plt.title("ELBO")
    plt.xlabel("step")
    plt.ylabel("loss")
    plot_file_name = "../loss-2017_categories-" + str(args.num_categories) + \
                     "_topics-" + str(args.num_topics) + \
                     "_batch-" + str(args.batch_size) + \
                     "_lr-" + str(args.learning_rate) + \
                     "_data-size-" + str(data_slice) + \
                     ".png"
    plt.savefig(plot_file_name)
    plt.show()

    # save model
    pyro.get_param_store().save("mymodelparams.pt")
Ejemplo n.º 28
0
def main(args):

    # Load dataset.
    if args.cpu_data or not args.cuda:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
    if args.test:
        dataset = generate_data(args.small, args.include_stop, device)
    else:
        dataset = BiosequenceDataset(
            args.file,
            "fasta",
            args.alphabet,
            include_stop=args.include_stop,
            device=device,
        )
    args.batch_size = min([dataset.data_size, args.batch_size])
    if args.split > 0.0:
        # Train test split.
        heldout_num = int(np.ceil(args.split * len(dataset)))
        data_lengths = [len(dataset) - heldout_num, heldout_num]
        # Specific data split seed, for comparability across models and
        # parameter initializations.
        pyro.set_rng_seed(args.rng_data_seed)
        indices = torch.randperm(sum(data_lengths), device=device).tolist()
        dataset_train, dataset_test = [
            torch.utils.data.Subset(dataset, indices[(offset - length):offset])
            for offset, length in zip(torch._utils._accumulate(data_lengths),
                                      data_lengths)
        ]
    else:
        dataset_train = dataset
        dataset_test = None

    # Training seed.
    pyro.set_rng_seed(args.rng_seed)

    # Construct model.
    model = FactorMuE(
        dataset.max_length,
        dataset.alphabet_length,
        args.z_dim,
        batch_size=args.batch_size,
        latent_seq_length=args.latent_seq_length,
        indel_factor_dependence=args.indel_factor,
        indel_prior_scale=args.indel_prior_scale,
        indel_prior_bias=args.indel_prior_bias,
        inverse_temp_prior=args.inverse_temp_prior,
        weights_prior_scale=args.weights_prior_scale,
        offset_prior_scale=args.offset_prior_scale,
        z_prior_distribution=args.z_prior,
        ARD_prior=args.ARD_prior,
        substitution_matrix=(not args.no_substitution_matrix),
        substitution_prior_scale=args.substitution_prior_scale,
        latent_alphabet_length=args.latent_alphabet,
        cuda=args.cuda,
        pin_memory=args.pin_mem,
    )

    # Infer with SVI.
    scheduler = MultiStepLR({
        "optimizer": Adam,
        "optim_args": {
            "lr": args.learning_rate
        },
        "milestones": json.loads(args.milestones),
        "gamma": args.learning_gamma,
    })
    n_epochs = args.n_epochs
    losses = model.fit_svi(
        dataset_train,
        n_epochs,
        args.anneal,
        args.batch_size,
        scheduler,
        args.jit,
    )

    # Evaluate.
    train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
        dataset_train, dataset_test, args.jit)
    print("train logp: {} perplex: {}".format(train_lp, train_perplex))
    print("test logp: {} perplex: {}".format(test_lp, test_perplex))

    # Get latent space embedding.
    z_locs, z_scales = model.embed(dataset)

    # Plot and save.
    time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    if not args.no_plots:
        plt.figure(figsize=(6, 6))
        plt.plot(losses)
        plt.xlabel("step")
        plt.ylabel("loss")
        if not args.no_save:
            plt.savefig(
                os.path.join(args.out_folder,
                             "FactorMuE_plot.loss_{}.pdf".format(time_stamp)))

        plt.figure(figsize=(6, 6))
        plt.scatter(z_locs[:, 0], z_locs[:, 1])
        plt.xlabel(r"$z_1$")
        plt.ylabel(r"$z_2$")
        if not args.no_save:
            plt.savefig(
                os.path.join(
                    args.out_folder,
                    "FactorMuE_plot.latent_{}.pdf".format(time_stamp)))

        if not args.indel_factor:
            # Plot indel parameters. See statearrangers.py for details on the
            # r and u parameters.
            plt.figure(figsize=(6, 6))
            insert = pyro.param("insert_q_mn").detach()
            insert_expect = torch.exp(insert - insert.logsumexp(-1, True))
            plt.plot(insert_expect[:, :, 1].cpu().numpy())
            plt.xlabel("position")
            plt.ylabel("probability of insert")
            plt.legend([r"$r_0$", r"$r_1$", r"$r_2$"])
            if not args.no_save:
                plt.savefig(
                    os.path.join(
                        args.out_folder,
                        "FactorMuE_plot.insert_prob_{}.pdf".format(time_stamp),
                    ))
            plt.figure(figsize=(6, 6))
            delete = pyro.param("delete_q_mn").detach()
            delete_expect = torch.exp(delete - delete.logsumexp(-1, True))
            plt.plot(delete_expect[:, :, 1].cpu().numpy())
            plt.xlabel("position")
            plt.ylabel("probability of delete")
            plt.legend([r"$u_0$", r"$u_1$", r"$u_2$"])
            if not args.no_save:
                plt.savefig(
                    os.path.join(
                        args.out_folder,
                        "FactorMuE_plot.delete_prob_{}.pdf".format(time_stamp),
                    ))

    if not args.no_save:
        pyro.get_param_store().save(
            os.path.join(args.out_folder,
                         "FactorMuE_results.params_{}.out".format(time_stamp)))
        with open(
                os.path.join(
                    args.out_folder,
                    "FactorMuE_results.evaluation_{}.txt".format(time_stamp),
                ),
                "w",
        ) as ow:
            ow.write("train_lp,test_lp,train_perplex,test_perplex\n")
            ow.write("{},{},{},{}\n".format(train_lp, test_lp, train_perplex,
                                            test_perplex))
        np.savetxt(
            os.path.join(
                args.out_folder,
                "FactorMuE_results.embed_loc_{}.txt".format(time_stamp)),
            z_locs.cpu().numpy(),
        )
        np.savetxt(
            os.path.join(
                args.out_folder,
                "FactorMuE_results.embed_scale_{}.txt".format(time_stamp),
            ),
            z_scales.cpu().numpy(),
        )
        with open(
                os.path.join(
                    args.out_folder,
                    "FactorMuE_results.input_{}.txt".format(time_stamp),
                ),
                "w",
        ) as ow:
            ow.write("[args]\n")
            args.latent_seq_length = model.latent_seq_length
            args.latent_alphabet = model.latent_alphabet_length
            for elem in list(args.__dict__.keys()):
                ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
            ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
            ow.write("max_length = {}\n".format(dataset.max_length))
Ejemplo n.º 29
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """
        elbo = 0.0
        # grab a trace from the generator
        for weight, model_trace, guide_trace, log_r in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = weight * 0
            surrogate_elbo_particle = weight * 0
            # compute elbo and surrogate elbo
            log_pdf = "batch_log_pdf" if (
                self.enum_discrete and weight.size(0) > 1) else "log_pdf"
            for name, model_site in model_trace.nodes.items():
                if model_site["type"] == "sample":
                    if model_site["is_observed"]:
                        elbo_particle += model_site[log_pdf]
                        surrogate_elbo_particle += model_site[log_pdf]
                    else:
                        guide_site = guide_trace.nodes[name]
                        lp_lq = model_site[log_pdf] - guide_site[log_pdf]
                        elbo_particle += lp_lq
                        if guide_site["fn"].reparameterized:
                            surrogate_elbo_particle += lp_lq
                        else:
                            # XXX should the user be able to control inclusion of the -logq term below?
                            guide_log_pdf = guide_site[log_pdf] / guide_site[
                                "scale"]  # not scaled by subsampling
                            surrogate_elbo_particle += model_site[
                                log_pdf] + log_r.detach() * guide_log_pdf

            # drop terms of weight zero to avoid nans
            if isinstance(weight, numbers.Number):
                if weight == 0.0:
                    elbo_particle = torch_zeros_like(elbo_particle)
                    surrogate_elbo_particle = torch_zeros_like(
                        surrogate_elbo_particle)
            else:
                weight_eq_zero = (weight == 0)
                elbo_particle[weight_eq_zero] = 0.0
                surrogate_elbo_particle[weight_eq_zero] = 0.0

            elbo += torch_data_sum(weight * elbo_particle)
            surrogate_elbo_particle = torch_sum(weight *
                                                surrogate_elbo_particle)

            # collect parameters to train from model and guide
            trainable_params = set(site["value"]
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values()
                                   if site["type"] == "param")

            if trainable_params:
                surrogate_loss_particle = -surrogate_elbo_particle
                torch_backward(surrogate_loss_particle)
                pyro.get_param_store().mark_params_active(trainable_params)

        loss = -elbo
        if np.isnan(loss):
            warnings.warn('Encountered NAN loss')
        return loss
Ejemplo n.º 30
0
 def save_model(self):
     pyro.get_param_store().save('gp_adf_rtss.save')
Ejemplo n.º 31
0
def bayesian_regression(x_data, y_data, num_iterations):
    # BAYESIAN REGRESSION WITH SVI

    class BayesianRegression(PyroModule):
        def __init__(self, in_features, out_features):
            super().__init__()
            self.linear = PyroModule[nn.Linear](in_features, out_features)
            self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
            self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

        def forward(self, x, y=None):
            # forward() specifies the data generating process
            sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) # this is the error term (typically called epsilon in regression equations)
            mean = self.linear(x).squeeze(-1)
            with pyro.plate("data", x.shape[0]):
                obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
            return mean

    """
    Guides -- posterior distribution classes

    The guide determines a family of distributions, and SVI aims to find an 
    approximate posterior distribution from this family that has the lowest
    KL divergence from the true posterior.
    """

    model = BayesianRegression(3, 1)

    """
    Under the hood, this defines a guide that uses a Normal distribution with
    learnable parameters corresponding to each sample statement in the model.
    e.g. in our case, this distribution should have a size of (5,) correspoding
    to the 3 regression coefficients for each of the terms, and 1 component
    contributed each by the intercept term and sigma in the model.
    """

    guide = AutoDiagonalNormal(model)

    adam = pyro.optim.Adam({"lr": 0.03}) # note this is from Pyro's optim module, not PyTorch's 
    svi = SVI(model, guide, adam, loss=Trace_ELBO())

    """
    We do not need to pass in learnable parameters to the optimizer
    (unlike the PyTorch example above) since that is determined by the guide
    code and happens behind the scenes within the SVI class automatically.
    To take an ELBO gradient step we simply call the step method of SVI.
    The data argument we pass to SVI.step will be passed to both
    model() and guide().
    """

    pyro.clear_param_store()
    for j in range(num_iterations):
        # calculate the loss and take a gradient step
        loss = svi.step(x_data, y_data)
        if (j+1) % 100 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))


    # We can examine the optimized parameter values by fetching from Pyro’s param store.

    guide.requires_grad_(False) # not sure what this does

    for name, value in pyro.get_param_store().items():
        print(name, pyro.param(name))


    # This gets us quantiles from the posterior distribution
    guide.quantiles([0.25, 0.5, 0.75])

    """
    Since Bayesian models give you a posterior distribution, 
    model evalution needs to be a compbination of sampling the posterior and
    running the samples through the model.

    We generate 800 samples from our trained model. Internally, this is done
    by first generating samples for the unobserved sites in the guide, and
    then running the model forward by conditioning the sites to values sampled
    from the guide. Refer to the Model Serving section for insight on how the
    Predictive class works.
    """

    def summary(samples):
        site_stats = {}
        for k, v in samples.items():
            site_stats[k] = {
                "mean": torch.mean(v, 0),
                "std": torch.std(v, 0),
                "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
                "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
            }
        return site_stats

    """
    Note that in return_sites, we specify both the outcome ("obs" site) as
    well as the return value of the model ("_RETURN") which captures the
    regression line. Additionally, we would also like to capture the regression
    coefficients (given by "linear.weight") for further analysis.
    """

    predictive = Predictive(model, guide=guide, num_samples=800,
                            return_sites=("linear.weight", "obs", "_RETURN"))
    samples = predictive(x_data)
    pred_summary = summary(samples)
Ejemplo n.º 32
0
def main():
    """
    run inference for SS-VAE
    :param args: arguments for SS-VAE
    :return: None
    """
    pyro.set_rng_seed(12345)
    cuda = True
    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = TextSSVAE(embed_dim=300,
                       z_dim=300,
                       kernels=[3, 4, 5],
                       filters=[100, 100, 100],
                       hidden_size=300,
                       num_rnn_layers=1,
                       config_enum="parallel",
                       use_cuda=cuda,
                       aux_loss_multiplier=46)

    ss_vae = ss_vae.cuda()

    try:
        pyro.get_param_store().load('pyro_param_store.store')
        print(
            'successfully loaded param store, remove file from directory if undesired'
        )
    except Exception:
        print("failed to load param store, starting over")

    try:
        ss_vae.load_state_dict(torch.load('ss_vae_model.pth'))
        print(
            'successfully loaded model parameters, remove file from directory if undesired'
        )
    except Exception:
        print("failed to load model parameters")

    # setup the optimizer
    adam_params = {"lr": 1e-4, "betas": (0.9, 0.999), "weight_decay": 0.01}
    optimizer = Adam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    jit = False
    guide = config_enumerate(ss_vae.guide, "parallel", expand=True)
    elbo = (JitTraceEnum_ELBO if jit else TraceEnum_ELBO)()

    loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    aux_loss = True
    if aux_loss:
        elbo = JitTrace_ELBO() if jit else Trace_ELBO()
        loss_aux = SVI(ss_vae.model_classify,
                       ss_vae.guide_classify,
                       optimizer,
                       loss=elbo)
        losses.append(loss_aux)

    batch_size = 32
    valid_num = 100
    train_data_size = 3409
    sup_num = 1163
    try:
        # setup the logger if a filename is provided
        logger = open('./tmp.log', "w") if './tmp.log' else None
        data_loaders = setup_data_loaders(IMDBCached,
                                          cuda,
                                          batch_size=32,
                                          sup_num=valid_num)

        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches
        periodic_interval_batches = int(train_data_size / (1.0 * sup_num))

        # number of unsupervised examples
        unsup_num = train_data_size - sup_num

        # initializing local variables to maintain the best validation accuracy
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # run inference for a certain number of epochs
        num_epochs = 200
        sup_loss_log = []
        unsup_loss_log = []

        for i in range(0, num_epochs):
            # get the losses for an epoch
            epoch_losses_sup, epoch_losses_unsup = \
                run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup)
            avg_epoch_losses_unsup = map(lambda v: v / unsup_num,
                                         epoch_losses_unsup)

            sup_loss_log.append(avg_epoch_losses_sup)
            unsup_loss_log.append(avg_epoch_losses_unsup)

            # store the loss and validation/testing accuracies in the logfile
            str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
            str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

            str_print = "{} epoch: avg losses {}".format(
                i, "{} {}".format(str_loss_sup, str_loss_unsup))
            ss_vae.eval()
            validation_accuracy = get_accuracy(data_loaders["valid"],
                                               ss_vae.classifier, batch_size)
            str_print += " validation accuracy {}".format(validation_accuracy)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(data_loaders["test"],
                                         ss_vae.classifier, batch_size)
            str_print += " test accuracy {}".format(test_accuracy)
            ss_vae.train()
            torch.save(ss_vae.state_dict(), 'ss_vae_model.pth')
            pyro.get_param_store().save('pyro_param_store.store')

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < validation_accuracy:
                best_valid_acc = validation_accuracy
                corresponding_test_acc = test_accuracy
            if i % 10 == 0:
                neg_sentences, neg_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.model,
                    ss_vae.w2v_model,
                    sentiment=0)
                pos_sentences, pos_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.model,
                    ss_vae.w2v_model,
                    sentiment=1)
                str_print += " neg_bleu {}".format(neg_bleu)
                str_print += " pos_bleu {}".format(pos_bleu)
                pd.DataFrame.from_dict(pos_sentences).to_csv(
                    'positive_sentences.csv', encoding='utf-8')
                pd.DataFrame.from_dict(neg_sentences).to_csv(
                    'negative_sentences.csv', encoding='utf-8')

                cond_neg_sentences, neg_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.conditioned_generation,
                    ss_vae.w2v_model,
                    sentiment=0)
                cond_pos_sentences, pos_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.conditioned_generation,
                    ss_vae.w2v_model,
                    sentiment=1)
                pd.DataFrame.from_dict(cond_pos_sentences).to_csv(
                    'cond_positive_sentences.csv', encoding='utf-8')
                pd.DataFrame.from_dict(cond_neg_sentences).to_csv(
                    'cond_negative_sentences.csv', encoding='utf-8')
                str_print += " cond_neg_bleu {}".format(neg_bleu)
                str_print += " cond_pos_bleu {}".format(pos_bleu)

            print_and_log(logger, str_print)

        np.save("avg_loss_sup", np.asarray(sup_loss_log))
        np.save("avg_loss_unsup", np.asarray(unsup_loss_log))
        ss_vae.eval()
        final_test_accuracy = get_accuracy(data_loaders["test"],
                                           ss_vae.classifier, batch_size)
        print_and_log(
            logger,
            "best validation accuracy {} corresponding testing accuracy {} "
            "last testing accuracy {}".format(best_valid_acc,
                                              corresponding_test_acc,
                                              final_test_accuracy))

    finally:
        # close the logger file object if we opened it earlier
        logfile = True
        if logfile:
            logger.close()
Ejemplo n.º 33
0
    def variational(self, data: torch.Tensor, y: torch.Tensor,
                    tensorboard: bool, log_dir: str):
        """
        Perform variational inference using the guide.

        Parameters
        ----------
        data_input : np.ndarray, shape=(n_samples, n_features)
            NumPy 2-D array with data input.
        y : np.ndarray, shape=(n_samples,)
            NumPy array with ground truth labels as 1-D vector (binary).
        """

        # explicitly define datatype
        data = data.float()
        y = y.float()

        num_samples = data.shape[0]

        # create dataset
        lr_dataset = torch.utils.data.TensorDataset(data, y)
        data_loader = DataLoader(dataset=lr_dataset,
                                 batch_size=1024,
                                 pin_memory=False)

        # define optimizer
        optim = Adam({'lr': 0.01})
        svi = SVI(self.model, self.guide, optim, loss=Trace_ELBO())

        # add tensorboard writer if requested
        if tensorboard:
            writer = SummaryWriter(log_dir=log_dir)

        # start variational process
        with tqdm(total=self.vi_epochs) as pbar:
            for epoch in range(self.vi_epochs):
                epoch_loss = 0.
                for i, (x, y) in enumerate(data_loader):
                    epoch_loss += svi.step(x, y)

                # get loss of complete epoch
                epoch_loss = epoch_loss / num_samples

                # logging stuff
                if tensorboard:

                    # add loss to logging
                    writer.add_scalar("SVI loss", epoch_loss, epoch)

                    # get param store and log current state of parameter store
                    param_store = pyro.get_param_store()
                    for key in self._sites.keys():
                        for d, (loc, scale) in enumerate(
                                zip(param_store["%s_mean" % key],
                                    param_store["%s_scale" % key])):
                            writer.add_scalar("%s_mean_%d" % (key, d), loc,
                                              epoch)
                            writer.add_scalar("%s_scale_%d" % (key, d), scale,
                                              epoch)

                            # also represent the weights as distributions
                            density = np.random.normal(
                                loc=loc.detach().cpu().numpy(),
                                scale=scale.detach().cpu().numpy(),
                                size=1000)
                            writer.add_histogram("histogram_%s_%d" % (key, d),
                                                 density, epoch)

                # update progress bar
                pbar.set_description("SVI Loss: %.5f" % epoch_loss)
                pbar.update(1)

        self.vi_model = pyro.get_param_store().get_state()

        if tensorboard:
            writer.close()
Ejemplo n.º 34
0
def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample,
                            local_samples, scale):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    def model(subsample):
        with pyro.plate("data", len(data), subsample_size, subsample) as ind:
            x = data[ind]
            z = pyro.sample("z", Normal(0, 1))
            pyro.sample("x", Normal(z, 1), obs=x)

    def guide(subsample):
        scale = pyro.param("scale", lambda: torch.tensor([1.0]))
        with pyro.plate("data", len(data), subsample_size, subsample):
            loc = pyro.param("loc",
                             lambda: torch.zeros(len(data)),
                             event_dim=0)
            z_dist = Normal(loc, scale)
            if has_rsample is not None:
                z_dist.has_rsample_(has_rsample)
            pyro.sample("z", z_dist)

    if scale != 1.0:
        model = poutine.scale(model, scale=scale)
        guide = poutine.scale(guide, scale=scale)

    num_particles = 50000
    if local_samples:
        guide = config_enumerate(guide, num_samples=num_particles)
        num_particles = 1

    optim = Adam({"lr": 0.1})
    elbo = Elbo(
        max_plate_nesting=1,  # set this to ensure rng agrees across runs
        num_particles=num_particles,
        vectorize_particles=True,
        strict_enumeration_warning=False,
    )
    inference = SVI(model, guide, optim, loss=elbo)
    with xfail_if_not_implemented():
        if subsample_size == 1:
            inference.loss_and_grads(model,
                                     guide,
                                     subsample=torch.tensor([0],
                                                            dtype=torch.long))
            inference.loss_and_grads(model,
                                     guide,
                                     subsample=torch.tensor([1],
                                                            dtype=torch.long))
        else:
            inference.loss_and_grads(model,
                                     guide,
                                     subsample=torch.tensor([0, 1],
                                                            dtype=torch.long))
    params = dict(pyro.get_param_store().named_parameters())
    normalizer = 2 if subsample else 1
    actual_grads = {
        name: param.grad.detach().cpu().numpy() / normalizer
        for name, param in params.items()
    }

    expected_grads = {
        "loc": scale * np.array([0.5, -2.0]),
        "scale": scale * np.array([2.0]),
    }
    for name in sorted(params):
        logger.info("expected {} = {}".format(name, expected_grads[name]))
        logger.info("actual   {} = {}".format(name, actual_grads[name]))
    assert_equal(actual_grads, expected_grads, prec=precision)
Ejemplo n.º 35
0
                     df["obs_perc_5"],
                     df["obs_perc_95"],
                     color='C1',
                     alpha=0.5)
    plt.legend()


if __name__ == '__main__':

    svi, model, guide = get_pyro_model(return_all=True)

    saved_param_files = glob.glob(MODEL_FILES)
    saved_param_files.sort(key=os.path.getmtime, reverse=True)
    print(*saved_param_files, sep='\n')
    idx = int(input("file? (0 for most recent exp) > "))
    pyro.get_param_store().load(saved_param_files[idx])

    saved_data_files = glob.glob(DATA_FILES)
    saved_data_files.sort(key=os.path.getmtime, reverse=True)
    print(*saved_data_files, sep='\n')
    idx = int(input("file? (0 for most recent data) > "))
    training_generator = iter(
        get_dataset(batch_size=1000, data_file=saved_data_files[idx]))
    x_data, y_data = next(training_generator)

    for name, value in pyro.get_param_store().items():
        print(name, pyro.param(name))

    trace_summary(svi, x_data, y_data)
    guide_summary(guide, x_data, y_data)
Ejemplo n.º 36
0
def run_GMM(data, K):

    @config_enumerate
    def model(data):
        # Global variables.
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        scale = pyro.sample('scale', dist.LogNormal(4., 2.))
        with pyro.plate('components', K):
            locs = pyro.sample('locs', dist.Normal(0., 10.))

        with pyro.plate('data', len(data)):
            # Local variables.
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO(max_plate_nesting=1)

    def init_loc_fn(site):
        if site["name"] == "weights":
            # Initialize weights to uniform.
            return torch.ones(K) / K
        if site["name"] == "scale":
            return (data.var() / 2).sqrt()
        if site["name"] == "locs":
            return data[torch.multinomial(
                torch.ones(len(data)) / len(data), K)]
        raise ValueError(site["name"])

    def initialize(seed):
        global global_guide, svi
        pyro.set_rng_seed(seed)
        pyro.clear_param_store()
        global_guide = AutoDelta(
            poutine.block(
                model,
                expose=[
                    'weights',
                    'locs',
                    'scale']),
            init_loc_fn=init_loc_fn)
        svi = SVI(model, global_guide, optim, loss=elbo)
        return svi.loss(model, global_guide, data)

    # Choose the best among 100 random initializations.
    loss, seed = min((initialize(seed), seed) for seed in range(100))
    initialize(seed)
    print('seed = {}, initial_loss = {}'.format(seed, loss))

    # Register hooks to monitor gradient norms.
    gradient_norms = defaultdict(list)
    for name, value in pyro.get_param_store().named_parameters():
        value.register_hook(
            lambda g,
            name=name: gradient_norms[name].append(
                g.norm().item()))

    losses = []
    for i in range(200 if not smoke_test else 2):
        loss = svi.step(data)
        losses.append(loss)
        print('.' if i % 100 else '\n', end='')
    print()

    map_estimates = global_guide(data)
    weights = map_estimates['weights']
    locs = map_estimates['locs']
    scale = map_estimates['scale']
    print('weights = {}'.format(weights.data.numpy()))
    print('locs = {}'.format(locs.data.numpy()))
    print('scale = {}'.format(scale.data.numpy()))

    guide_trace = poutine.trace(global_guide).get_trace(
        data)  # record the globals
    trained_model = poutine.replay(
        model, trace=guide_trace)  # replay the globals

    def classifier(data, temperature=0):
        inferred_model = infer_discrete(
            trained_model,
            temperature=temperature,
            first_available_dim=-
            2)  # avoid conflict with data plate
        trace = poutine.trace(inferred_model).get_trace(data)
        return trace.nodes["assignment"]["value"]

    assignment = classifier(data)

    pyplot.figure(figsize=(8, 2), dpi=100).set_facecolor('white')
    pyplot.plot(data.numpy(), assignment.numpy(), 'bx')
    pyplot.title('MAP assignment')
    pyplot.xlabel('Latent posterior sample value')
    pyplot.ylabel('class assignment')

    return assignment
Ejemplo n.º 37
0
def test_particle_gradient(Elbo, reparameterized, has_rsample):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    def model():
        with pyro.plate("data", len(data)) as ind:
            x = data[ind]
            z = pyro.sample("z", Normal(0, 1))
            pyro.sample("x", Normal(z, 1), obs=x)

    def guide():
        scale = pyro.param("scale", lambda: torch.tensor([1.0]))
        with pyro.plate("data", len(data)):
            loc = pyro.param("loc",
                             lambda: torch.zeros(len(data)),
                             event_dim=0)
            z_dist = Normal(loc, scale)
            if has_rsample is not None:
                z_dist.has_rsample_(has_rsample)
            pyro.sample("z", z_dist)

    elbo = Elbo(
        max_plate_nesting=1,  # set this to ensure rng agrees across runs
        num_particles=1,
        strict_enumeration_warning=False,
    )

    # Elbo gradient estimator
    pyro.set_rng_seed(0)
    elbo.loss_and_grads(model, guide)
    params = dict(pyro.get_param_store().named_parameters())
    actual_grads = {
        name: param.grad.detach().cpu()
        for name, param in params.items()
    }

    # capture sample values and log_probs
    pyro.set_rng_seed(0)
    guide_tr = poutine.trace(guide).get_trace()
    model_tr = poutine.trace(poutine.replay(model, guide_tr)).get_trace()
    guide_tr.compute_log_prob()
    model_tr.compute_log_prob()
    x = data
    z = guide_tr.nodes["z"]["value"].data
    loc = pyro.param("loc").data
    scale = pyro.param("scale").data

    # expected grads
    if reparameterized and has_rsample is not False:
        # pathwise gradient estimator
        expected_grads = {
            "scale":
            -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum(0, keepdim=True) /
            scale,
            "loc":
            -(-z + (x - z)),
        }
    else:
        # score function gradient estimator
        elbo = (model_tr.nodes["x"]["log_prob"].data +
                model_tr.nodes["z"]["log_prob"].data -
                guide_tr.nodes["z"]["log_prob"].data)
        dlogq_dloc = (z - loc) / scale**2
        dlogq_dscale = (z - loc)**2 / scale**3 - 1 / scale
        if Elbo is TraceEnum_ELBO:
            expected_grads = {
                "scale":
                -(dlogq_dscale * elbo - dlogq_dscale).sum(0, keepdim=True),
                "loc": -(dlogq_dloc * elbo - dlogq_dloc),
            }
        elif Elbo is Trace_ELBO:
            # expected value of dlogq_dscale and dlogq_dloc is zero
            expected_grads = {
                "scale": -(dlogq_dscale * elbo).sum(0, keepdim=True),
                "loc": -(dlogq_dloc * elbo),
            }

    for name in sorted(params):
        logger.info("expected {} = {}".format(name, expected_grads[name]))
        logger.info("actual   {} = {}".format(name, actual_grads[name]))

    assert_equal(actual_grads, expected_grads, prec=1e-4)
Ejemplo n.º 38
0
    percentile = pyro.sample("var3", dist.Uniform(0, 1))
    if (percentile > 0.95):
        GPA = 4

    else:
        GPA = pyro.sample("var4", dist.Normal(2.75, 0.5))

    if (GPA == 4):
        Interviews = dist.Binomial(Recruiters, 0.9).sample()

    if (GPA < 4):
        Interviews = dist.Binomial(Recruiters, 0.6).sample()

    for n in range(1, 2):
        with pyro.iarange("data"):
            pyro.sample("obs",
                        dist.Binomial(Interviews, 0.4),
                        obs=data['offers'][n])


guide = ag.AutoDiagonalNormal(model)
pyro.clear_param_store()
optim = Adam({'lr': 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
for i in range(1000):
    loss = svi.step(data)
    if ((i % 100) == 0):
        print(loss)
for name in pyro.get_param_store().get_all_param_names():
    print(name, pyro.param(name).data.numpy())
Ejemplo n.º 39
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # get info regarding rao-blackwellization of vectorized map_data
        guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"]
        model_vec_md_info = model_trace.graph["vectorized_map_data_info"]
        guide_vec_md_condition = guide_vec_md_info[
            'rao-blackwellization-condition']
        model_vec_md_condition = model_vec_md_info[
            'rao-blackwellization-condition']
        do_vec_rb = guide_vec_md_condition and model_vec_md_condition
        if not do_vec_rb:
            warnings.warn(
                "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. "
                "Falling back to higher-variance gradient estimator. "
                "Try to avoid these issues in your model and guide:\n{}".
                format("\n".join(guide_vec_md_info["warnings"]
                                 | model_vec_md_info["warnings"])))
        guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set()
        model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set()

        # have the trace compute all the individual (batch) log pdf terms
        # so that they are available below
        guide_trace.compute_batch_log_pdf(
            site_filter=lambda name, site: name in guide_vec_md_nodes)
        guide_trace.log_pdf()
        model_trace.compute_batch_log_pdf(
            site_filter=lambda name, site: name in model_vec_md_nodes)
        model_trace.log_pdf()

        # prepare a list of all the cost nodes, each of which is +- log_pdf
        cost_nodes = []
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    cost_nodes.append(CostNode(model_site["log_pdf"], True))
                else:
                    # cost node from model sample
                    cost_nodes.append(CostNode(model_site["log_pdf"], True))
                    # cost node from guide sample
                    guide_site = guide_trace.nodes[name]
                    zero_expectation = name in non_reparam_nodes
                    cost_nodes.append(
                        CostNode(-guide_site["log_pdf"], not zero_expectation))

        # compute the elbo; if all stochastic nodes are reparameterizable, we're done
        # this bit is never differentiated: it's here for getting an estimate of the elbo itself
        elbo = torch_data_sum(sum(c.cost for c in cost_nodes))

        # compute the surrogate elbo, removing terms whose gradient is zero
        # this is the bit that's actually differentiated
        # XXX should the user be able to control if these terms are included?
        surrogate_elbo = sum(c.cost for c in cost_nodes
                             if c.nonzero_expectation)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:

            # recursively compute downstream cost nodes for all sample sites in model and guide
            # (even though ultimately just need for non-reparameterizable sample sites)
            # 1. downstream costs used for rao-blackwellization
            # 2. model observe sites (as well as terms that arise from the model and guide having different
            # dependency structures) are taken care of via 'children_in_model' below
            topo_sort_guide_nodes = list(
                reversed(list(networkx.topological_sort(guide_trace))))
            topo_sort_guide_nodes = [
                x for x in topo_sort_guide_nodes
                if guide_trace.nodes[x]["type"] == "sample"
            ]
            downstream_guide_cost_nodes = {}
            downstream_costs = {}

            for node in topo_sort_guide_nodes:
                node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \
                    guide_trace.nodes[node][node_log_pdf_key]
                nodes_included_in_sum = set([node])
                downstream_guide_cost_nodes[node] = set([node])
                for child in guide_trace.successors(node):
                    child_cost_nodes = downstream_guide_cost_nodes[child]
                    downstream_guide_cost_nodes[node].update(child_cost_nodes)
                    if nodes_included_in_sum.isdisjoint(
                            child_cost_nodes):  # avoid duplicates
                        if node_log_pdf_key == 'log_pdf':
                            downstream_costs[node] += downstream_costs[
                                child].sum()
                        else:
                            downstream_costs[node] += downstream_costs[child]
                        nodes_included_in_sum.update(child_cost_nodes)
                missing_downstream_costs = downstream_guide_cost_nodes[
                    node] - nodes_included_in_sum
                # include terms we missed because we had to avoid duplicates
                for missing_node in missing_downstream_costs:
                    mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf'
                    if node_log_pdf_key == 'log_pdf':
                        downstream_costs[node] += (
                            model_trace.nodes[missing_node][mn_log_pdf_key] -
                            guide_trace.nodes[missing_node][mn_log_pdf_key]
                        ).sum()
                    else:
                        downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \
                                                  guide_trace.nodes[missing_node][mn_log_pdf_key]

            # finish assembling complete downstream costs
            # (the above computation may be missing terms from model)
            # XXX can we cache some of the sums over children_in_model to make things more efficient?
            for site in non_reparam_nodes:
                children_in_model = set()
                for node in downstream_guide_cost_nodes[site]:
                    children_in_model.update(model_trace.successors(node))
                # remove terms accounted for above
                children_in_model.difference_update(
                    downstream_guide_cost_nodes[site])
                for child in children_in_model:
                    child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf'
                    site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf'
                    assert (model_trace.nodes[child]["type"] == "sample")
                    if site_log_pdf_key == 'log_pdf':
                        downstream_costs[site] += model_trace.nodes[child][
                            child_log_pdf_key].sum()
                    else:
                        downstream_costs[site] += model_trace.nodes[child][
                            child_log_pdf_key]

            # construct all the reinforce-like terms.
            # we include only downstream costs to reduce variance
            # optionally include baselines to further reduce variance
            # XXX should the average baseline be in the param store as below?
            elbo_reinforce_terms = 0.0
            for node in non_reparam_nodes:
                guide_site = guide_trace.nodes[node]
                log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                downstream_cost = downstream_costs[node]
                baseline = 0.0
                (nn_baseline, nn_baseline_input, use_decaying_avg_baseline,
                 baseline_beta,
                 baseline_value) = _get_baseline_options(guide_site)
                use_nn_baseline = nn_baseline is not None
                use_baseline_value = baseline_value is not None
                assert(not (use_nn_baseline and use_baseline_value)), \
                    "cannot use baseline_value and nn_baseline simultaneously"
                if use_decaying_avg_baseline:
                    avg_downstream_cost_old = pyro.param(
                        "__baseline_avg_downstream_cost_" + node,
                        ng_zeros(1),
                        tags="__tracegraph_elbo_internal_tag")
                    avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
                        baseline_beta * avg_downstream_cost_old
                    avg_downstream_cost_old.data = avg_downstream_cost_new.data  # XXX copy_() ?
                    baseline += avg_downstream_cost_old
                if use_nn_baseline:
                    # block nn_baseline_input gradients except in baseline loss
                    baseline += nn_baseline(detach_iterable(nn_baseline_input))
                elif use_baseline_value:
                    # it's on the user to make sure baseline_value tape only points to baseline params
                    baseline += baseline_value
                if use_nn_baseline or use_baseline_value:
                    # accumulate baseline loss
                    baseline_loss += torch.pow(
                        downstream_cost.detach() - baseline, 2.0).sum()

                guide_log_pdf = guide_site[log_pdf_key] / guide_site[
                    "scale"]  # not scaled by subsampling
                if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
                    if downstream_cost.size() != baseline.size():
                        raise ValueError(
                            "Expected baseline at site {} to be {} instead got {}"
                            .format(node, downstream_cost.size(),
                                    baseline.size()))
                    downstream_cost = downstream_cost - baseline
                elbo_reinforce_terms += (guide_log_pdf *
                                         downstream_cost.detach()).sum()

            surrogate_elbo += elbo_reinforce_terms

        # collect parameters to train from model and guide
        trainable_params = set(site["value"]
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values()
                               if site["type"] == "param")

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))
            pyro.get_param_store().mark_params_active(trainable_params)

        loss = -elbo
        return weight * loss
Ejemplo n.º 40
0
        print("Saving")
        save_path = "../raw-results/"
        #save_path = "/afs/cs.stanford.edu/u/mhahn/scr/deps/"
        with open(
                save_path + "/manual_output_ground_coarse/" + args.language +
                "_" + __file__ + "_model_" + str(myID) + ".tsv",
                "w") as outFile:
            print("\t".join(
                list(
                    map(str, [
                        "Counter", "Document", "DH_Mean_NoPunct",
                        "DH_Sigma_NoPunct", "Distance_Mean_NoPunct",
                        "Distance_Sigma_NoPunct", "Dependency"
                    ]))),
                  file=outFile)
            dh_numpy = pyro.get_param_store().get_param("mu_DH").data.numpy()
            dh_sigma_numpy = pyro.get_param_store().get_param(
                "sigma_DH").data.numpy()
            dist_numpy = pyro.get_param_store().get_param(
                "mu_Dist").data.numpy()
            dist_sigma_numpy = pyro.get_param_store().get_param(
                "sigma_Dist").data.numpy()

            for i in range(len(itos_deps)):
                key = itos_deps[i]
                dependency = key
                for doc in range(len(itos_docs)):
                    print("\t".join(
                        list(
                            map(str, [
                                counter, itos_docs[doc], dh_numpy[doc, i],
Ejemplo n.º 41
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({'lr': args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                         )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
                    strict_enumeration_warning=(model is not model_7),
                    jit_options={"time_compilation": args.time_compilation})
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug('time to compile: {} s.'.format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('{} capacity = {} parameters'.format(model.__name__,
                                                      capacity))
data, labels = create_simple_classification_dataset(num_schedules)
schedule_starts = np.linspace(0, 20 * (num_schedules-1), num=num_schedules)
not_first_time = False
distributions = [np.array([.5, .1], dtype=float) for _ in range(num_schedules)]  # each one is mean, sigma

print('Inference')
for epoch in range(num_epochs):
    # for j, (imgs, lbls) in enumerate(train_loader, 0):
    #     loss = inference.step(imgs.to(device), lbls.to(device))
    for _ in range(num_schedules):
        x_data = []
        y_data = []
        chosen_schedule_start = int(np.random.choice(schedule_starts))
        schedule_num = int(chosen_schedule_start / 20)
        if not_first_time:
            pyro.get_param_store().get_state()['params']['emm']=Variable(torch.Tensor([distributions[schedule_num][0]]),requires_grad=True)
            pyro.get_param_store().get_state()['params']['ems']=Variable(torch.Tensor([distributions[schedule_num][1]]),requires_grad=True)
            # print(pyro.get_param_store().get_state()['params']['emm'])
        else:
            not_first_time = True
        for each_t in range(chosen_schedule_start, chosen_schedule_start + 20):
            x = data[each_t][2:]
            x_data.append(x)
            # noinspection PyArgumentList
            x = torch.Tensor([x]).reshape((2))

            label = labels[each_t]
            y_data.append(label)
            # noinspection PyArgumentList
            label = torch.Tensor([label]).reshape(1)
            label = Variable(label).long()
Ejemplo n.º 43
0
 def load_model(self, filename):
     pyro.get_param_store().load(filename)
Ejemplo n.º 44
0
# In[20]:

adam_params = {"lr": 0.001, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 1000
# do gradient steps
for step in range(n_steps):
    svi.step(x, y)

# In[21]:

for name in pyro.get_param_store():
    print(name + ':{}'.format(pyro.param(name)))

# In[22]:

y_pred = Predictive(model=model,
                    guide=guide,
                    num_samples=1000,
                    return_sites=["y"])

# In[23]:

x_ = torch.tensor(np.linspace(-2, 2, 100))
y_ = y_pred.get_samples(x_, None)

# In[ ]:
Ejemplo n.º 45
0
    def forward(self, inputs, n_samples=10, avg_posterior=False, seeds=None):

        if seeds:
            if len(seeds) != n_samples:
                raise ValueError(
                    "Number of seeds should match number of samples.")

        if self.inference == "svi":

            if avg_posterior is True:

                guide_trace = poutine.trace(self.guide).get_trace(inputs)

                avg_state_dict = {}
                for key in self.basenet.state_dict().keys():
                    avg_weights = guide_trace.nodes[str(key) + "_loc"]['value']
                    avg_state_dict.update({str(key): avg_weights})

                self.basenet.load_state_dict(avg_state_dict)
                preds = [self.basenet.model(inputs)]

            else:

                preds = []

                if seeds:
                    for seed in seeds:
                        pyro.set_rng_seed(seed)
                        guide_trace = poutine.trace(
                            self.guide).get_trace(inputs)
                        preds.append(guide_trace.nodes['_RETURN']['value'])

                else:

                    for _ in range(n_samples):
                        guide_trace = poutine.trace(
                            self.guide).get_trace(inputs)
                        preds.append(guide_trace.nodes['_RETURN']['value'])

                if DEBUG:
                    print("\nlearned variational params:\n")
                    print(pyro.get_param_store().get_all_param_names())
                    print(
                        list(
                            poutine.trace(
                                self.guide).get_trace(inputs).nodes.keys()))
                    print("\n",
                          pyro.get_param_store()["model.0.weight_loc"][0][:5])
                    print(guide_trace.nodes['module$$$model.0.weight']
                          ["fn"].loc[0][:5])
                    print(
                        "posterior sample: ",
                        guide_trace.nodes['module$$$model.0.weight']['value']
                        [5][0][0])

        elif self.inference == "hmc":

            preds = []
            posterior_predictive = list(self.posterior_predictive.values())

            if seeds is None:
                seeds = range(n_samples)

            for seed in seeds:
                net = posterior_predictive[seed]
                preds.append(net.forward(inputs))

        output_probs = torch.stack(preds).mean(0)
        return output_probs
Ejemplo n.º 46
0
def get_count_matrix_from_encodings(z: np.ndarray,
                                    d: np.ndarray,
                                    p: Union[np.ndarray, None],
                                    model: VariationalInferenceModel,
                                    dataset_obj,
                                    cells_only: bool = True) -> sp.csc.csc_matrix:
    """Make point estimate of the ambient-background-subtracted UMI count matrix.

    Sample counts by maximizing the model posterior based on learned latent
    variables.  The output matrix is in sparse form.

    Args:
        z: Latent variable embedding of gene expression in a low-dimensional
            space.
        d: Latent variable scale factor for the number of UMI counts coming
            from each real cell.
        p: Latent variable denoting probability that each barcode contains a
            real cell.
        model: Model with latent variables already inferred.
        dataset_obj: Input dataset.
        cells_only: If True, only returns the encodings of barcodes that are
            determined to contain cells.

    Returns:
        inferred_count_matrix: Matrix of the same dimensions as the input
            matrix, but where the UMI counts have had ambient-background
            subtracted.

    Note:
        This currently uses the MAP estimate of draws from a Poisson (or a
        negative binomial with zero overdispersion).

    """

    # If simple model was used, then p = None.  Here set it to 1.
    if p is None:
        p = np.ones_like(d)

    # Get the count matrix with genes trimmed.
    if cells_only:
        count_matrix = dataset_obj.get_count_matrix()
    else:
        count_matrix = dataset_obj.get_count_matrix_all_barcodes()

    logging.info("Getting ambient-background-subtracted UMI count matrix.")

    # Ensure there are no nans in p (there shouldn't be).
    p_no_nans = p
    p_no_nans[np.isnan(p)] = 0  # Just make sure there are no nans.

    # Trim everything down to the barcodes we are interested in (just cells?).
    if cells_only:
        d = d[p_no_nans > 0.5]
        z = z[p_no_nans > 0.5, :]
        barcode_inds = dataset_obj.analyzed_barcode_inds[p_no_nans > 0.5]
    else:
        # Set cell size factors equal to zero where cell probability < 0.5.
        d[p_no_nans < 0.5] = 0.
        z[p_no_nans < 0.5, :] = 0.
        barcode_inds = np.arange(0, count_matrix.shape[0])  # All barcodes

    # Get mean of the inferred posterior for the overdispersion, phi.
    phi = pyro.get_param_store().get_param("phi_loc").detach().cpu().numpy().item()

    # Get the gene expression vectors by sending latent z through the decoder.
    # Send dataset through the learned encoder in chunks.
    barcodes = []
    genes = []
    counts = []
    s = 200
    for i in np.arange(0, barcode_inds.size, s):

        # TODO: for 117000 cells, this routine overflows (~15GB) memory

        last_ind_this_chunk = min(count_matrix.shape[0], i+s)

        # Decode gene expression for a chunk of barcodes.
        decoded = model.decoder(torch.Tensor(
            z[i:last_ind_this_chunk]).to(device=model.device))
        chi = decoded.detach().cpu().numpy()

        # Estimate counts for the chunk of barcodes.
        chunk_dense_counts = estimate_counts(chi,
                                             d[i:last_ind_this_chunk],
                                             phi)

        # Turn the floating point count estimates into integers.
        decimal_values, _ = np.modf(chunk_dense_counts)  # Stuff after decimal.
        roundoff_counts = np.random.binomial(1, p=decimal_values)  # Bernoulli.
        chunk_dense_counts = np.floor(chunk_dense_counts).astype(dtype=int)
        chunk_dense_counts += roundoff_counts

        # Find all the nonzero counts in this dense matrix chunk.
        nonzero_barcode_inds_this_chunk, nonzero_genes_trimmed = \
            np.nonzero(chunk_dense_counts)
        nonzero_counts = \
            chunk_dense_counts[nonzero_barcode_inds_this_chunk,
                               nonzero_genes_trimmed].flatten(order='C')

        # Get the original gene index from gene index in the trimmed dataset.
        nonzero_genes = dataset_obj.analyzed_gene_inds[nonzero_genes_trimmed]

        # Get the actual barcode values.
        nonzero_barcode_inds = nonzero_barcode_inds_this_chunk + i
        nonzero_barcodes = barcode_inds[nonzero_barcode_inds]

        # Append these to their lists.
        barcodes.extend(nonzero_barcodes.astype(dtype=np.uint32))
        genes.extend(nonzero_genes.astype(dtype=np.uint16))
        counts.extend(nonzero_counts.astype(dtype=np.uint32))

    # Convert the lists to numpy arrays.
    counts = np.array(counts, dtype=np.uint32)
    barcodes = np.array(barcodes, dtype=np.uint32)
    genes = np.array(genes, dtype=np.uint16)

    # Put the counts into a sparse csc_matrix.
    inferred_count_matrix = sp.csc_matrix((counts, (barcodes, genes)),
                                          shape=dataset_obj.data['matrix'].shape)

    return inferred_count_matrix
Ejemplo n.º 47
0
def get_param(name):
    return pyro.get_param_store()[name]
Ejemplo n.º 48
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        """
        elbo = 0.0
        surrogate_elbo = 0.0
        trainable_params = set()
        # grab a trace from the generator
        for weight, model_trace, guide_trace, log_r in self._get_traces(
                model, guide, *args, **kwargs):
            elbo_particle = weight * 0
            surrogate_elbo_particle = weight * 0
            # compute elbo and surrogate elbo
            log_pdf = "batch_log_pdf" if (
                self.enum_discrete and weight.size(0) > 1) else "log_pdf"
            for name in model_trace.nodes.keys():
                if model_trace.nodes[name]["type"] == "sample":
                    if model_trace.nodes[name]["is_observed"]:
                        elbo_particle += model_trace.nodes[name][log_pdf]
                        surrogate_elbo_particle += model_trace.nodes[name][
                            log_pdf]
                    else:
                        lp_lq = model_trace.nodes[name][
                            log_pdf] - guide_trace.nodes[name][log_pdf]
                        elbo_particle += lp_lq
                        if guide_trace.nodes[name]["fn"].reparameterized:
                            surrogate_elbo_particle += lp_lq
                        else:
                            # XXX should the user be able to control inclusion of the -logq term below?
                            surrogate_elbo_particle += model_trace.nodes[name][log_pdf] + \
                                log_r.detach() * guide_trace.nodes[name][log_pdf]

            # drop terms of weight zero to avoid nans
            if isinstance(weight, numbers.Number):
                if weight == 0.0:
                    elbo_particle = torch_zeros_like(elbo_particle)
                    surrogate_elbo_particle = torch_zeros_like(
                        surrogate_elbo_particle)
            else:
                weight_eq_zero = (weight == 0)
                elbo_particle[weight_eq_zero] = 0.0
                surrogate_elbo_particle[weight_eq_zero] = 0.0

            elbo += torch_data_sum(weight * elbo_particle)
            surrogate_elbo += torch_sum(weight * surrogate_elbo_particle)

            # grab model parameters to train
            for name in model_trace.nodes.keys():
                if model_trace.nodes[name]["type"] == "param":
                    trainable_params.add(model_trace.nodes[name]["value"])

            # grab guide parameters to train
            for name in guide_trace.nodes.keys():
                if guide_trace.nodes[name]["type"] == "param":
                    trainable_params.add(guide_trace.nodes[name]["value"])

        loss = -elbo
        surrogate_loss = -surrogate_elbo
        if trainable_params:
            torch_backward(surrogate_loss)

        pyro.get_param_store().mark_params_active(trainable_params)

        return loss
Ejemplo n.º 49
0
 def loss_and_grads(model, guide, *args, **kwargs):
     _loss = self._loss(model, guide, *args, **kwargs)
     _loss.backward()
     pyro.get_param_store().mark_params_active(pyro.get_param_store().get_all_param_names())
     return _loss