def step(self, *args, **kwargs): """ :returns: estimate of the loss :rtype: float Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by `loss_and_grads`). Any args or kwargs are passed to the model and guide """ # get loss and compute gradients loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs) # get active params params = pyro.get_param_store().get_active_params() # actually perform gradient steps # torch.optim objects gets instantiated for any params that haven't been seen yet self.optim(params) # zero gradients pyro.util.zero_grads(params) # mark parameters in the param store as inactive pyro.get_param_store().mark_params_inactive(params) return loss
def test_hmc_conjugate_gaussian(fixture, num_samples, warmup_steps, hmc_params, expected_means, expected_precs, mean_tol, std_tol): pyro.get_param_store().clear() hmc_kernel = HMC(fixture.model, **hmc_params) mcmc_run = MCMC(hmc_kernel, num_samples, warmup_steps).run(fixture.data) for i in range(1, fixture.chain_len + 1): param_name = 'loc_' + str(i) marginal = EmpiricalMarginal(mcmc_run, sites=param_name) latent_loc = marginal.mean latent_std = marginal.variance.sqrt() expected_mean = torch.ones(fixture.dim) * expected_means[i - 1] expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1]) # Actual vs expected posterior means for the latents logger.info('Posterior mean (actual) - {}'.format(param_name)) logger.info(latent_loc) logger.info('Posterior mean (expected) - {}'.format(param_name)) logger.info(expected_mean) assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) # Actual vs expected posterior precisions for the latents logger.info('Posterior std (actual) - {}'.format(param_name)) logger.info(latent_std) logger.info('Posterior std (expected) - {}'.format(param_name)) logger.info(expected_std) assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
def test_module_nn(nn_module): pyro.clear_param_store() nn_module = nn_module() assert pyro.get_param_store()._params == {} pyro.module("module", nn_module) for name in pyro.get_param_store().get_all_param_names(): assert pyro.params.user_param_name(name) in nn_module.state_dict().keys()
def test_bern_elbo_gradient(enum_discrete, trace_graph): pyro.clear_param_store() num_particles = 2000 def model(): p = Variable(torch.Tensor([0.25])) pyro.sample("z", dist.Bernoulli(p)) def guide(): p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True)) pyro.sample("z", dist.Bernoulli(p)) print("Computing gradients using surrogate loss") Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO elbo = Elbo(enum_discrete=enum_discrete, num_particles=(1 if enum_discrete else num_particles)) with xfail_if_not_implemented(): elbo.loss_and_grads(model, guide) params = sorted(pyro.get_param_store().get_all_param_names()) assert params, "no params found" actual_grads = {name: pyro.param(name).grad.clone() for name in params} print("Computing gradients using finite difference") elbo = Trace_ELBO(num_particles=num_particles) expected_grads = finite_difference(lambda: elbo.loss(model, guide)) for name in params: print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data, expected_grads[name].data)) assert_equal(actual_grads, expected_grads, prec=0.1)
def __call__(self, params, *args, **kwargs): """ :param params: a list of parameters :type params: an iterable of strings Do an optimization step for each param in params. If a given param has never been seen before, initialize an optimizer for it. """ for p in params: # if we have not seen this param before, we instantiate and optim object to deal with it if p not in self.optim_objs: # get our constructor arguments def_optim_dict = self._get_optim_args(p) # create a single optim object for that param self.optim_objs[p] = self.pt_optim_constructor([p], **def_optim_dict) # set state from _state_waiting_to_be_consumed if present param_name = pyro.get_param_store().param_name(p) if param_name in self._state_waiting_to_be_consumed: state = self._state_waiting_to_be_consumed.pop(param_name) self.optim_objs[p].load_state_dict(state) # actually perform the step for the optim object self.optim_objs[p].step(*args, **kwargs)
def main(args): pyro.set_rng_seed(0) pyro.enable_validation() optim = Adam({"lr": 0.1}) inference = SVI(model, guide, optim, loss=Trace_ELBO()) # Data is an arbitrary json-like structure with tensors at leaves. one = torch.tensor(1.0) data = { "foo": one, "bar": [0 * one, 1 * one, 2 * one], "baz": { "noun": { "concrete": 4 * one, "abstract": 6 * one, }, "verb": 2 * one, }, } print('Step\tLoss') loss = 0.0 for step in range(args.num_epochs): loss += inference.step(data) if step and step % 10 == 0: print('{}\t{:0.5g}'.format(step, loss)) loss = 0.0 print('Parameters:') for name in sorted(pyro.get_param_store().get_all_param_names()): print('{} = {}'.format(name, pyro.param(name).detach().cpu().numpy()))
def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs): # construct all the reinforce-like terms. # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance # XXX should the average baseline be in the param store as below? surrogate_elbo = 0.0 baseline_loss = 0.0 for node in non_reparam_nodes: guide_site = guide_trace.nodes[node] downstream_cost = downstream_costs[node] baseline = 0.0 (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None assert(not (use_nn_baseline and use_baseline_value)), \ "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: dc_shape = downstream_cost.shape param_name = "__baseline_avg_downstream_cost_" + node with torch.no_grad(): avg_downstream_cost_old = pyro.param(param_name, guide_site['value'].new_zeros(dc_shape)) avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old pyro.get_param_store().replace_param(param_name, avg_downstream_cost_new, avg_downstream_cost_old) baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline(detach_iterable(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # accumulate baseline loss baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum() score_function_term = guide_site["score_parts"].score_function if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value: if downstream_cost.shape != baseline.shape: raise ValueError("Expected baseline at site {} to be {} instead got {}".format( node, downstream_cost.shape, baseline.shape)) downstream_cost = downstream_cost - baseline surrogate_elbo += (score_function_term * downstream_cost.detach()).sum() return surrogate_elbo, baseline_loss
def test_elbo_hmm_in_guide(enumerate1, num_steps): pyro.clear_param_store() data = torch.ones(num_steps) init_probs = torch.tensor([0.5, 0.5]) def model(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) emission_probs = pyro.param("emission_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y) @config_enumerate(default=enumerate1) def guide(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) elbo = TraceEnum_ELBO(max_iarange_nesting=0) elbo.loss_and_grads(model, guide, data) # These golden values simply test agreement between parallel and sequential. expected_grads = { 2: { "transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]], "emission_probs": [[0.75, -0.75], [0.25, -0.25]], }, 3: { "transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]], "emission_probs": [[1.125, -1.125], [0.375, -0.375]], }, 10: { "transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]], "emission_probs": [[3.75, -3.75], [1.25, -1.25]], }, 20: { "transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]], "emission_probs": [[7.5, -7.5], [2.5, -2.5]], }, } for name, value in pyro.get_param_store().named_parameters(): actual = value.grad expected = torch.tensor(expected_grads[num_steps][name]) assert_equal(actual, expected, msg=''.join([ '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()), '\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()), ]))
def get_state(self): """ Get state associated with all the optimizers in the form of a dictionary with key-value pairs (parameter name, optim state dicts) """ state_dict = {} for param in self.optim_objs: param_name = pyro.get_param_store().param_name(param) state_dict[param_name] = self.optim_objs[param].state_dict() return state_dict
def test_save_and_load(self): lin = pyro.module("mymodule", self.linear_module) pyro.module("mymodule2", self.linear_module2) x = torch.randn(1, 3) myparam = pyro.param("myparam", torch.tensor(1.234 * torch.ones(1), requires_grad=True)) cost = torch.sum(torch.pow(lin(x), 2.0)) * torch.pow(myparam, 4.0) cost.backward() params = list(self.linear_module.parameters()) + [myparam] optim = torch.optim.Adam(params, lr=.01) myparam_copy_stale = copy(pyro.param("myparam").detach().cpu().numpy()) optim.step() myparam_copy = copy(pyro.param("myparam").detach().cpu().numpy()) param_store_params = copy(pyro.get_param_store()._params) param_store_param_to_name = copy(pyro.get_param_store()._param_to_name) assert len(list(param_store_params.keys())) == 5 assert len(list(param_store_param_to_name.values())) == 5 pyro.get_param_store().save('paramstore.unittest.out') pyro.clear_param_store() assert len(list(pyro.get_param_store()._params)) == 0 assert len(list(pyro.get_param_store()._param_to_name)) == 0 pyro.get_param_store().load('paramstore.unittest.out') def modules_are_equal(): weights_equal = np.sum(np.fabs(self.linear_module3.weight.detach().cpu().numpy() - self.linear_module.weight.detach().cpu().numpy())) == 0.0 bias_equal = np.sum(np.fabs(self.linear_module3.bias.detach().cpu().numpy() - self.linear_module.bias.detach().cpu().numpy())) == 0.0 return (weights_equal and bias_equal) assert not modules_are_equal() pyro.module("mymodule", self.linear_module3, update_module_params=False) assert id(self.linear_module3.weight) != id(pyro.param('mymodule$$$weight')) assert not modules_are_equal() pyro.module("mymodule", self.linear_module3, update_module_params=True) assert id(self.linear_module3.weight) == id(pyro.param('mymodule$$$weight')) assert modules_are_equal() myparam = pyro.param("myparam") store = pyro.get_param_store() assert myparam_copy_stale != myparam.detach().cpu().numpy() assert myparam_copy == myparam.detach().cpu().numpy() assert sorted(param_store_params.keys()) == sorted(store._params.keys()) assert sorted(param_store_param_to_name.values()) == sorted(store._param_to_name.values()) assert sorted(store._params.keys()) == sorted(store._param_to_name.values())
def _get_optim_args(self, param): # if we were passed a fct, we call fct with param info # arguments are (module name, param name, tags) e.g. ('mymodule', 'bias', 'baseline') if callable(self.pt_optim_args): # get param name param_name = pyro.get_param_store().param_name(param) module_name = module_from_param_with_module_name(param_name) stripped_param_name = user_param_name(param_name) # get tags tags = pyro.get_param_store().get_param_tags(param_name) # invoke the user-provided callable opt_dict = self.pt_optim_args(module_name, stripped_param_name, tags) # must be dictionary assert isinstance(opt_dict, dict), "per-param optim arg must return defaults dictionary" return opt_dict else: return self.pt_optim_args
def finite_difference(eval_loss, delta=0.1): """ Computes finite-difference approximation of all parameters. """ params = pyro.get_param_store().get_all_param_names() assert params, "no params found" grads = {name: Variable(torch.zeros(pyro.param(name).size())) for name in params} for name in sorted(params): value = pyro.param(name).data for index in itertools.product(*map(range, value.size())): center = value[index] value[index] = center + delta pos = eval_loss() value[index] = center - delta neg = eval_loss() value[index] = center grads[name][index] = (pos - neg) / (2 * delta) return grads
def test_iarange(Elbo, reparameterized): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) num_particles = 20000 precision = 0.06 Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal @poutine.broadcast def model(): particles_iarange = pyro.iarange("particles", num_particles, dim=-2) data_iarange = pyro.iarange("data", len(data), dim=-1) pyro.sample("nuisance_a", Normal(0, 1)) with particles_iarange, data_iarange: z = pyro.sample("z", Normal(0, 1)) pyro.sample("nuisance_b", Normal(2, 3)) with data_iarange, particles_iarange: pyro.sample("x", Normal(z, 1), obs=data) pyro.sample("nuisance_c", Normal(4, 5)) @poutine.broadcast def guide(): loc = pyro.param("loc", torch.zeros(len(data))) scale = pyro.param("scale", torch.tensor([1.])) pyro.sample("nuisance_c", Normal(4, 5)) with pyro.iarange("particles", num_particles, dim=-2): with pyro.iarange("data", len(data), dim=-1): pyro.sample("z", Normal(loc, scale)) pyro.sample("nuisance_b", Normal(2, 3)) pyro.sample("nuisance_a", Normal(0, 1)) optim = Adam({"lr": 0.1}) inference = SVI(model, guide, optim, loss=Elbo(strict_enumeration_warning=False)) inference.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = {name: param.grad.detach().cpu().numpy() / num_particles for name, param in params.items()} expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} for name in sorted(params): logger.info('expected {} = {}'.format(name, expected_grads[name])) logger.info('actual {} = {}'.format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision)
def test_elbo_hmm_in_model(enumerate1, num_steps): pyro.clear_param_store() data = torch.ones(num_steps) init_probs = torch.tensor([0.5, 0.5]) def model(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.9, 0.1], [0.1, 0.9]]), constraint=constraints.simplex) locs = pyro.param("obs_locs", torch.tensor([-1.0, 1.0])) scale = pyro.param("obs_scale", torch.tensor(1.0), constraint=constraints.positive) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Normal(locs[x], scale), obs=y) @config_enumerate(default=enumerate1) def guide(data): mean_field_probs = pyro.param("mean_field_probs", torch.ones(num_steps, 2) / 2, constraint=constraints.simplex) for i in range(num_steps): pyro.sample("x_{}".format(i), dist.Categorical(mean_field_probs[i])) elbo = TraceEnum_ELBO(max_iarange_nesting=0) elbo.loss_and_grads(model, guide, data) expected_unconstrained_grads = { "transition_probs": torch.tensor([[0.2, -0.2], [-0.2, 0.2]]) * (num_steps - 1), "obs_locs": torch.tensor([-num_steps, 0]), "obs_scale": torch.tensor(-num_steps), "mean_field_probs": torch.tensor([[0.5, -0.5]] * num_steps), } for name, value in pyro.get_param_store().named_parameters(): actual = value.grad expected = expected_unconstrained_grads[name] assert_equal(actual, expected, msg=''.join([ '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()), '\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()), ]))
def main(args): pyro.set_rng_seed(0) pyro.enable_validation() optim = Adam({"lr": 0.1}) inference = SVI(model, guide, optim, loss=Trace_ELBO()) data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0]) k = 2 print('Step\tLoss') loss = 0.0 for step in range(args.num_epochs): if step and step % 10 == 0: print('{}\t{:0.5g}'.format(step, loss)) loss = 0.0 loss += inference.step(data, k) print('Parameters:') for name in sorted(pyro.get_param_store().get_all_param_names()): print('{} = {}'.format(name, pyro.param(name).detach().cpu().numpy()))
def test_subsample_gradient(trace_graph, reparameterized): pyro.clear_param_store() data_size = 2 subsample_size = 1 num_particles = 1000 precision = 0.333 data = dist.normal(ng_zeros(data_size), ng_ones(data_size)) def model(subsample_size): with pyro.iarange("data", len(data), subsample_size) as ind: x = data[ind] z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)), reparameterized=reparameterized)) pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x) def guide(subsample_size): mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True)) sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True)) with pyro.iarange("data", len(data), subsample_size) as ind: mu = mu[ind] sigma = sigma.expand(subsample_size) pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized)) optim = Adam({"lr": 0.1}) inference = SVI(model, guide, optim, loss="ELBO", trace_graph=trace_graph, num_particles=num_particles) # Compute gradients without subsampling. inference.loss_and_grads(model, guide, subsample_size=data_size) params = dict(pyro.get_param_store().named_parameters()) expected_grads = {name: param.grad.data.clone() for name, param in params.items()} zero_grads(params.values()) # Compute gradients with subsampling. inference.loss_and_grads(model, guide, subsample_size=subsample_size) actual_grads = {name: param.grad.data.clone() for name, param in params.items()} for name in sorted(params): print('\nexpected {} = {}'.format(name, expected_grads[name].cpu().numpy())) print('actual {} = {}'.format(name, actual_grads[name].cpu().numpy())) assert_equal(actual_grads, expected_grads, prec=precision)
def test_subsample_gradient(Elbo, reparameterized, subsample): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) num_particles = 50000 precision = 0.06 Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(subsample): with pyro.iarange("particles", num_particles): with pyro.iarange("data", len(data), subsample_size, subsample) as ind: x = data[ind].unsqueeze(-1).expand(-1, num_particles) z = pyro.sample("z", Normal(0, 1).expand_by(x.shape)) pyro.sample("x", Normal(z, 1), obs=x) def guide(subsample): loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True)) scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True)) with pyro.iarange("particles", num_particles): with pyro.iarange("data", len(data), subsample_size, subsample) as ind: loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles) pyro.sample("z", Normal(loc_ind, scale)) optim = Adam({"lr": 0.1}) elbo = Elbo(strict_enumeration_warning=False) inference = SVI(model, guide, optim, loss=elbo) if subsample_size == 1: inference.loss_and_grads(model, guide, subsample=torch.LongTensor([0])) inference.loss_and_grads(model, guide, subsample=torch.LongTensor([1])) else: inference.loss_and_grads(model, guide, subsample=torch.LongTensor([0, 1])) params = dict(pyro.get_param_store().named_parameters()) normalizer = 2 * num_particles / subsample_size actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()} expected_grads = {'loc': np.array([0.5, -2.0]), 'scale': np.array([2.0])} for name in sorted(params): logger.info('expected {} = {}'.format(name, expected_grads[name])) logger.info('actual {} = {}'.format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision)
def train(args, dataset): """ Train a model and guide to fit a dataset. """ counts = dataset["counts"] num_stations = len(dataset["stations"]) logging.info( "Training on {} stations over {} hours, {} batches/epoch".format( num_stations, len(counts), int(math.ceil(len(counts) / args.batch_size)))) time_features = make_time_features(args, 0, len(counts)) control_features = (counts.max(1)[0] + counts.max(2)[0]).clamp(max=1) logging.info( "On average {:0.1f}/{} stations are open at any one time".format( control_features.sum(-1).mean(), num_stations)) features = torch.cat([time_features, control_features], -1) feature_dim = features.size(-1) logging.info("feature_dim = {}".format(feature_dim)) metadata = {"args": args, "losses": [], "control": control_features} torch.save(metadata, args.training_filename) def optim_config(module_name, param_name): config = { "lr": args.learning_rate, "betas": (0.8, 0.99), "weight_decay": 0.01**(1 / args.num_steps), } if param_name == "init_scale": config["lr"] *= 0.1 # init_dist sees much less data per minibatch return config training_counts = counts[:args.truncate] if args.truncate else counts data_size = len(training_counts) model = Model(args, features, training_counts).to(device=args.device) guide = Guide(args, features, training_counts).to(device=args.device) elbo = (TraceMeanField_ELBO if args.analytic_kl else Trace_ELBO)() optim = ClippedAdam(optim_config) svi = SVI(model, guide, optim, elbo) losses = [] forecaster = None for step in range(args.num_steps): begin_time = torch.randint(max(1, data_size - args.batch_size), ()).item() end_time = min(data_size, begin_time + args.batch_size) feature_batch = features[begin_time:end_time].to(device=args.device) counts_batch = counts[begin_time:end_time].to(device=args.device) loss = svi.step(feature_batch, counts_batch) / counts_batch.numel() assert math.isfinite(loss), loss losses.append(loss) logging.debug("step {} loss = {:0.4g}".format(step, loss)) if step % 20 == 0: # Save state every few steps. pyro.get_param_store().save(args.param_store_filename) metadata = { "args": args, "losses": losses, "control": control_features } torch.save(metadata, args.training_filename) forecaster = Forecaster(args, dataset, features, model, guide) torch.save(forecaster, args.forecaster_filename) if logging.Logger(None).isEnabledFor(logging.DEBUG): init_scale = pyro.param("init_scale").data trans_scale = pyro.param("trans_scale").data trans_matrix = pyro.param("trans_matrix").data eigs = trans_matrix.eig()[0].norm(dim=-1).sort( descending=True).values logging.debug("guide.diag_part = {}".format( guide.diag_part.data.squeeze())) logging.debug( "init scale min/mean/max: {:0.3g} {:0.3g} {:0.3g}".format( init_scale.min(), init_scale.mean(), init_scale.max())) logging.debug( "trans scale min/mean/max: {:0.3g} {:0.3g} {:0.3g}".format( trans_scale.min(), trans_scale.mean(), trans_scale.max())) logging.debug("trans mat eig:\n{}".format(eigs)) return forecaster
def transform(self, X: np.ndarray, num_samples: int = 1000, random_state: int = None, mean_estimate: bool = False) -> np.ndarray: """ After model calibration, this function is used to get calibrated outputs of uncalibrated confidence estimates. Parameters ---------- X : np.ndarray, shape=(n_samples, [n_classes]) or (n_samples, [n_box_features]) NumPy array with confidence values for each prediction on classification with shapes 1-D for binary classification, 2-D for multi class (softmax). On detection, this array must have 2 dimensions with number of additional box features in last dim. num_samples : int, optional, default: 1000 Number of samples generated on MCMC sampling or Variational Inference. random_state : int, optional, default: None Fix the random seed for the random number mean_estimate : bool, optional, default: False If True, directly return the mean on probabilistic methods like MCMC or VI instead of the full distribution. This parameter has no effect on MLE. Returns ------- np.ndarray, shape=(n_samples, [n_classes]) on MLE or on MCMC/VI if 'mean_estimate' is True or shape=(n_parameters, n_samples, [n_classes]) on VI, MCMC if 'mean_estimate' is False On MLE without uncertainty, return NumPy array with calibrated confidence estimates. 1-D for binary classification, 2-D for multi class (softmax). On VI or MCMC, return NumPy array with leading dimension as the number of sampled parameters from the log regression parameter distribution obtained by VI or MCMC. """ def process_model(weights: dict) -> torch.Tensor: """ Fix model weights to the weight vector given as the parameter and return calibrated data. """ # model will return pytorch tensor model = pyro.condition(self.model, data=weights) logit = model(data) # distinguish between detection, binary and multiclass classification if self.detection or self._is_binary_classification(): calibrated = torch.sigmoid(logit) else: calibrated = torch.softmax(logit, dim=1) return calibrated # prepare input data X = super().transform(X) self.to(self._device) # convert input data and weights to torch (and possibly to CUDA) data = self.prepare(X).float().to(self._device) # if weights is 2-D matrix, we are in sampling mode # treat each row as a separate weights vector if self.method in ['variational', 'mcmc']: if mean_estimate: weights = {} # on MCMC sampling, use mean over all weights as mean weight estimate # TODO: we need to find another way since the parameters are conditionally dependent # TODO: revise!!! We often have log-normals instead of normal distributions, # thus the mean will be a different if self.mcmc_model is not None: for name, site in self._sites.items(): weights[name] = torch.from_numpy( np.mean(self.mcmc_model[name])).to(self._device) # on variational inference, use mean of the variational distribution for inference elif self.vi_model is not None: for name, site in self._sites.items(): weights[name] = torch.from_numpy( self.vi_model['params']['%s_mean' % name]).to( self._device) else: raise ValueError( "Internal error: neither MCMC nor variational model given." ) # on MLE without uncertainty, only return the single model estimate calibrated = process_model(weights).cpu().numpy() calibrated = self.squeeze_generic(calibrated, axes_to_keep=0) else: parameter = [] if self.mcmc_model is not None: with manual_seed(seed=random_state): idxs = torch.randint(0, self.mcmc_steps, size=(num_samples, ), device=self._device) samples = { k: v.index_select(0, idxs) for k, v in self.mcmc_model.items() } elif self.vi_model is not None: # restore state of global parameter store of pyro and use this parameter store for the predictive pyro.get_param_store().set_state(self.vi_model) predictive = Predictive(self.model, guide=self.guide, num_samples=num_samples, return_sites=tuple( self._sites.keys())) with manual_seed(seed=random_state): samples = predictive(data) else: raise ValueError( "Internal error: neither MCMC nor variational model given." ) # remove unnecessary dims that possibly occur on MCMC or VI samples = { k: torch.reshape(v, (num_samples, -1)) for k, v in samples.items() } # iterate over all parameter sets for i in range(num_samples): param_dict = {} # iterate over all sites and store into parameter dict for site in self._sites.keys(): param_dict[site] = samples[site][i].detach().to( self._device) parameter.append(param_dict) calibrated = [] # iterate over all parameter collections and compute calibration mapping for param_dict in parameter: cal = process_model(param_dict) calibrated.append(cal) # stack all calibrated estimates along axis 0 and calculate stddev as well as mean calibrated = torch.stack(calibrated, dim=0).cpu().numpy() calibrated = self.squeeze_generic(calibrated, axes_to_keep=(0, 1)) else: # extract all weight values of sites and store into single dict weights = {} for name, site in self._sites.items(): weights[name] = torch.from_numpy(site['values']).to( self._device) # on MLE without uncertainty, only return the single model estimate calibrated = process_model(weights).cpu().numpy() calibrated = self.squeeze_generic(calibrated, axes_to_keep=0) # delete torch data tensor del data # if device is cuda, empty GPU cache to free memory if self._device.type == 'cuda': with torch.cuda.device(self._device): torch.cuda.empty_cache() return calibrated
def save_model(self): # save parameters from the pyro module not pytorch itself save_path = Path("data/saved_models/") save_path.mkdir(exist_ok=True, parents=True) pyro.get_param_store().save( save_path.joinpath(f"{self.config.id:02}_bnn_params.pr"))
trace = poutine.trace(pyromodel).get_trace(torch.tensor(x), torch.tensor(y)) trace.compute_log_prob() # optional, but allows printing of log_prob shapes print(trace.format_shapes()) ys = [] amp = 1. sig = 1.0 #xs = np.linspace(0, 5, 500, dtype='float32') xs = torch.tensor(xtest_pca.astype('float32')) for i in range(50): sampled_model = guide(None, None) ys += [sampled_model(xs).cpu().detach().numpy().flatten()] ys = np.stack(ys).T for name, value in pyro.get_param_store().items(): print(name, pyro.param(name).shape) """ plt.figure() plt.yscale('linear') plt.title("Training Data") plt.xlabel("hu (mu = <10> is Y for NN)") plt.ylabel("Intensity (X for NN)") plt.plot(hu,x[::10,:].T) """ plt.figure() plt.yscale('linear') plt.title("Fit to mu") plt.xlabel("1st PCA component") plt.ylabel("mu") plt.legend()
def boosting_bbvi(): n_iterations = 2 initial_approximation = dummy_approximation components = [initial_approximation] weights = torch.tensor([1.]) wrapped_approximation = partial(approximation, components=components, weights=weights) locs = [0] scales = [0] gradient_norms = defaultdict(list) for t in range(1, n_iterations + 1): # setup the inference algorithm wrapped_guide = partial(guide, index=t) # do gradient steps losses = [] # Register hooks to monitor gradient norms. wrapped_guide(data) print(pyro.get_param_store().named_parameters()) adam_params = {"lr": 0.002, "betas": (0.90, 0.999)} optimizer = Adam(adam_params) for name, value in pyro.get_param_store().named_parameters(): if not name in gradient_norms: value.register_hook(lambda g, name=name: gradient_norms[name]. append(g.norm().item())) svi = SVI(model, wrapped_guide, optimizer, loss=relbo) for step in range(n_steps): loss = svi.step(data, approximation=wrapped_approximation) losses.append(loss) if PRINT_INTERMEDIATE_LATENT_VALUES: print('Loss: {}'.format(loss)) variance = pyro.param("variance_{}".format(t)).item() mu = pyro.param("mu_{}".format(t)).item() print('mu = {}'.format(mu)) print('variance = {}'.format(variance)) if step % 100 == 0: print('.', end=' ') pyplot.plot(range(len(losses)), losses) pyplot.xlabel('Update Steps') pyplot.ylabel('-ELBO') pyplot.title('-ELBO against time for component {}'.format(t)) pyplot.show() components.append(wrapped_guide) new_weight = 2 / (t + 1) weights = weights * (1 - new_weight) weights = torch.cat((weights, torch.tensor([new_weight]))) wrapped_approximation = partial(approximation, components=components, weights=weights) scale = pyro.param("variance_{}".format(t)).item() scales.append(scale) loc = pyro.param("mu_{}".format(t)).item() locs.append(loc) print('mu = {}'.format(loc)) print('variance = {}'.format(scale)) pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor('white') for name, grad_norms in gradient_norms.items(): pyplot.plot(grad_norms, label=name) pyplot.xlabel('iters') pyplot.ylabel('gradient norm') # pyplot.yscale('log') pyplot.legend(loc='best') pyplot.title('Gradient norms during SVI') pyplot.show() print(weights) print(locs) print(scales) X = np.arange(-10, 10, 0.1) Y1 = weights[1].item() * scipy.stats.norm.pdf((X - locs[1]) / scales[1]) Y2 = weights[2].item() * scipy.stats.norm.pdf((X - locs[2]) / scales[2]) pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor('white') pyplot.plot(X, Y1, 'r-') pyplot.plot(X, Y2, 'b-') pyplot.plot(X, Y1 + Y2, 'k--') pyplot.plot(data.data.numpy(), np.zeros(len(data)), 'k*') pyplot.title('Approximation of posterior over mu') pyplot.ylabel('probability density') pyplot.show()
def run_svi(beta_hat, obs_error, K, true_beta): num_steps = TOTAL_ITS//K start = time() pyro.clear_param_store() pyro.enable_validation(True) def my_model(): return prs_model(torch.tensor(beta_hat), torch.tensor(obs_error)) initial_approximation = partial(prs_guide, index=0) components = [initial_approximation] weights = torch.tensor([1.]) wrapped_approximation = partial(approximation, components=components, weights=weights) optimizer = pyro.optim.Adam({'lr': LR}) losses = [] wrapped_guide = partial(prs_guide, index=0) svi = pyro.infer.SVI( my_model, wrapped_guide, optimizer, loss=pyro.infer.Trace_ELBO(num_particles=NUM_PARTICLES) ) for step in range(num_steps): loss = svi.step() losses.append(loss) if step % 100 == 0: print('\t', step, np.mean(losses[-100:])) if step % 100 == 0: pstore = pyro.get_param_store() curr_mean = pstore.get_param( 'var_mean_{}'.format(0)).detach().numpy() curr_psis = pstore.get_param( 'var_psi_causal_{}'.format(0)).detach().numpy() curr_mean = curr_mean * curr_psis print('\t\t', np.corrcoef(true_beta, curr_mean)[0, 1], np.mean((true_beta - curr_mean)**2)) pstore = pyro.get_param_store() for t in range(1, K): print('Boost level', t) wrapped_guide = partial(prs_guide, index=t) losses = [] optimizer = pyro.optim.Adam({'lr': LR}) svi = pyro.infer.SVI(my_model, wrapped_guide, optimizer, loss=relbo) new_weight = 2 / ((t+1) + 2) new_weights = torch.cat((weights * (1-new_weight), torch.tensor([new_weight]))) for step in range(num_steps): loss = svi.step(approximation=wrapped_approximation) losses.append(loss) if step % 100 == 0: print('\t', step, np.mean(losses[-100:])) if step % 100 == 0: pstore = pyro.get_param_store() curr_means = [ pstore.get_param( 'var_mean_{}'.format(s)).detach().numpy() for s in range(t+1) ] curr_psis = [ pstore.get_param( 'var_psi_causal_{}'.format(0)).detach().numpy() for s in range(t+1) ] curr_means = np.array(curr_means) * np.array(curr_psis) curr_mean = new_weights.detach().numpy().dot(curr_means) print('\t\t', np.corrcoef(true_beta, curr_mean)[0, 1], np.mean((true_beta - curr_mean)**2)) components.append(wrapped_guide) weights = new_weights wrapped_approximation = partial(approximation, components=components, weights=weights) # scales.append( # pstore.get_param('var_mean_{}'.format(t)).detach().numpy() # ) print('BBBVI ran in', time() - start) pstore = pyro.get_param_store() curr_means = [ pstore.get_param( 'var_mean_{}'.format(s)).detach().numpy() for s in range(K) ] return weights.detach().numpy().dot(np.array(np.array(curr_means)))
def backtest(data, covariates, model_fn, *, forecaster_fn=Forecaster, metrics=None, transform=None, train_window=None, min_train_window=1, test_window=None, min_test_window=1, stride=1, seed=1234567890, num_samples=100, batch_size=None, forecaster_options={}): """ Backtest a forecasting model on a moving window of (train,test) data. :param data: A tensor dataset with time dimension -2. :type data: ~torch.Tensor :param covariates: A tensor of covariates with time dimension -2. For models not using covariates, pass a shaped empty tensor ``torch.empty(duration, 0)``. :type covariates: ~torch.Tensor :param callable model_fn: Function that returns an :class:`~pyro.contrib.forecast.forecaster.ForecastingModel` object. :param callable forecaster_fn: Function that returns a forecaster object (for example, :class:`~pyro.contrib.forecast.forecaster.Forecaster` or :class:`~pyro.contrib.forecast.forecaster.HMCForecaster`) given arguments model, training data, training covariates and keyword arguments defined in `forecaster_options`. :param dict metrics: A dictionary mapping metric name to metric function. The metric function should input a forecast ``pred`` and ground ``truth`` and can output anything, often a number. Example metrics include: :func:`eval_mae`, :func:`eval_rmse`, and :func:`eval_crps`. :param callable transform: An optional transform to apply before computing metrics. If provided this will be applied as ``pred, truth = transform(pred, truth)``. :param int train_window: Size of the training window. Be default trains from beginning of data. This must be None if forecaster is :class:`~pyro.contrib.forecast.forecaster.Forecaster` and ``forecaster_options["warm_start"]`` is true. :param int min_train_window: If ``train_window`` is None, this specifies the min training window size. Defaults to 1. :param int test_window: Size of the test window. By default forecasts to end of data. :param int min_test_window: If ``test_window`` is None, this specifies the min test window size. Defaults to 1. :param int stride: Optional stride for test/train split. Defaults to 1. :param int seed: Random number seed. :param int num_samples: Number of samples for forecast. Defaults to 100. :param int batch_size: Batch size for forecast sampling. Defaults to ``num_samples``. :param forecaster_options: Options dict to pass to forecaster, or callable inputting time window ``t0,t1,t2`` and returning such a dict. See :class:`~pyro.contrib.forecaster.Forecaster` for details. :type forecaster_options: dict or callable :returns: A list of dictionaries of evaluation data. Caller is responsible for aggregating the per-window metrics. Dictionary keys include: train begin time "t0", train/test split time "t1", test end time "t2", "seed", "num_samples", "train_walltime", "test_walltime", and one key for each metric. :rtype: list """ assert data.size(-2) == covariates.size(-2) assert isinstance(min_train_window, int) and min_train_window >= 1 assert isinstance(min_test_window, int) and min_test_window >= 1 if metrics is None: metrics = DEFAULT_METRICS assert metrics, "no metrics specified" if callable(forecaster_options): forecaster_options_fn = forecaster_options else: def forecaster_options_fn(*args, **kwargs): return forecaster_options if train_window is not None and forecaster_options_fn().get("warm_start"): raise ValueError("Cannot warm start with moving training window; " "either set warm_start=False or train_window=None") duration = data.size(-2) if test_window is None: stop = duration - min_test_window + 1 else: stop = duration - test_window + 1 if train_window is None: start = min_train_window else: start = train_window pyro.clear_param_store() results = [] for t1 in range(start, stop, stride): t0 = 0 if train_window is None else t1 - train_window t2 = duration if test_window is None else t1 + test_window assert 0 <= t0 < t1 < t2 <= duration logger.info( "Training on window [{t0}:{t1}], testing on window [{t1}:{t2}]". format(t0=t0, t1=t1, t2=t2)) # Train a forecaster on the training window. pyro.set_rng_seed(seed) forecaster_options = forecaster_options_fn(t0=t0, t1=t1, t2=t2) if not forecaster_options.get("warm_start"): pyro.clear_param_store() train_data = data[..., t0:t1, :] train_covariates = covariates[..., t0:t1, :] start_time = default_timer() model = model_fn() forecaster = forecaster_fn(model, train_data, train_covariates, **forecaster_options) train_walltime = default_timer() - start_time # Forecast forward to testing window. test_covariates = covariates[..., t0:t2, :] start_time = default_timer() # Gradually reduce batch_size to avoid OOM errors. while True: try: pred = forecaster(train_data, test_covariates, num_samples=num_samples, batch_size=batch_size) break except RuntimeError as e: if "out of memory" in str(e) and batch_size > 1: batch_size = (batch_size + 1) // 2 warnings.warn( "out of memory, decreasing batch_size to {}".format( batch_size), RuntimeWarning) else: raise test_walltime = default_timer() - start_time truth = data[..., t1:t2, :] # We aggressively garbage collect because Monte Carlo forecast are memory intensive. del forecaster # Evaluate the forecasts. if transform is not None: pred, truth = transform(pred, truth) result = { "t0": t0, "t1": t1, "t2": t2, "seed": seed, "num_samples": num_samples, "train_walltime": train_walltime, "test_walltime": test_walltime, "params": {}, } results.append(result) for name, fn in metrics.items(): result[name] = fn(pred, truth) for name, value in pyro.get_param_store().items(): if value.numel() == 1: value = value.cpu().item() result["params"][name] = value for dct in (result, result["params"]): for key, value in sorted(dct.items()): if isinstance(value, (int, float)): logger.debug("{} = {:0.6g}".format(key, value)) del pred return results
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # get info regarding rao-blackwellization of vectorized map_data guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"] model_vec_md_info = model_trace.graph["vectorized_map_data_info"] guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition'] model_vec_md_condition = model_vec_md_info['rao-blackwellization-condition'] do_vec_rb = guide_vec_md_condition and model_vec_md_condition if not do_vec_rb: warnings.warn( "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. " "Falling back to higher-variance gradient estimator. " "Try to avoid these issues in your model and guide:\n{}".format("\n".join( guide_vec_md_info["warnings"] | model_vec_md_info["warnings"]))) guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set() model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set() # have the trace compute all the individual (batch) log pdf terms # so that they are available below guide_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in guide_vec_md_nodes) guide_trace.log_pdf() model_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in model_vec_md_nodes) model_trace.log_pdf() # prepare a list of all the cost nodes, each of which is +- log_pdf cost_nodes = [] non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": if model_site["is_observed"]: cost_nodes.append(CostNode(model_site["log_pdf"], True)) else: # cost node from model sample cost_nodes.append(CostNode(model_site["log_pdf"], True)) # cost node from guide sample guide_site = guide_trace.nodes[name] zero_expectation = name in non_reparam_nodes cost_nodes.append(CostNode(-guide_site["log_pdf"], not zero_expectation)) # compute the elbo; if all stochastic nodes are reparameterizable, we're done # this bit is never differentiated: it's here for getting an estimate of the elbo itself elbo = torch_data_sum(sum(c.cost for c in cost_nodes)) # compute the surrogate elbo, removing terms whose gradient is zero # this is the bit that's actually differentiated # XXX should the user be able to control if these terms are included? surrogate_elbo = sum(c.cost for c in cost_nodes if c.nonzero_expectation) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: # recursively compute downstream cost nodes for all sample sites in model and guide # (even though ultimately just need for non-reparameterizable sample sites) # 1. downstream costs used for rao-blackwellization # 2. model observe sites (as well as terms that arise from the model and guide having different # dependency structures) are taken care of via 'children_in_model' below topo_sort_guide_nodes = list(reversed(list(networkx.topological_sort(guide_trace)))) topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample"] downstream_guide_cost_nodes = {} downstream_costs = {} for node in topo_sort_guide_nodes: node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \ guide_trace.nodes[node][node_log_pdf_key] nodes_included_in_sum = set([node]) downstream_guide_cost_nodes[node] = set([node]) for child in guide_trace.successors(node): child_cost_nodes = downstream_guide_cost_nodes[child] downstream_guide_cost_nodes[node].update(child_cost_nodes) if nodes_included_in_sum.isdisjoint(child_cost_nodes): # avoid duplicates if node_log_pdf_key == 'log_pdf': downstream_costs[node] += downstream_costs[child].sum() else: downstream_costs[node] += downstream_costs[child] nodes_included_in_sum.update(child_cost_nodes) missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum # include terms we missed because we had to avoid duplicates for missing_node in missing_downstream_costs: mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf' if node_log_pdf_key == 'log_pdf': downstream_costs[node] += (model_trace.nodes[missing_node][mn_log_pdf_key] - guide_trace.nodes[missing_node][mn_log_pdf_key]).sum() else: downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \ guide_trace.nodes[missing_node][mn_log_pdf_key] # finish assembling complete downstream costs # (the above computation may be missing terms from model) # XXX can we cache some of the sums over children_in_model to make things more efficient? for site in non_reparam_nodes: children_in_model = set() for node in downstream_guide_cost_nodes[site]: children_in_model.update(model_trace.successors(node)) # remove terms accounted for above children_in_model.difference_update(downstream_guide_cost_nodes[site]) for child in children_in_model: child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf' site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf' assert (model_trace.nodes[child]["type"] == "sample") if site_log_pdf_key == 'log_pdf': downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key].sum() else: downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key] # construct all the reinforce-like terms. # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance # XXX should the average baseline be in the param store as below? elbo_reinforce_terms = 0.0 for node in non_reparam_nodes: guide_site = guide_trace.nodes[node] log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_cost = downstream_costs[node] baseline = 0.0 (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None assert(not (use_nn_baseline and use_baseline_value)), \ "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node, ng_zeros(1), tags="__tracegraph_elbo_internal_tag") avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old avg_downstream_cost_old.data = avg_downstream_cost_new.data # XXX copy_() ? baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline(detach_iterable(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # accumulate baseline loss baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum() guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"] # not scaled by subsampling if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value: if downstream_cost.size() != baseline.size(): raise ValueError("Expected baseline at site {} to be {} instead got {}".format( node, downstream_cost.size(), baseline.size())) downstream_cost = downstream_cost - baseline elbo_reinforce_terms += (guide_log_pdf * downstream_cost.detach()).sum() surrogate_elbo += elbo_reinforce_terms # collect parameters to train from model and guide trainable_params = set(site["value"] for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "param") if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) pyro.get_param_store().mark_params_active(trainable_params) loss = -elbo return weight * loss
def get_encodings(model: VariationalInferenceModel, dataset_obj, cells_only: bool = True) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Get inferred quantities from a trained model. Run a dataset through the model's trained encoder and return the inferred quantities. Args: model: A trained cellbender.model.VariationalInferenceModel, which will be used to generate the encodings from data. dataset_obj: The dataset to be encoded. cells_only: If True, only returns the encodings of barcodes that are determined to contain cells. Returns: z: Latent variable embedding of gene expression in a low-dimensional space. d: Latent variable scale factor for the number of UMI counts coming from each real cell. Not in log space, but actual size. This is not just the encoded d, but the mean of the LogNormal distribution, which is exp(mean + sigma^2 / 2). p: Latent variable denoting probability that each barcode contains a real cell. """ logging.info("Encoding data according to model.") # Get the count matrix with genes trimmed. if cells_only: dataset = dataset_obj.get_count_matrix() else: dataset = dataset_obj.get_count_matrix_all_barcodes() # Initialize numpy arrays as placeholders. z = np.zeros((dataset.shape[0], model.z_dim)) d = np.zeros((dataset.shape[0])) p = np.zeros((dataset.shape[0])) # Get chi ambient, if it was part of the model. chi_ambient = get_ambient_expression() if chi_ambient is not None: chi_ambient = torch.Tensor(chi_ambient).to(device=model.device) # Send dataset through the learned encoder in chunks. s = 200 for i in np.arange(0, dataset.shape[0], s): # Put chunk of data into a torch.Tensor. x = torch.Tensor(np.array( dataset[i:min(dataset.shape[0], i + s), :].todense(), dtype=int).squeeze()).to(device=model.device) # Send data chunk through encoder. enc = model.encoder.forward(x, chi_ambient) # Get d_cell_scale from fit model. d_sig = \ pyro.get_param_store().get_param('d_cell_scale').detach().cpu().numpy() # Put the resulting encodings into the appropriate numpy arrays. z[i:min(dataset.shape[0], i + s), :] = \ enc['z']['loc'].detach().cpu().numpy() d[i:min(dataset.shape[0], i + s)] = \ np.exp(enc['d_loc'].detach().cpu().numpy() + d_sig.item()**2 / 2) try: # p is not always available: it depends which model was used. p[i:min(dataset.shape[0], i + s)] = \ enc['p_y'].detach().sigmoid().cpu().numpy() except KeyError: p = None # Simple model gets None for p. return z, d, p
def main(args): logging.info(f"CUDA available: {torch.cuda.is_available()}") logging.info('Generating data') pyro.set_rng_seed(0) pyro.clear_param_store() pyro.enable_validation(True) # Debugging the trace of the model. For showing the shapes of the tensors through the model # tracemodel = functools.partial(model, args=args) # trace = poutine.trace(tracemodel).get_trace() # trace.compute_log_prob() # optional, but allows printing of log_prob shapes # print(trace.format_shapes()) # We can generate synthetic data directly by calling the model. data = model(args=args) gen_doc_word_data = data["doc_word_data"] gen_doc_category_data = data["doc_category_data"] # Loading data corpora = prepro_file_load("corpora") documents = list(prepro_file_load("id2pre_text").values()) category_list = [[cat] for cat in list(prepro_file_load("id2category").values())] category_corpora = prepro_file_load("category_corpora") doc_word_data = [ torch.tensor(list(filter(lambda a: a != -1, corpora.doc2idx(doc))), dtype=torch.int64) for doc in documents ] doc_category_data = [ torch.tensor(next( filter(lambda a: a != -1, category_corpora.doc2idx(cat))), dtype=torch.int64) for cat in category_list ] # TODO X check if there are differences in this date and model generated data # Slice data to only use data from the first n documents data_slice = None if data_slice is not None: doc_word_data = doc_word_data[:data_slice] doc_category_data = doc_category_data[:data_slice] # Setting the new args args.num_words_per_doc = list(map(len, doc_word_data)) args.num_words = len(corpora) args.num_docs = len(doc_word_data) args.num_categories = len(category_corpora) args.num_topics = args.num_categories * 2 # TODO X test different amounts of topics # We'll train using SVI. logging.info('-' * 40) logging.info('Training on {} documents'.format(args.num_docs)) Elbo = JitTraceEnum_ELBO if args.jit else Trace_ELBO # TODO X test TraceEnum_ vs Trace_ vs TraceMeanField_ elbo = Elbo( max_plate_nesting=2 ) # TODO Changing the max plate nesting value might be worth looking at optim = ClippedAdam({'lr': args.learning_rate }) # TODO X try different learning rates # TODO try something other than ClippedAdam or changing its parameters svi = SVI(model, parametrized_guide, optim, elbo) losses = [] # Training for num_steps iterations logging.info('Step\tLoss') for step in tqdm(range(args.num_steps)): loss = svi.step(doc_word_data=doc_word_data, category_data=doc_category_data, args=args, batch_size=args.batch_size) losses.append(loss) if step % 10 == 0: logging.info('{: >5d}\t{}'.format(step, loss)) loss = elbo.loss(model, parametrized_guide, doc_word_data=doc_word_data, category_data=doc_category_data, args=args, batch_size=args.batch_size) logging.info('final loss = {}'.format(loss)) # Print params after training print('topic_weights_posterior = ', pyro.param("topic_weights_posterior")) print('topic_words_posterior = ', pyro.param("topic_words_posterior")) print('category_weights_posterior = ', pyro.param("category_weights_posterior")) print('category_topics_posterior = ', pyro.param("category_topics_posterior")) print('doc_category_posterior = ', pyro.param("doc_category_posterior")) # Plot loss over iterations plt.plot(losses) plt.title("ELBO") plt.xlabel("step") plt.ylabel("loss") plot_file_name = "../loss-2017_categories-" + str(args.num_categories) + \ "_topics-" + str(args.num_topics) + \ "_batch-" + str(args.batch_size) + \ "_lr-" + str(args.learning_rate) + \ "_data-size-" + str(data_slice) + \ ".png" plt.savefig(plot_file_name) plt.show() # save model pyro.get_param_store().save("mymodelparams.pt")
def main(args): # Load dataset. if args.cpu_data or not args.cuda: device = torch.device("cpu") else: device = torch.device("cuda") if args.test: dataset = generate_data(args.small, args.include_stop, device) else: dataset = BiosequenceDataset( args.file, "fasta", args.alphabet, include_stop=args.include_stop, device=device, ) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.0: # Train test split. heldout_num = int(np.ceil(args.split * len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] # Specific data split seed, for comparability across models and # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths), device=device).tolist() dataset_train, dataset_test = [ torch.utils.data.Subset(dataset, indices[(offset - length):offset]) for offset, length in zip(torch._utils._accumulate(data_lengths), data_lengths) ] else: dataset_train = dataset dataset_test = None # Training seed. pyro.set_rng_seed(args.rng_seed) # Construct model. model = FactorMuE( dataset.max_length, dataset.alphabet_length, args.z_dim, batch_size=args.batch_size, latent_seq_length=args.latent_seq_length, indel_factor_dependence=args.indel_factor, indel_prior_scale=args.indel_prior_scale, indel_prior_bias=args.indel_prior_bias, inverse_temp_prior=args.inverse_temp_prior, weights_prior_scale=args.weights_prior_scale, offset_prior_scale=args.offset_prior_scale, z_prior_distribution=args.z_prior, ARD_prior=args.ARD_prior, substitution_matrix=(not args.no_substitution_matrix), substitution_prior_scale=args.substitution_prior_scale, latent_alphabet_length=args.latent_alphabet, cuda=args.cuda, pin_memory=args.pin_mem, ) # Infer with SVI. scheduler = MultiStepLR({ "optimizer": Adam, "optim_args": { "lr": args.learning_rate }, "milestones": json.loads(args.milestones), "gamma": args.learning_gamma, }) n_epochs = args.n_epochs losses = model.fit_svi( dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit, ) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( dataset_train, dataset_test, args.jit) print("train logp: {} perplex: {}".format(train_lp, train_perplex)) print("test logp: {} perplex: {}".format(test_lp, test_perplex)) # Get latent space embedding. z_locs, z_scales = model.embed(dataset) # Plot and save. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if not args.no_plots: plt.figure(figsize=(6, 6)) plt.plot(losses) plt.xlabel("step") plt.ylabel("loss") if not args.no_save: plt.savefig( os.path.join(args.out_folder, "FactorMuE_plot.loss_{}.pdf".format(time_stamp))) plt.figure(figsize=(6, 6)) plt.scatter(z_locs[:, 0], z_locs[:, 1]) plt.xlabel(r"$z_1$") plt.ylabel(r"$z_2$") if not args.no_save: plt.savefig( os.path.join( args.out_folder, "FactorMuE_plot.latent_{}.pdf".format(time_stamp))) if not args.indel_factor: # Plot indel parameters. See statearrangers.py for details on the # r and u parameters. plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) plt.plot(insert_expect[:, :, 1].cpu().numpy()) plt.xlabel("position") plt.ylabel("probability of insert") plt.legend([r"$r_0$", r"$r_1$", r"$r_2$"]) if not args.no_save: plt.savefig( os.path.join( args.out_folder, "FactorMuE_plot.insert_prob_{}.pdf".format(time_stamp), )) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].cpu().numpy()) plt.xlabel("position") plt.ylabel("probability of delete") plt.legend([r"$u_0$", r"$u_1$", r"$u_2$"]) if not args.no_save: plt.savefig( os.path.join( args.out_folder, "FactorMuE_plot.delete_prob_{}.pdf".format(time_stamp), )) if not args.no_save: pyro.get_param_store().save( os.path.join(args.out_folder, "FactorMuE_results.params_{}.out".format(time_stamp))) with open( os.path.join( args.out_folder, "FactorMuE_results.evaluation_{}.txt".format(time_stamp), ), "w", ) as ow: ow.write("train_lp,test_lp,train_perplex,test_perplex\n") ow.write("{},{},{},{}\n".format(train_lp, test_lp, train_perplex, test_perplex)) np.savetxt( os.path.join( args.out_folder, "FactorMuE_results.embed_loc_{}.txt".format(time_stamp)), z_locs.cpu().numpy(), ) np.savetxt( os.path.join( args.out_folder, "FactorMuE_results.embed_scale_{}.txt".format(time_stamp), ), z_scales.cpu().numpy(), ) with open( os.path.join( args.out_folder, "FactorMuE_results.input_{}.txt".format(time_stamp), ), "w", ) as ow: ow.write("[args]\n") args.latent_seq_length = model.latent_seq_length args.latent_alphabet = model.latent_alphabet_length for elem in list(args.__dict__.keys()): ow.write("{} = {}\n".format(elem, args.__getattribute__(elem))) ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet))) ow.write("max_length = {}\n".format(dataset.max_length))
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo = 0.0 # grab a trace from the generator for weight, model_trace, guide_trace, log_r in self._get_traces( model, guide, *args, **kwargs): elbo_particle = weight * 0 surrogate_elbo_particle = weight * 0 # compute elbo and surrogate elbo log_pdf = "batch_log_pdf" if ( self.enum_discrete and weight.size(0) > 1) else "log_pdf" for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle += model_site[log_pdf] surrogate_elbo_particle += model_site[log_pdf] else: guide_site = guide_trace.nodes[name] lp_lq = model_site[log_pdf] - guide_site[log_pdf] elbo_particle += lp_lq if guide_site["fn"].reparameterized: surrogate_elbo_particle += lp_lq else: # XXX should the user be able to control inclusion of the -logq term below? guide_log_pdf = guide_site[log_pdf] / guide_site[ "scale"] # not scaled by subsampling surrogate_elbo_particle += model_site[ log_pdf] + log_r.detach() * guide_log_pdf # drop terms of weight zero to avoid nans if isinstance(weight, numbers.Number): if weight == 0.0: elbo_particle = torch_zeros_like(elbo_particle) surrogate_elbo_particle = torch_zeros_like( surrogate_elbo_particle) else: weight_eq_zero = (weight == 0) elbo_particle[weight_eq_zero] = 0.0 surrogate_elbo_particle[weight_eq_zero] = 0.0 elbo += torch_data_sum(weight * elbo_particle) surrogate_elbo_particle = torch_sum(weight * surrogate_elbo_particle) # collect parameters to train from model and guide trainable_params = set(site["value"] for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "param") if trainable_params: surrogate_loss_particle = -surrogate_elbo_particle torch_backward(surrogate_loss_particle) pyro.get_param_store().mark_params_active(trainable_params) loss = -elbo if np.isnan(loss): warnings.warn('Encountered NAN loss') return loss
def save_model(self): pyro.get_param_store().save('gp_adf_rtss.save')
def bayesian_regression(x_data, y_data, num_iterations): # BAYESIAN REGRESSION WITH SVI class BayesianRegression(PyroModule): def __init__(self, in_features, out_features): super().__init__() self.linear = PyroModule[nn.Linear](in_features, out_features) self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1)) def forward(self, x, y=None): # forward() specifies the data generating process sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) # this is the error term (typically called epsilon in regression equations) mean = self.linear(x).squeeze(-1) with pyro.plate("data", x.shape[0]): obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y) return mean """ Guides -- posterior distribution classes The guide determines a family of distributions, and SVI aims to find an approximate posterior distribution from this family that has the lowest KL divergence from the true posterior. """ model = BayesianRegression(3, 1) """ Under the hood, this defines a guide that uses a Normal distribution with learnable parameters corresponding to each sample statement in the model. e.g. in our case, this distribution should have a size of (5,) correspoding to the 3 regression coefficients for each of the terms, and 1 component contributed each by the intercept term and sigma in the model. """ guide = AutoDiagonalNormal(model) adam = pyro.optim.Adam({"lr": 0.03}) # note this is from Pyro's optim module, not PyTorch's svi = SVI(model, guide, adam, loss=Trace_ELBO()) """ We do not need to pass in learnable parameters to the optimizer (unlike the PyTorch example above) since that is determined by the guide code and happens behind the scenes within the SVI class automatically. To take an ELBO gradient step we simply call the step method of SVI. The data argument we pass to SVI.step will be passed to both model() and guide(). """ pyro.clear_param_store() for j in range(num_iterations): # calculate the loss and take a gradient step loss = svi.step(x_data, y_data) if (j+1) % 100 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data))) # We can examine the optimized parameter values by fetching from Pyro’s param store. guide.requires_grad_(False) # not sure what this does for name, value in pyro.get_param_store().items(): print(name, pyro.param(name)) # This gets us quantiles from the posterior distribution guide.quantiles([0.25, 0.5, 0.75]) """ Since Bayesian models give you a posterior distribution, model evalution needs to be a compbination of sampling the posterior and running the samples through the model. We generate 800 samples from our trained model. Internally, this is done by first generating samples for the unobserved sites in the guide, and then running the model forward by conditioning the sites to values sampled from the guide. Refer to the Model Serving section for insight on how the Predictive class works. """ def summary(samples): site_stats = {} for k, v in samples.items(): site_stats[k] = { "mean": torch.mean(v, 0), "std": torch.std(v, 0), "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0], "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0], } return site_stats """ Note that in return_sites, we specify both the outcome ("obs" site) as well as the return value of the model ("_RETURN") which captures the regression line. Additionally, we would also like to capture the regression coefficients (given by "linear.weight") for further analysis. """ predictive = Predictive(model, guide=guide, num_samples=800, return_sites=("linear.weight", "obs", "_RETURN")) samples = predictive(x_data) pred_summary = summary(samples)
def main(): """ run inference for SS-VAE :param args: arguments for SS-VAE :return: None """ pyro.set_rng_seed(12345) cuda = True # batch_size: number of images (and labels) to be considered in a batch ss_vae = TextSSVAE(embed_dim=300, z_dim=300, kernels=[3, 4, 5], filters=[100, 100, 100], hidden_size=300, num_rnn_layers=1, config_enum="parallel", use_cuda=cuda, aux_loss_multiplier=46) ss_vae = ss_vae.cuda() try: pyro.get_param_store().load('pyro_param_store.store') print( 'successfully loaded param store, remove file from directory if undesired' ) except Exception: print("failed to load param store, starting over") try: ss_vae.load_state_dict(torch.load('ss_vae_model.pth')) print( 'successfully loaded model parameters, remove file from directory if undesired' ) except Exception: print("failed to load model parameters") # setup the optimizer adam_params = {"lr": 1e-4, "betas": (0.9, 0.999), "weight_decay": 0.01} optimizer = Adam(adam_params) # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum # by enumerating each class label for the sampled discrete categorical distribution in the model jit = False guide = config_enumerate(ss_vae.guide, "parallel", expand=True) elbo = (JitTraceEnum_ELBO if jit else TraceEnum_ELBO)() loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo) # build a list of all losses considered losses = [loss_basic] # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al) aux_loss = True if aux_loss: elbo = JitTrace_ELBO() if jit else Trace_ELBO() loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo) losses.append(loss_aux) batch_size = 32 valid_num = 100 train_data_size = 3409 sup_num = 1163 try: # setup the logger if a filename is provided logger = open('./tmp.log', "w") if './tmp.log' else None data_loaders = setup_data_loaders(IMDBCached, cuda, batch_size=32, sup_num=valid_num) # how often would a supervised batch be encountered during inference # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised # until we have traversed through the all supervised batches periodic_interval_batches = int(train_data_size / (1.0 * sup_num)) # number of unsupervised examples unsup_num = train_data_size - sup_num # initializing local variables to maintain the best validation accuracy # seen across epochs over the supervised training set # and the corresponding testing set and the state of the networks best_valid_acc, corresponding_test_acc = 0.0, 0.0 # run inference for a certain number of epochs num_epochs = 200 sup_loss_log = [] unsup_loss_log = [] for i in range(0, num_epochs): # get the losses for an epoch epoch_losses_sup, epoch_losses_unsup = \ run_inference_for_epoch(data_loaders, losses, periodic_interval_batches) # compute average epoch losses i.e. losses per example avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup) avg_epoch_losses_unsup = map(lambda v: v / unsup_num, epoch_losses_unsup) sup_loss_log.append(avg_epoch_losses_sup) unsup_loss_log.append(avg_epoch_losses_unsup) # store the loss and validation/testing accuracies in the logfile str_loss_sup = " ".join(map(str, avg_epoch_losses_sup)) str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup)) str_print = "{} epoch: avg losses {}".format( i, "{} {}".format(str_loss_sup, str_loss_unsup)) ss_vae.eval() validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, batch_size) str_print += " validation accuracy {}".format(validation_accuracy) # this test accuracy is only for logging, this is not used # to make any decisions during training test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, batch_size) str_print += " test accuracy {}".format(test_accuracy) ss_vae.train() torch.save(ss_vae.state_dict(), 'ss_vae_model.pth') pyro.get_param_store().save('pyro_param_store.store') # update the best validation accuracy and the corresponding # testing accuracy and the state of the parent module (including the networks) if best_valid_acc < validation_accuracy: best_valid_acc = validation_accuracy corresponding_test_acc = test_accuracy if i % 10 == 0: neg_sentences, neg_bleu = generateSentences( data_loaders["test"], ss_vae.model, ss_vae.w2v_model, sentiment=0) pos_sentences, pos_bleu = generateSentences( data_loaders["test"], ss_vae.model, ss_vae.w2v_model, sentiment=1) str_print += " neg_bleu {}".format(neg_bleu) str_print += " pos_bleu {}".format(pos_bleu) pd.DataFrame.from_dict(pos_sentences).to_csv( 'positive_sentences.csv', encoding='utf-8') pd.DataFrame.from_dict(neg_sentences).to_csv( 'negative_sentences.csv', encoding='utf-8') cond_neg_sentences, neg_bleu = generateSentences( data_loaders["test"], ss_vae.conditioned_generation, ss_vae.w2v_model, sentiment=0) cond_pos_sentences, pos_bleu = generateSentences( data_loaders["test"], ss_vae.conditioned_generation, ss_vae.w2v_model, sentiment=1) pd.DataFrame.from_dict(cond_pos_sentences).to_csv( 'cond_positive_sentences.csv', encoding='utf-8') pd.DataFrame.from_dict(cond_neg_sentences).to_csv( 'cond_negative_sentences.csv', encoding='utf-8') str_print += " cond_neg_bleu {}".format(neg_bleu) str_print += " cond_pos_bleu {}".format(pos_bleu) print_and_log(logger, str_print) np.save("avg_loss_sup", np.asarray(sup_loss_log)) np.save("avg_loss_unsup", np.asarray(unsup_loss_log)) ss_vae.eval() final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, batch_size) print_and_log( logger, "best validation accuracy {} corresponding testing accuracy {} " "last testing accuracy {}".format(best_valid_acc, corresponding_test_acc, final_test_accuracy)) finally: # close the logger file object if we opened it earlier logfile = True if logfile: logger.close()
def variational(self, data: torch.Tensor, y: torch.Tensor, tensorboard: bool, log_dir: str): """ Perform variational inference using the guide. Parameters ---------- data_input : np.ndarray, shape=(n_samples, n_features) NumPy 2-D array with data input. y : np.ndarray, shape=(n_samples,) NumPy array with ground truth labels as 1-D vector (binary). """ # explicitly define datatype data = data.float() y = y.float() num_samples = data.shape[0] # create dataset lr_dataset = torch.utils.data.TensorDataset(data, y) data_loader = DataLoader(dataset=lr_dataset, batch_size=1024, pin_memory=False) # define optimizer optim = Adam({'lr': 0.01}) svi = SVI(self.model, self.guide, optim, loss=Trace_ELBO()) # add tensorboard writer if requested if tensorboard: writer = SummaryWriter(log_dir=log_dir) # start variational process with tqdm(total=self.vi_epochs) as pbar: for epoch in range(self.vi_epochs): epoch_loss = 0. for i, (x, y) in enumerate(data_loader): epoch_loss += svi.step(x, y) # get loss of complete epoch epoch_loss = epoch_loss / num_samples # logging stuff if tensorboard: # add loss to logging writer.add_scalar("SVI loss", epoch_loss, epoch) # get param store and log current state of parameter store param_store = pyro.get_param_store() for key in self._sites.keys(): for d, (loc, scale) in enumerate( zip(param_store["%s_mean" % key], param_store["%s_scale" % key])): writer.add_scalar("%s_mean_%d" % (key, d), loc, epoch) writer.add_scalar("%s_scale_%d" % (key, d), scale, epoch) # also represent the weights as distributions density = np.random.normal( loc=loc.detach().cpu().numpy(), scale=scale.detach().cpu().numpy(), size=1000) writer.add_histogram("histogram_%s_%d" % (key, d), density, epoch) # update progress bar pbar.set_description("SVI Loss: %.5f" % epoch_loss) pbar.update(1) self.vi_model = pyro.get_param_store().get_state() if tensorboard: writer.close()
def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) subsample_size = 1 if subsample else len(data) precision = 0.06 * scale Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(subsample): with pyro.plate("data", len(data), subsample_size, subsample) as ind: x = data[ind] z = pyro.sample("z", Normal(0, 1)) pyro.sample("x", Normal(z, 1), obs=x) def guide(subsample): scale = pyro.param("scale", lambda: torch.tensor([1.0])) with pyro.plate("data", len(data), subsample_size, subsample): loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0) z_dist = Normal(loc, scale) if has_rsample is not None: z_dist.has_rsample_(has_rsample) pyro.sample("z", z_dist) if scale != 1.0: model = poutine.scale(model, scale=scale) guide = poutine.scale(guide, scale=scale) num_particles = 50000 if local_samples: guide = config_enumerate(guide, num_samples=num_particles) num_particles = 1 optim = Adam({"lr": 0.1}) elbo = Elbo( max_plate_nesting=1, # set this to ensure rng agrees across runs num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=False, ) inference = SVI(model, guide, optim, loss=elbo) with xfail_if_not_implemented(): if subsample_size == 1: inference.loss_and_grads(model, guide, subsample=torch.tensor([0], dtype=torch.long)) inference.loss_and_grads(model, guide, subsample=torch.tensor([1], dtype=torch.long)) else: inference.loss_and_grads(model, guide, subsample=torch.tensor([0, 1], dtype=torch.long)) params = dict(pyro.get_param_store().named_parameters()) normalizer = 2 if subsample else 1 actual_grads = { name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items() } expected_grads = { "loc": scale * np.array([0.5, -2.0]), "scale": scale * np.array([2.0]), } for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=precision)
df["obs_perc_5"], df["obs_perc_95"], color='C1', alpha=0.5) plt.legend() if __name__ == '__main__': svi, model, guide = get_pyro_model(return_all=True) saved_param_files = glob.glob(MODEL_FILES) saved_param_files.sort(key=os.path.getmtime, reverse=True) print(*saved_param_files, sep='\n') idx = int(input("file? (0 for most recent exp) > ")) pyro.get_param_store().load(saved_param_files[idx]) saved_data_files = glob.glob(DATA_FILES) saved_data_files.sort(key=os.path.getmtime, reverse=True) print(*saved_data_files, sep='\n') idx = int(input("file? (0 for most recent data) > ")) training_generator = iter( get_dataset(batch_size=1000, data_file=saved_data_files[idx])) x_data, y_data = next(training_generator) for name, value in pyro.get_param_store().items(): print(name, pyro.param(name)) trace_summary(svi, x_data, y_data) guide_summary(guide, x_data, y_data)
def run_GMM(data, K): @config_enumerate def model(data): # Global variables. weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) scale = pyro.sample('scale', dist.LogNormal(4., 2.)) with pyro.plate('components', K): locs = pyro.sample('locs', dist.Normal(0., 10.)) with pyro.plate('data', len(data)): # Local variables. assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]}) elbo = TraceEnum_ELBO(max_plate_nesting=1) def init_loc_fn(site): if site["name"] == "weights": # Initialize weights to uniform. return torch.ones(K) / K if site["name"] == "scale": return (data.var() / 2).sqrt() if site["name"] == "locs": return data[torch.multinomial( torch.ones(len(data)) / len(data), K)] raise ValueError(site["name"]) def initialize(seed): global global_guide, svi pyro.set_rng_seed(seed) pyro.clear_param_store() global_guide = AutoDelta( poutine.block( model, expose=[ 'weights', 'locs', 'scale']), init_loc_fn=init_loc_fn) svi = SVI(model, global_guide, optim, loss=elbo) return svi.loss(model, global_guide, data) # Choose the best among 100 random initializations. loss, seed = min((initialize(seed), seed) for seed in range(100)) initialize(seed) print('seed = {}, initial_loss = {}'.format(seed, loss)) # Register hooks to monitor gradient norms. gradient_norms = defaultdict(list) for name, value in pyro.get_param_store().named_parameters(): value.register_hook( lambda g, name=name: gradient_norms[name].append( g.norm().item())) losses = [] for i in range(200 if not smoke_test else 2): loss = svi.step(data) losses.append(loss) print('.' if i % 100 else '\n', end='') print() map_estimates = global_guide(data) weights = map_estimates['weights'] locs = map_estimates['locs'] scale = map_estimates['scale'] print('weights = {}'.format(weights.data.numpy())) print('locs = {}'.format(locs.data.numpy())) print('scale = {}'.format(scale.data.numpy())) guide_trace = poutine.trace(global_guide).get_trace( data) # record the globals trained_model = poutine.replay( model, trace=guide_trace) # replay the globals def classifier(data, temperature=0): inferred_model = infer_discrete( trained_model, temperature=temperature, first_available_dim=- 2) # avoid conflict with data plate trace = poutine.trace(inferred_model).get_trace(data) return trace.nodes["assignment"]["value"] assignment = classifier(data) pyplot.figure(figsize=(8, 2), dpi=100).set_facecolor('white') pyplot.plot(data.numpy(), assignment.numpy(), 'bx') pyplot.title('MAP assignment') pyplot.xlabel('Latent posterior sample value') pyplot.ylabel('class assignment') return assignment
def test_particle_gradient(Elbo, reparameterized, has_rsample): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(): with pyro.plate("data", len(data)) as ind: x = data[ind] z = pyro.sample("z", Normal(0, 1)) pyro.sample("x", Normal(z, 1), obs=x) def guide(): scale = pyro.param("scale", lambda: torch.tensor([1.0])) with pyro.plate("data", len(data)): loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0) z_dist = Normal(loc, scale) if has_rsample is not None: z_dist.has_rsample_(has_rsample) pyro.sample("z", z_dist) elbo = Elbo( max_plate_nesting=1, # set this to ensure rng agrees across runs num_particles=1, strict_enumeration_warning=False, ) # Elbo gradient estimator pyro.set_rng_seed(0) elbo.loss_and_grads(model, guide) params = dict(pyro.get_param_store().named_parameters()) actual_grads = { name: param.grad.detach().cpu() for name, param in params.items() } # capture sample values and log_probs pyro.set_rng_seed(0) guide_tr = poutine.trace(guide).get_trace() model_tr = poutine.trace(poutine.replay(model, guide_tr)).get_trace() guide_tr.compute_log_prob() model_tr.compute_log_prob() x = data z = guide_tr.nodes["z"]["value"].data loc = pyro.param("loc").data scale = pyro.param("scale").data # expected grads if reparameterized and has_rsample is not False: # pathwise gradient estimator expected_grads = { "scale": -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum(0, keepdim=True) / scale, "loc": -(-z + (x - z)), } else: # score function gradient estimator elbo = (model_tr.nodes["x"]["log_prob"].data + model_tr.nodes["z"]["log_prob"].data - guide_tr.nodes["z"]["log_prob"].data) dlogq_dloc = (z - loc) / scale**2 dlogq_dscale = (z - loc)**2 / scale**3 - 1 / scale if Elbo is TraceEnum_ELBO: expected_grads = { "scale": -(dlogq_dscale * elbo - dlogq_dscale).sum(0, keepdim=True), "loc": -(dlogq_dloc * elbo - dlogq_dloc), } elif Elbo is Trace_ELBO: # expected value of dlogq_dscale and dlogq_dloc is zero expected_grads = { "scale": -(dlogq_dscale * elbo).sum(0, keepdim=True), "loc": -(dlogq_dloc * elbo), } for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name])) logger.info("actual {} = {}".format(name, actual_grads[name])) assert_equal(actual_grads, expected_grads, prec=1e-4)
percentile = pyro.sample("var3", dist.Uniform(0, 1)) if (percentile > 0.95): GPA = 4 else: GPA = pyro.sample("var4", dist.Normal(2.75, 0.5)) if (GPA == 4): Interviews = dist.Binomial(Recruiters, 0.9).sample() if (GPA < 4): Interviews = dist.Binomial(Recruiters, 0.6).sample() for n in range(1, 2): with pyro.iarange("data"): pyro.sample("obs", dist.Binomial(Interviews, 0.4), obs=data['offers'][n]) guide = ag.AutoDiagonalNormal(model) pyro.clear_param_store() optim = Adam({'lr': 0.01}) svi = SVI(model, guide, optim, loss=Trace_ELBO()) for i in range(1000): loss = svi.step(data) if ((i % 100) == 0): print(loss) for name in pyro.get_param_store().get_all_param_names(): print(name, pyro.param(name).data.numpy())
def _loss_and_grads_particle(self, weight, model_trace, guide_trace): # get info regarding rao-blackwellization of vectorized map_data guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"] model_vec_md_info = model_trace.graph["vectorized_map_data_info"] guide_vec_md_condition = guide_vec_md_info[ 'rao-blackwellization-condition'] model_vec_md_condition = model_vec_md_info[ 'rao-blackwellization-condition'] do_vec_rb = guide_vec_md_condition and model_vec_md_condition if not do_vec_rb: warnings.warn( "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. " "Falling back to higher-variance gradient estimator. " "Try to avoid these issues in your model and guide:\n{}". format("\n".join(guide_vec_md_info["warnings"] | model_vec_md_info["warnings"]))) guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set() model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set() # have the trace compute all the individual (batch) log pdf terms # so that they are available below guide_trace.compute_batch_log_pdf( site_filter=lambda name, site: name in guide_vec_md_nodes) guide_trace.log_pdf() model_trace.compute_batch_log_pdf( site_filter=lambda name, site: name in model_vec_md_nodes) model_trace.log_pdf() # prepare a list of all the cost nodes, each of which is +- log_pdf cost_nodes = [] non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": if model_site["is_observed"]: cost_nodes.append(CostNode(model_site["log_pdf"], True)) else: # cost node from model sample cost_nodes.append(CostNode(model_site["log_pdf"], True)) # cost node from guide sample guide_site = guide_trace.nodes[name] zero_expectation = name in non_reparam_nodes cost_nodes.append( CostNode(-guide_site["log_pdf"], not zero_expectation)) # compute the elbo; if all stochastic nodes are reparameterizable, we're done # this bit is never differentiated: it's here for getting an estimate of the elbo itself elbo = torch_data_sum(sum(c.cost for c in cost_nodes)) # compute the surrogate elbo, removing terms whose gradient is zero # this is the bit that's actually differentiated # XXX should the user be able to control if these terms are included? surrogate_elbo = sum(c.cost for c in cost_nodes if c.nonzero_expectation) # the following computations are only necessary if we have non-reparameterizable nodes baseline_loss = 0.0 if non_reparam_nodes: # recursively compute downstream cost nodes for all sample sites in model and guide # (even though ultimately just need for non-reparameterizable sample sites) # 1. downstream costs used for rao-blackwellization # 2. model observe sites (as well as terms that arise from the model and guide having different # dependency structures) are taken care of via 'children_in_model' below topo_sort_guide_nodes = list( reversed(list(networkx.topological_sort(guide_trace)))) topo_sort_guide_nodes = [ x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample" ] downstream_guide_cost_nodes = {} downstream_costs = {} for node in topo_sort_guide_nodes: node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \ guide_trace.nodes[node][node_log_pdf_key] nodes_included_in_sum = set([node]) downstream_guide_cost_nodes[node] = set([node]) for child in guide_trace.successors(node): child_cost_nodes = downstream_guide_cost_nodes[child] downstream_guide_cost_nodes[node].update(child_cost_nodes) if nodes_included_in_sum.isdisjoint( child_cost_nodes): # avoid duplicates if node_log_pdf_key == 'log_pdf': downstream_costs[node] += downstream_costs[ child].sum() else: downstream_costs[node] += downstream_costs[child] nodes_included_in_sum.update(child_cost_nodes) missing_downstream_costs = downstream_guide_cost_nodes[ node] - nodes_included_in_sum # include terms we missed because we had to avoid duplicates for missing_node in missing_downstream_costs: mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf' if node_log_pdf_key == 'log_pdf': downstream_costs[node] += ( model_trace.nodes[missing_node][mn_log_pdf_key] - guide_trace.nodes[missing_node][mn_log_pdf_key] ).sum() else: downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \ guide_trace.nodes[missing_node][mn_log_pdf_key] # finish assembling complete downstream costs # (the above computation may be missing terms from model) # XXX can we cache some of the sums over children_in_model to make things more efficient? for site in non_reparam_nodes: children_in_model = set() for node in downstream_guide_cost_nodes[site]: children_in_model.update(model_trace.successors(node)) # remove terms accounted for above children_in_model.difference_update( downstream_guide_cost_nodes[site]) for child in children_in_model: child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf' site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf' assert (model_trace.nodes[child]["type"] == "sample") if site_log_pdf_key == 'log_pdf': downstream_costs[site] += model_trace.nodes[child][ child_log_pdf_key].sum() else: downstream_costs[site] += model_trace.nodes[child][ child_log_pdf_key] # construct all the reinforce-like terms. # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance # XXX should the average baseline be in the param store as below? elbo_reinforce_terms = 0.0 for node in non_reparam_nodes: guide_site = guide_trace.nodes[node] log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf' downstream_cost = downstream_costs[node] baseline = 0.0 (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta, baseline_value) = _get_baseline_options(guide_site) use_nn_baseline = nn_baseline is not None use_baseline_value = baseline_value is not None assert(not (use_nn_baseline and use_baseline_value)), \ "cannot use baseline_value and nn_baseline simultaneously" if use_decaying_avg_baseline: avg_downstream_cost_old = pyro.param( "__baseline_avg_downstream_cost_" + node, ng_zeros(1), tags="__tracegraph_elbo_internal_tag") avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old avg_downstream_cost_old.data = avg_downstream_cost_new.data # XXX copy_() ? baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss baseline += nn_baseline(detach_iterable(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value if use_nn_baseline or use_baseline_value: # accumulate baseline loss baseline_loss += torch.pow( downstream_cost.detach() - baseline, 2.0).sum() guide_log_pdf = guide_site[log_pdf_key] / guide_site[ "scale"] # not scaled by subsampling if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value: if downstream_cost.size() != baseline.size(): raise ValueError( "Expected baseline at site {} to be {} instead got {}" .format(node, downstream_cost.size(), baseline.size())) downstream_cost = downstream_cost - baseline elbo_reinforce_terms += (guide_log_pdf * downstream_cost.detach()).sum() surrogate_elbo += elbo_reinforce_terms # collect parameters to train from model and guide trainable_params = set(site["value"] for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "param") if trainable_params: surrogate_loss = -surrogate_elbo torch_backward(weight * (surrogate_loss + baseline_loss)) pyro.get_param_store().mark_params_active(trainable_params) loss = -elbo return weight * loss
print("Saving") save_path = "../raw-results/" #save_path = "/afs/cs.stanford.edu/u/mhahn/scr/deps/" with open( save_path + "/manual_output_ground_coarse/" + args.language + "_" + __file__ + "_model_" + str(myID) + ".tsv", "w") as outFile: print("\t".join( list( map(str, [ "Counter", "Document", "DH_Mean_NoPunct", "DH_Sigma_NoPunct", "Distance_Mean_NoPunct", "Distance_Sigma_NoPunct", "Dependency" ]))), file=outFile) dh_numpy = pyro.get_param_store().get_param("mu_DH").data.numpy() dh_sigma_numpy = pyro.get_param_store().get_param( "sigma_DH").data.numpy() dist_numpy = pyro.get_param_store().get_param( "mu_Dist").data.numpy() dist_sigma_numpy = pyro.get_param_store().get_param( "sigma_Dist").data.numpy() for i in range(len(itos_deps)): key = itos_deps[i] dependency = key for doc in range(len(itos_docs)): print("\t".join( list( map(str, [ counter, itos_docs[doc], dh_numpy[doc, i],
def main(args): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') logging.info('Loading data') data = poly.load_data(poly.JSB_CHORALES) logging.info('-' * 40) model = models[args.model] logging.info('Training {} on {} sequences'.format( model.__name__, len(data['train']['sequences']))) sequences = data['train']['sequences'] lengths = data['train']['sequence_lengths'] # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() pyro.enable_validation(__debug__) # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({'lr': args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug('time to compile: {} s.'.format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info('training loss = {}'.format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info('-' * 40) logging.info('Evaluating on {} test sequences'.format( len(data['test']['sequences']))) sequences = data['test']['sequences'][..., present_notes] lengths = data['test']['sequence_lengths'] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info('{} capacity = {} parameters'.format(model.__name__, capacity))
data, labels = create_simple_classification_dataset(num_schedules) schedule_starts = np.linspace(0, 20 * (num_schedules-1), num=num_schedules) not_first_time = False distributions = [np.array([.5, .1], dtype=float) for _ in range(num_schedules)] # each one is mean, sigma print('Inference') for epoch in range(num_epochs): # for j, (imgs, lbls) in enumerate(train_loader, 0): # loss = inference.step(imgs.to(device), lbls.to(device)) for _ in range(num_schedules): x_data = [] y_data = [] chosen_schedule_start = int(np.random.choice(schedule_starts)) schedule_num = int(chosen_schedule_start / 20) if not_first_time: pyro.get_param_store().get_state()['params']['emm']=Variable(torch.Tensor([distributions[schedule_num][0]]),requires_grad=True) pyro.get_param_store().get_state()['params']['ems']=Variable(torch.Tensor([distributions[schedule_num][1]]),requires_grad=True) # print(pyro.get_param_store().get_state()['params']['emm']) else: not_first_time = True for each_t in range(chosen_schedule_start, chosen_schedule_start + 20): x = data[each_t][2:] x_data.append(x) # noinspection PyArgumentList x = torch.Tensor([x]).reshape((2)) label = labels[each_t] y_data.append(label) # noinspection PyArgumentList label = torch.Tensor([label]).reshape(1) label = Variable(label).long()
def load_model(self, filename): pyro.get_param_store().load(filename)
# In[20]: adam_params = {"lr": 0.001, "betas": (0.90, 0.999)} optimizer = Adam(adam_params) # setup the inference algorithm svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) n_steps = 1000 # do gradient steps for step in range(n_steps): svi.step(x, y) # In[21]: for name in pyro.get_param_store(): print(name + ':{}'.format(pyro.param(name))) # In[22]: y_pred = Predictive(model=model, guide=guide, num_samples=1000, return_sites=["y"]) # In[23]: x_ = torch.tensor(np.linspace(-2, 2, 100)) y_ = y_pred.get_samples(x_, None) # In[ ]:
def forward(self, inputs, n_samples=10, avg_posterior=False, seeds=None): if seeds: if len(seeds) != n_samples: raise ValueError( "Number of seeds should match number of samples.") if self.inference == "svi": if avg_posterior is True: guide_trace = poutine.trace(self.guide).get_trace(inputs) avg_state_dict = {} for key in self.basenet.state_dict().keys(): avg_weights = guide_trace.nodes[str(key) + "_loc"]['value'] avg_state_dict.update({str(key): avg_weights}) self.basenet.load_state_dict(avg_state_dict) preds = [self.basenet.model(inputs)] else: preds = [] if seeds: for seed in seeds: pyro.set_rng_seed(seed) guide_trace = poutine.trace( self.guide).get_trace(inputs) preds.append(guide_trace.nodes['_RETURN']['value']) else: for _ in range(n_samples): guide_trace = poutine.trace( self.guide).get_trace(inputs) preds.append(guide_trace.nodes['_RETURN']['value']) if DEBUG: print("\nlearned variational params:\n") print(pyro.get_param_store().get_all_param_names()) print( list( poutine.trace( self.guide).get_trace(inputs).nodes.keys())) print("\n", pyro.get_param_store()["model.0.weight_loc"][0][:5]) print(guide_trace.nodes['module$$$model.0.weight'] ["fn"].loc[0][:5]) print( "posterior sample: ", guide_trace.nodes['module$$$model.0.weight']['value'] [5][0][0]) elif self.inference == "hmc": preds = [] posterior_predictive = list(self.posterior_predictive.values()) if seeds is None: seeds = range(n_samples) for seed in seeds: net = posterior_predictive[seed] preds.append(net.forward(inputs)) output_probs = torch.stack(preds).mean(0) return output_probs
def get_count_matrix_from_encodings(z: np.ndarray, d: np.ndarray, p: Union[np.ndarray, None], model: VariationalInferenceModel, dataset_obj, cells_only: bool = True) -> sp.csc.csc_matrix: """Make point estimate of the ambient-background-subtracted UMI count matrix. Sample counts by maximizing the model posterior based on learned latent variables. The output matrix is in sparse form. Args: z: Latent variable embedding of gene expression in a low-dimensional space. d: Latent variable scale factor for the number of UMI counts coming from each real cell. p: Latent variable denoting probability that each barcode contains a real cell. model: Model with latent variables already inferred. dataset_obj: Input dataset. cells_only: If True, only returns the encodings of barcodes that are determined to contain cells. Returns: inferred_count_matrix: Matrix of the same dimensions as the input matrix, but where the UMI counts have had ambient-background subtracted. Note: This currently uses the MAP estimate of draws from a Poisson (or a negative binomial with zero overdispersion). """ # If simple model was used, then p = None. Here set it to 1. if p is None: p = np.ones_like(d) # Get the count matrix with genes trimmed. if cells_only: count_matrix = dataset_obj.get_count_matrix() else: count_matrix = dataset_obj.get_count_matrix_all_barcodes() logging.info("Getting ambient-background-subtracted UMI count matrix.") # Ensure there are no nans in p (there shouldn't be). p_no_nans = p p_no_nans[np.isnan(p)] = 0 # Just make sure there are no nans. # Trim everything down to the barcodes we are interested in (just cells?). if cells_only: d = d[p_no_nans > 0.5] z = z[p_no_nans > 0.5, :] barcode_inds = dataset_obj.analyzed_barcode_inds[p_no_nans > 0.5] else: # Set cell size factors equal to zero where cell probability < 0.5. d[p_no_nans < 0.5] = 0. z[p_no_nans < 0.5, :] = 0. barcode_inds = np.arange(0, count_matrix.shape[0]) # All barcodes # Get mean of the inferred posterior for the overdispersion, phi. phi = pyro.get_param_store().get_param("phi_loc").detach().cpu().numpy().item() # Get the gene expression vectors by sending latent z through the decoder. # Send dataset through the learned encoder in chunks. barcodes = [] genes = [] counts = [] s = 200 for i in np.arange(0, barcode_inds.size, s): # TODO: for 117000 cells, this routine overflows (~15GB) memory last_ind_this_chunk = min(count_matrix.shape[0], i+s) # Decode gene expression for a chunk of barcodes. decoded = model.decoder(torch.Tensor( z[i:last_ind_this_chunk]).to(device=model.device)) chi = decoded.detach().cpu().numpy() # Estimate counts for the chunk of barcodes. chunk_dense_counts = estimate_counts(chi, d[i:last_ind_this_chunk], phi) # Turn the floating point count estimates into integers. decimal_values, _ = np.modf(chunk_dense_counts) # Stuff after decimal. roundoff_counts = np.random.binomial(1, p=decimal_values) # Bernoulli. chunk_dense_counts = np.floor(chunk_dense_counts).astype(dtype=int) chunk_dense_counts += roundoff_counts # Find all the nonzero counts in this dense matrix chunk. nonzero_barcode_inds_this_chunk, nonzero_genes_trimmed = \ np.nonzero(chunk_dense_counts) nonzero_counts = \ chunk_dense_counts[nonzero_barcode_inds_this_chunk, nonzero_genes_trimmed].flatten(order='C') # Get the original gene index from gene index in the trimmed dataset. nonzero_genes = dataset_obj.analyzed_gene_inds[nonzero_genes_trimmed] # Get the actual barcode values. nonzero_barcode_inds = nonzero_barcode_inds_this_chunk + i nonzero_barcodes = barcode_inds[nonzero_barcode_inds] # Append these to their lists. barcodes.extend(nonzero_barcodes.astype(dtype=np.uint32)) genes.extend(nonzero_genes.astype(dtype=np.uint16)) counts.extend(nonzero_counts.astype(dtype=np.uint32)) # Convert the lists to numpy arrays. counts = np.array(counts, dtype=np.uint32) barcodes = np.array(barcodes, dtype=np.uint32) genes = np.array(genes, dtype=np.uint16) # Put the counts into a sparse csc_matrix. inferred_count_matrix = sp.csc_matrix((counts, (barcodes, genes)), shape=dataset_obj.data['matrix'].shape) return inferred_count_matrix
def get_param(name): return pyro.get_param_store()[name]
def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo = 0.0 surrogate_elbo = 0.0 trainable_params = set() # grab a trace from the generator for weight, model_trace, guide_trace, log_r in self._get_traces( model, guide, *args, **kwargs): elbo_particle = weight * 0 surrogate_elbo_particle = weight * 0 # compute elbo and surrogate elbo log_pdf = "batch_log_pdf" if ( self.enum_discrete and weight.size(0) > 1) else "log_pdf" for name in model_trace.nodes.keys(): if model_trace.nodes[name]["type"] == "sample": if model_trace.nodes[name]["is_observed"]: elbo_particle += model_trace.nodes[name][log_pdf] surrogate_elbo_particle += model_trace.nodes[name][ log_pdf] else: lp_lq = model_trace.nodes[name][ log_pdf] - guide_trace.nodes[name][log_pdf] elbo_particle += lp_lq if guide_trace.nodes[name]["fn"].reparameterized: surrogate_elbo_particle += lp_lq else: # XXX should the user be able to control inclusion of the -logq term below? surrogate_elbo_particle += model_trace.nodes[name][log_pdf] + \ log_r.detach() * guide_trace.nodes[name][log_pdf] # drop terms of weight zero to avoid nans if isinstance(weight, numbers.Number): if weight == 0.0: elbo_particle = torch_zeros_like(elbo_particle) surrogate_elbo_particle = torch_zeros_like( surrogate_elbo_particle) else: weight_eq_zero = (weight == 0) elbo_particle[weight_eq_zero] = 0.0 surrogate_elbo_particle[weight_eq_zero] = 0.0 elbo += torch_data_sum(weight * elbo_particle) surrogate_elbo += torch_sum(weight * surrogate_elbo_particle) # grab model parameters to train for name in model_trace.nodes.keys(): if model_trace.nodes[name]["type"] == "param": trainable_params.add(model_trace.nodes[name]["value"]) # grab guide parameters to train for name in guide_trace.nodes.keys(): if guide_trace.nodes[name]["type"] == "param": trainable_params.add(guide_trace.nodes[name]["value"]) loss = -elbo surrogate_loss = -surrogate_elbo if trainable_params: torch_backward(surrogate_loss) pyro.get_param_store().mark_params_active(trainable_params) return loss
def loss_and_grads(model, guide, *args, **kwargs): _loss = self._loss(model, guide, *args, **kwargs) _loss.backward() pyro.get_param_store().mark_params_active(pyro.get_param_store().get_all_param_names()) return _loss