def history_policy_model(state, history, t, discount=1.0, discount_factor=0.95, max_depth=10): """history is a string""" with scope(prefix=history): if t > max_depth: return pyro.sample("a%d" % t, dist.Categorical(torch.ones(len(actions)))) action_weights = torch.zeros(len(actions)) for i, action in enumerate(actions): with scope(prefix="%s%d" % (action, t)): value = history_value_model(state, action, history, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth) action_weights[i] = torch.exp(value) # Make the weights positive, then subtract from max min_weight = torch.min(action_weights) max_weight = torch.max(action_weights) action_weights = tensor([ remap(action_weights[i], min_weight, max_weight, 0., 1.) for i in range(len(action_weights)) ]) return actions[pyro.sample("a%d" % t, dist.Categorical(action_weights))]
def policy_model(state, t, discount=1.0, discount_factor=0.95, max_depth=10, alpha=0.1): """Returns Pr(a|s)""" # Weight the actions based on the value, and return the most # likely action if t > max_depth: return pyro.sample("a%d" % t, dist.Categorical(tensor([1., 1., 1.]))) action_weights = torch.zeros(len(actions)) for i, action in enumerate(actions): with scope(prefix="%s%d" % (action, t)): value = value_model(state, action, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth) action_weights[i] = torch.exp( alpha * value) # action weight is softmax of value # Make the weights positive, then subtract from max min_weight = torch.min(action_weights) max_weight = torch.max(action_weights) action_weights = tensor([ remap(action_weights[i], min_weight, max_weight, 0., 1.) for i in range(len(action_weights)) ]) return actions[pyro.sample("a%d" % t, dist.Categorical(action_weights))]
def model1(r=True): model2() with scope(prefix="inter"): model2() if r: model1(r=False) model2()
def __call__(self, *args, **kwargs): with scope(prefix=self.stochastic_name ) if self.stochastic_name else nullcontext(): return super(type(self), self).__call__( *args, **kwargs, **{ name: (spec() if isinstance(spec, AbstractSampler) else spec) for name, spec in self.stochastic_specs.items() if name not in kwargs })
def __call__(self, **kwargs): if self._pyrofit_instance_name is None: cls.__call__(self, **kwargs) updates_var = { key: val() for key, val in VARIABLES[self._pyrofit_instance_name].items() } kwargs.update(updates_var) return scope(cls.__call__, prefix=self._pyrofit_instance_name)(self, **kwargs)
def test_only_withs(): def model1(): with scope(prefix="a"): with scope(prefix="b"): pyro.sample("x", dist.Bernoulli(0.5)) tr1 = poutine.trace(name_count(model1)).get_trace() assert "a/b/x" in tr1.nodes tr2 = poutine.trace(name_count(scope(prefix="model1")(model1))).get_trace() assert "model1/a/b/x" in tr2.nodes
def history_policy_model_guide(state, history, t, discount=1.0, discount_factor=0.95, max_depth=10): with scope(prefix=history): weights = pyro.param("action_weights", torch.ones(len(actions)), constraint=dist.constraints.simplex) for i, action in enumerate(actions): with scope(prefix="%s%d" % (action, t)): value = history_value_model(state, action, history, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth) return actions[pyro.sample("a%d" % t, dist.Categorical(weights))]
def belief_policy_model_guide(belief, t, discount=1.0, discount_factor=0.95, max_depth=10, bu_nsteps=10, bu_lr=0.1): weights = pyro.param("action_weights", torch.ones(len(actions)), constraint=dist.constraints.simplex) for i, action in enumerate(actions): with scope(prefix="%s%d" % (action,t)): value = belief_value_model(belief, action, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth, bu_nsteps=bu_nsteps, bu_lr=bu_lr) action = pyro.sample("a%d" % t, dist.Categorical(weights))
def __init__(self, name, **kwargs): self._pyrofit_instance_name = name if self._pyrofit_instance_name is None: cls.__init__(self, **kwargs) updates_set = {key: val for key, val in SETTINGS[name].items()} kwargs.update(updates_set) try: scoped_init = scope(cls.__init__, prefix=self._pyrofit_instance_name)( self, **kwargs) except TypeError as e: raise TypeError(str(e) + f" [{name}]") return scoped_init
def sample_action(state, player, max_depth=0, t=0): action_weights = torch.zeros(9) for i in range(9): action = Action(player, i) with scope(prefix="a{}{}{}".format(state, t, action)): outcome = sample_outcome(state, action, player, max_depth=max_depth, t=t) expected_reward = reward_probability(outcome, player) action_weights[i] = np.exp(expected_reward) location = pyro.sample("a{}{}".format(state, t), pyro.distributions.Categorical(action_weights)) return Action(player, location.item())
def do_sampling(): root = self.root_node_type(tf=self.root_node_tf) self._set_node_parameters(root, detach=detach) tree.add_node(root) node_queue = [root] k = 0 while len(node_queue) > 0: parent = node_queue.pop(0) # Ask node to sample its children. with scope(prefix=parent.name): children = parent.sample_children() k += 1 for child in children: self._set_node_parameters(child, detach=detach) tree.add_node(child) tree.add_edge(parent, child) node_queue.append(child)
def _register_param(self, name): """ In "model" mode, lifts the parameter with name ``name`` to a random sample using a predefined prior (from :meth:`set_prior` method). In "guide" mode, we use the guide generated from :meth:`autoguide`. :param str name: Name of the parameter. """ if name in self._priors: with autoname.scope(prefix=self._get_name()): if self.mode == "model": p = pyro.sample(name, self._priors[name]) else: p = self._sample_from_guide(name) elif name in self._constraints: p_unconstrained = self._parameters["{}_unconstrained".format(name)] p = transform_to(self._constraints[name])(p_unconstrained) self.register_buffer(name, p)
def sample_action_guide(state, player, max_depth=0, t=0): params = pyro.param( str(state) + player + "params", torch.ones(9), torch.distributions.constraints.positive) location = pyro.sample("a{}{}".format(state, t), pyro.distributions.Categorical(params)) action = Action(player, location.item()) for i in range(9): action = Action(player, i) with scope(prefix="a{}{}{}".format(state, t, action)): outcome = sample_outcome(state, action, player, max_depth=max_depth, t=t) return action
def policy_model_guide(state, t, discount=1.0, discount_factor=0.95, max_depth=10): # You must reproduce the same structure in model... # prior weights is uniform weights = pyro.param("action_weights", torch.ones(len(actions)), constraint=dist.constraints.simplex) for i, action in enumerate(actions): with scope(prefix="%s%d" % (action, t)): value = value_model(state, action, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth) action = pyro.sample("a%d" % t, dist.Categorical(weights))
def sample_outcome(state, action, player, max_depth=0, t=0): next_state = copy.deepcopy(state) apply_transistion(next_state, action) if next_state.outcome != None: return next_state.outcome # just assume we've lost if we haven't won yet, this is not very sofisticated # and resticts us to working with x as the main player if t >= max_depth: return "o" with scope(prefix="o{}{}".format(next_state, t)): return sample_outcome(next_state, sample_action(next_state, other_player(player), max_depth=max_depth, t=t + 1), other_player(player), max_depth=max_depth, t=t + 1)
def belief_policy_model(belief, t, discount=1.0, discount_factor=0.95, max_depth=10, bu_nsteps=10, bu_lr=0.1): if t > max_depth: return pyro.sample("a%d" % t, dist.Categorical(tensor([1., 1., 1.]))) action_weights = torch.zeros(len(actions)) for i, action in enumerate(actions): with scope(prefix="%s%d" % (action,t)): value = belief_value_model(belief, action, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth, bu_nsteps=bu_nsteps, bu_lr=bu_lr) action_weights[i] = torch.exp(value) # action weight is softmax of value # Make the weights positive, then subtract from max min_weight = torch.min(action_weights) max_weight = torch.max(action_weights) action_weights = tensor([remap(action_weights[i], min_weight, max_weight, 0., 1.) for i in range(len(action_weights))]) return actions[pyro.sample("a%d" % t, dist.Categorical(action_weights))]
def model(): # Resample the continuous structure of the tree. node_queue = [root] while len(node_queue) > 0: parent = node_queue.pop(0) children, rules = scene_tree.get_children_and_rules(parent) for child, rule in zip(children, rules): with scope(prefix=parent.name): rule.sample_child(parent, child) node_queue.append(child) # Implement observation constraints if fix_observeds: xyz_observed_variance = 1E-2 rot_observed_variance = 1E-2 for node, original_node in zip(scene_tree.nodes, original_tree.nodes): if node.observed: xyz_observed_dist = dist.Normal(original_node.translation, xyz_observed_variance) rot_observed_dist = dist.Normal(original_node.rotation, rot_observed_variance) pyro.sample("%s_xyz_observed" % node.name, xyz_observed_dist, obs=node.translation) pyro.sample("%s_rotation_observed" % node.name, rot_observed_dist, obs=node.rotation) for k, constraint in enumerate(constraints): clamped_error_distribution = dist.Normal(0., 0.001) violation, _, _ = constraint.eval_violation(scene_tree) positive_violations = torch.clamp(violation, 0., np.inf) pyro.sample("%s_%d_err" % (type(constraint).__name__, k), clamped_error_distribution, obs=positive_violations)
def _reg_fn(fn): # Prefix sample sites name = fn.__qualname__ fn = scope(fn, prefix=name) # Inspect function signature sig = _parse_signature(fn) def wrapped_fn(**kwargs): # Checking consistency of kwargs assert sorted(list(kwargs.keys())) == sig['params'], """ '%s': keyword arguments %s expected, but %s given""" % ( name, str(sig['params']), str(list(kwargs))) assert sorted(list(VARIABLES[name])) == sig['yaml'], """ '%s': yaml variables %s expected, but %s given""" % ( name, str(sig['yaml_var']), str(VARIABLES[name])) updates = {key: val() for key, val in VARIABLES[name].items()} # Update kwargs and run wrapped function kwargs.update(updates) return fn(**kwargs) return wrapped_fn
def model(self, x, zs): # pylint: disable=too-many-locals def _compute_rim(decoded): shared_representation = get_module( "metagene_shared", lambda: torch.nn.Sequential( torch.nn.Conv2d( decoded.shape[1], decoded.shape[1], kernel_size=1), torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05), torch.nn.LeakyReLU(0.2, inplace=True), ), )(decoded) rim = torch.cat( [ get_module( f"decoder_{_encode_metagene_name(n)}", partial(self._create_metagene_decoder, decoded.shape[1], n), )(shared_representation) for n in self.metagenes ], dim=1, ) rim = torch.nn.functional.softmax(rim, dim=1) return rim num_genes = x["data"][0].shape[1] decoded = self._decode(zs) label = center_crop(x["label"], [None, *decoded.shape[-2:]]) rim = checkpoint(_compute_rim, decoded) rim = center_crop(rim, [None, None, *label.shape[-2:]]) rim = p.sample("rim", Delta(rim)) scale = p.sample( "scale", Delta( center_crop( self._get_scale_decoder(decoded.shape[1])(decoded), [None, None, *label.shape[-2:]], )), ) rim = scale * rim with p.poutine.scale(scale=len(x["data"]) / self.n): rate_mg_prior = Normal( 0.0, 1e-8 + get_param( "rate_mg_sd", lambda: torch.ones(num_genes), constraint=constraints.positive, ), ) rate_mg = torch.stack([ p.sample(_encode_metagene_name(n), rate_mg_prior) for n in self.metagenes ]) rate_mg = p.sample("rate_mg", Delta(rate_mg)) rate_g_effects_baseline = get_param( "rate_g_effects_baseline", lambda: self.__init_rate_baseline().log(), lr_multiplier=5.0, ) logits_g_effects_baseline = get_param( "logits_g_effects_baseline", # pylint: disable=unnecessary-lambda self.__init_logits_baseline, lr_multiplier=5.0, ) rate_g_effects_prior = Normal( 0.0, 1e-8 + get_param( "rate_g_effects_sd", lambda: torch.ones(num_genes), constraint=constraints.positive, ), ) rate_g_effects = p.sample("rate_g_effects", rate_g_effects_prior) rate_g_effects = torch.cat( [rate_g_effects_baseline.unsqueeze(0), rate_g_effects]) logits_g_effects_prior = Normal( 0.0, 1e-8 + get_param( "logits_g_effects_sd", lambda: torch.ones(num_genes), constraint=constraints.positive, ), ) logits_g_effects = p.sample( "logits_g_effects", logits_g_effects_prior, ) logits_g_effects = torch.cat( [logits_g_effects_baseline.unsqueeze(0), logits_g_effects]) effects = torch.cat( [ torch.ones(x["effects"].shape[0], 1).to(x["effects"]), x["effects"], ], 1, ).float() logits_g = effects @ logits_g_effects rate_g = effects @ rate_g_effects rate_mg = rate_g[:, None] + rate_mg with scope(prefix=self.tag): image_distr = self._sample_image(x, decoded) def _compute_sample_params(data, label, rim, rate_mg, logits_g): nonmissing = label != 0 zero_count_spots = 1 + torch.where(data.sum(1) == 0)[0] nonpartial = binary_fill_holes( np.isin(label.cpu(), [0, *zero_count_spots.cpu()])) nonpartial = torch.as_tensor(nonpartial).to(nonmissing) mask = nonpartial & nonmissing if not mask.any(): return ( data[[]], torch.zeros(0, num_genes).to(rim), logits_g.expand(0, -1), ) label = label[mask] - 1 idxs, label = torch.unique(label, return_inverse=True) data = data[idxs] rim = rim[:, mask] labelonehot = sparseonehot(label) rim = torch.sparse.mm(labelonehot.t().float(), rim.t()) rgs = rim @ rate_mg.exp() return data, rgs, logits_g.expand(len(rgs), -1) data, rgs, logits_g = zip(*it.starmap( _compute_sample_params, zip(x["data"], label, rim, rate_mg, logits_g), )) expression_distr = NegativeBinomial( total_count=1e-8 + torch.cat(rgs), logits=torch.cat(logits_g), ) p.sample("xsg", expression_distr, obs=torch.cat(data)) return image_distr, expression_distr
def model1(): with scope(prefix="a"): with scope(prefix="b"): pyro.sample("x", dist.Bernoulli(0.5))
def model(self, x, zs): # pylint: disable=too-many-locals, too-many-statements dataset = require("dataloader").dataset def _compute_rim(decoded): shared_representation = get_module( "metagene_shared", lambda: torch.nn.Sequential( torch.nn.Conv2d( decoded.shape[1], decoded.shape[1], kernel_size=1 ), torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05), torch.nn.LeakyReLU(0.2, inplace=True), ), )(decoded) rim = torch.cat( [ get_module( f"decoder_{_encode_metagene_name(n)}", partial( self._create_metagene_decoder, decoded.shape[1], n ), )(shared_representation) for n in self.metagenes ], dim=1, ) rim = torch.nn.functional.softmax(rim, dim=1) return rim decoded = self._decode(zs) label = center_crop(x["label"], [None, *decoded.shape[-2:]]) rim = checkpoint(_compute_rim, decoded) rim = center_crop(rim, [None, None, *label.shape[-2:]]) rim = pyro.sample("rim", Delta(rim)) scale = pyro.sample( "scale", Delta( center_crop( self._get_scale_decoder(decoded.shape[1])(decoded), [None, None, *label.shape[-2:]], ) ), ) rim = scale * rim rate_mg_prior = Normal( 0.0, 1e-8 + get_param( "rate_mg_prior_sd", lambda: torch.ones(len(self._allocated_genes)), constraint=constraints.positive, ), ) with pyro.poutine.scale(scale=len(x["data"]) / dataset.size()): rate_mg = torch.stack( [ pyro.sample( _encode_metagene_name(n), rate_mg_prior, infer={"is_global": True}, ) for n in self.metagenes ] ) rate_mg = pyro.sample("rate_mg", Delta(rate_mg)) rate_g_conditions_prior = Normal( 0.0, 1e-8 + get_param( "rate_g_conditions_prior_sd", lambda: torch.ones(len(self._allocated_genes)), constraint=constraints.positive, ), ) logits_g_conditions_prior = Normal( 0.0, 1e-8 + get_param( "logits_g_conditions_prior_sd", lambda: torch.ones(len(self._allocated_genes)), constraint=constraints.positive, ), ) rate_g, logits_g = [], [] for batch_idx, (slide, covariates) in enumerate( zip(x["slide"], x["covariates"]) ): rate_g_slide = get_param( "rate_g_condition_baseline", lambda: self.__init_rate_baseline().log(), lr_multiplier=5.0, ) logits_g_slide = get_param( "logits_g_condition_baseline", self.__init_logits_baseline, lr_multiplier=5.0, ) for covariate, condition in covariates.items(): try: conditions = get("covariates")[covariate] except KeyError: continue if pd.isna(condition): with pyro.poutine.scale( scale=1.0 / dataset.size(slide=slide) ): pyro.sample( f"condition-{covariate}-{batch_idx}", OneHotCategorical( to_device(torch.ones(len(conditions))) / len(conditions) ), infer={"is_global": True}, ) # ^ NOTE 1: This statement affects the ELBO but not its # gradient. The pmf is non-differentiable but # it doesn't matter---our prior over the # conditions is uniform; even if a gradient # existed, it would always be zero. # ^ NOTE 2: The result is used to index the effect of # the condition. However, this takes place in # the guide to avoid sampling effets that are # not used in the current minibatch, # potentially (?) reducing noise in the # learning signal. Therefore, the result here # is discarded. condition_scale = 1e-99 # ^ HACK: Pyro requires scale > 0 else: condition_scale = 1.0 / dataset.size( covariate=covariate, condition=condition ) with pyro.poutine.scale(scale=condition_scale): rate_g_slide = rate_g_slide + pyro.sample( f"rate_g_condition-{covariate}-{batch_idx}", rate_g_conditions_prior, infer={"is_global": True}, ) logits_g_slide = logits_g_slide + pyro.sample( f"logits_g_condition-{covariate}-{batch_idx}", logits_g_conditions_prior, infer={"is_global": True}, ) rate_g.append(rate_g_slide) logits_g.append(logits_g_slide) logits_g = torch.stack(logits_g)[:, self._gene_indices] rate_g = torch.stack(rate_g)[:, self._gene_indices] rate_mg = rate_g.unsqueeze(1) + rate_mg[:, self._gene_indices] with scope(prefix=self.tag): self._sample_image(x, decoded) for i, (data, label, rim, rate_mg, logits_g) in enumerate( zip(x["data"], label, rim, rate_mg, logits_g) ): zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0] partial_idxs = np.unique( torch.cat([label[0], label[-1], label[:, 0], label[:, -1]]) .cpu() .numpy() ) partial_idxs = np.setdiff1d( partial_idxs, zero_count_idxs.cpu().numpy() ) mask = np.invert( np.isin(label.cpu().numpy(), [0, *partial_idxs]) ) mask = torch.as_tensor(mask, device=label.device) if not mask.any(): continue label = label[mask] idxs, label = torch.unique(label, return_inverse=True) data = data[idxs - 1] pyro.sample(f"idx-{i}", Delta(idxs.float())) rim = rim[:, mask] labelonehot = sparseonehot(label) rim = torch.sparse.mm(labelonehot.t().float(), rim.t()) rsg = rim @ rate_mg.exp() expression_distr = NegativeBinomial( total_count=1e-8 + rsg, logits=logits_g ) pyro.sample(f"xsg-{i}", expression_distr, obs=data)
def rejection_sample_structure_to_feasibility(tree, constraints=[], max_n_iters=100, do_forward_sim=False, timestep=0.001, T=1.): # Pre-build prog to check ourselves against builder, mbp, sg, node_to_free_body_ids_map, body_id_to_node_map = \ compile_scene_tree_to_mbp_and_sg(tree, timestep=timestep) mbp.Finalize() floating_base_bodies = mbp.GetFloatingBaseBodies() diagram = builder.Build() diagram_context = diagram.CreateDefaultContext() mbp_context = diagram.GetMutableSubsystemContext(mbp, diagram_context) q0 = mbp.GetPositions(mbp_context) nq = len(q0) # Set up projection NLP. ik = InverseKinematics(mbp, mbp_context) q_dec = ik.q() prog = ik.prog() # Nonpenetration constraint. ik.AddMinimumDistanceConstraint(0.001) # Other requested constraints. for constraint in constraints: constraint.add_to_ik_prog(tree, ik, mbp, mbp_context, node_to_free_body_ids_map) from pyro.contrib.autoname import scope best_q = q0 best_violation = np.inf for k in range(max_n_iters): node_queue = [tree.get_root()] while len(node_queue) > 0: parent = node_queue.pop(0) children, rules = tree.get_children_and_rules(parent) for child, rule in zip(children, rules): with scope(prefix=parent.name): rule.sample_child(parent, child) node_queue.append(child) for node, body_ids in node_to_free_body_ids_map.items(): for body_id in body_ids: mbp.SetFreeBodyPose(mbp_context, mbp.get_body(body_id), torch_tf_to_drake_tf(node.tf)) q = mbp.GetPositions(mbp_context) all_bindings = prog.GetAllConstraints() satisfied = prog.CheckSatisfied(all_bindings, q) if satisfied: return tree, True # Otherwise compute violation evals = np.concatenate([ binding.evaluator().Eval(q).flatten() for binding in all_bindings ]) lbs = np.concatenate([ binding.evaluator().lower_bound().flatten() for binding in all_bindings ]) ubs = np.concatenate([ binding.evaluator().upper_bound().flatten() for binding in all_bindings ]) viols = np.maximum(np.clip(lbs - evals, 0., np.inf), np.clip(evals - ubs, 0., np.inf)) total_violation = np.sum(viols) if total_violation < best_violation: print("Updating best viol to ", best_violation) best_violation = total_violation best_q = q # Load best q into tree mbp.SetPositions(mbp_context, q) for body_id, node in body_id_to_node_map.items(): if body_id in floating_base_bodies: node.tf = drake_tf_to_torch_tf( mbp.GetFreeBodyPose(mbp_context, mbp.get_body(body_id))) return tree, False