示例#1
0
def history_policy_model(state,
                         history,
                         t,
                         discount=1.0,
                         discount_factor=0.95,
                         max_depth=10):
    """history is a string"""
    with scope(prefix=history):
        if t > max_depth:
            return pyro.sample("a%d" % t,
                               dist.Categorical(torch.ones(len(actions))))
        action_weights = torch.zeros(len(actions))
        for i, action in enumerate(actions):
            with scope(prefix="%s%d" % (action, t)):
                value = history_value_model(state,
                                            action,
                                            history,
                                            t,
                                            discount=discount,
                                            discount_factor=discount_factor,
                                            max_depth=max_depth)
                action_weights[i] = torch.exp(value)
    # Make the weights positive, then subtract from max
    min_weight = torch.min(action_weights)
    max_weight = torch.max(action_weights)
    action_weights = tensor([
        remap(action_weights[i], min_weight, max_weight, 0., 1.)
        for i in range(len(action_weights))
    ])
    return actions[pyro.sample("a%d" % t, dist.Categorical(action_weights))]
示例#2
0
def policy_model(state,
                 t,
                 discount=1.0,
                 discount_factor=0.95,
                 max_depth=10,
                 alpha=0.1):
    """Returns Pr(a|s)"""
    # Weight the actions based on the value, and return the most
    # likely action
    if t > max_depth:
        return pyro.sample("a%d" % t, dist.Categorical(tensor([1., 1., 1.])))
    action_weights = torch.zeros(len(actions))
    for i, action in enumerate(actions):
        with scope(prefix="%s%d" % (action, t)):
            value = value_model(state,
                                action,
                                t,
                                discount=discount,
                                discount_factor=discount_factor,
                                max_depth=max_depth)
            action_weights[i] = torch.exp(
                alpha * value)  # action weight is softmax of value
    # Make the weights positive, then subtract from max
    min_weight = torch.min(action_weights)
    max_weight = torch.max(action_weights)
    action_weights = tensor([
        remap(action_weights[i], min_weight, max_weight, 0., 1.)
        for i in range(len(action_weights))
    ])
    return actions[pyro.sample("a%d" % t, dist.Categorical(action_weights))]
示例#3
0
 def model1(r=True):
     model2()
     with scope(prefix="inter"):
         model2()
         if r:
             model1(r=False)
     model2()
示例#4
0
 def __call__(self, *args, **kwargs):
     with scope(prefix=self.stochastic_name
                ) if self.stochastic_name else nullcontext():
         return super(type(self), self).__call__(
             *args, **kwargs, **{
                 name:
                 (spec() if isinstance(spec, AbstractSampler) else spec)
                 for name, spec in self.stochastic_specs.items()
                 if name not in kwargs
             })
示例#5
0
 def __call__(self, **kwargs):
     if self._pyrofit_instance_name is None:
         cls.__call__(self, **kwargs)
     updates_var = {
         key: val()
         for key, val in VARIABLES[self._pyrofit_instance_name].items()
     }
     kwargs.update(updates_var)
     return scope(cls.__call__,
                  prefix=self._pyrofit_instance_name)(self, **kwargs)
示例#6
0
def test_only_withs():
    def model1():
        with scope(prefix="a"):
            with scope(prefix="b"):
                pyro.sample("x", dist.Bernoulli(0.5))

    tr1 = poutine.trace(name_count(model1)).get_trace()
    assert "a/b/x" in tr1.nodes

    tr2 = poutine.trace(name_count(scope(prefix="model1")(model1))).get_trace()
    assert "model1/a/b/x" in tr2.nodes
示例#7
0
def history_policy_model_guide(state,
                               history,
                               t,
                               discount=1.0,
                               discount_factor=0.95,
                               max_depth=10):
    with scope(prefix=history):
        weights = pyro.param("action_weights",
                             torch.ones(len(actions)),
                             constraint=dist.constraints.simplex)
        for i, action in enumerate(actions):
            with scope(prefix="%s%d" % (action, t)):
                value = history_value_model(state,
                                            action,
                                            history,
                                            t,
                                            discount=discount,
                                            discount_factor=discount_factor,
                                            max_depth=max_depth)
    return actions[pyro.sample("a%d" % t, dist.Categorical(weights))]
示例#8
0
def belief_policy_model_guide(belief, t, discount=1.0, discount_factor=0.95, max_depth=10,
                              bu_nsteps=10, bu_lr=0.1):
    weights = pyro.param("action_weights", torch.ones(len(actions)),
                         constraint=dist.constraints.simplex)
    for i, action in enumerate(actions):
        with scope(prefix="%s%d" % (action,t)):
            value = belief_value_model(belief, action, t,
                                       discount=discount, discount_factor=discount_factor,
                                       max_depth=max_depth,
                                       bu_nsteps=bu_nsteps,
                                       bu_lr=bu_lr)
    action = pyro.sample("a%d" % t, dist.Categorical(weights))
示例#9
0
 def __init__(self, name, **kwargs):
     self._pyrofit_instance_name = name
     if self._pyrofit_instance_name is None:
         cls.__init__(self, **kwargs)
     updates_set = {key: val for key, val in SETTINGS[name].items()}
     kwargs.update(updates_set)
     try:
         scoped_init = scope(cls.__init__,
                             prefix=self._pyrofit_instance_name)(
                                 self, **kwargs)
     except TypeError as e:
         raise TypeError(str(e) + f" [{name}]")
     return scoped_init
示例#10
0
def sample_action(state, player, max_depth=0, t=0):
    action_weights = torch.zeros(9)
    for i in range(9):
        action = Action(player, i)
        with scope(prefix="a{}{}{}".format(state, t, action)):
            outcome = sample_outcome(state,
                                     action,
                                     player,
                                     max_depth=max_depth,
                                     t=t)
        expected_reward = reward_probability(outcome, player)
        action_weights[i] = np.exp(expected_reward)

    location = pyro.sample("a{}{}".format(state, t),
                           pyro.distributions.Categorical(action_weights))
    return Action(player, location.item())
示例#11
0
 def do_sampling():
     root = self.root_node_type(tf=self.root_node_tf)
     self._set_node_parameters(root, detach=detach)
     tree.add_node(root)
     node_queue = [root]
     k = 0
     while len(node_queue) > 0:
         parent = node_queue.pop(0)
         # Ask node to sample its children.
         with scope(prefix=parent.name):
             children = parent.sample_children()
         k += 1
         for child in children:
             self._set_node_parameters(child, detach=detach)
             tree.add_node(child)
             tree.add_edge(parent, child)
             node_queue.append(child)
示例#12
0
    def _register_param(self, name):
        """
        In "model" mode, lifts the parameter with name ``name`` to a random
        sample using a predefined prior (from :meth:`set_prior` method). In
        "guide" mode, we use the guide generated from :meth:`autoguide`.

        :param str name: Name of the parameter.
        """
        if name in self._priors:
            with autoname.scope(prefix=self._get_name()):
                if self.mode == "model":
                    p = pyro.sample(name, self._priors[name])
                else:
                    p = self._sample_from_guide(name)
        elif name in self._constraints:
            p_unconstrained = self._parameters["{}_unconstrained".format(name)]
            p = transform_to(self._constraints[name])(p_unconstrained)
        self.register_buffer(name, p)
示例#13
0
def sample_action_guide(state, player, max_depth=0, t=0):
    params = pyro.param(
        str(state) + player + "params", torch.ones(9),
        torch.distributions.constraints.positive)
    location = pyro.sample("a{}{}".format(state, t),
                           pyro.distributions.Categorical(params))
    action = Action(player, location.item())

    for i in range(9):
        action = Action(player, i)
        with scope(prefix="a{}{}{}".format(state, t, action)):
            outcome = sample_outcome(state,
                                     action,
                                     player,
                                     max_depth=max_depth,
                                     t=t)

    return action
示例#14
0
def policy_model_guide(state,
                       t,
                       discount=1.0,
                       discount_factor=0.95,
                       max_depth=10):
    # You must reproduce the same structure in model...
    # prior weights is uniform
    weights = pyro.param("action_weights",
                         torch.ones(len(actions)),
                         constraint=dist.constraints.simplex)
    for i, action in enumerate(actions):
        with scope(prefix="%s%d" % (action, t)):
            value = value_model(state,
                                action,
                                t,
                                discount=discount,
                                discount_factor=discount_factor,
                                max_depth=max_depth)
    action = pyro.sample("a%d" % t, dist.Categorical(weights))
示例#15
0
def sample_outcome(state, action, player, max_depth=0, t=0):
    next_state = copy.deepcopy(state)
    apply_transistion(next_state, action)
    if next_state.outcome != None:
        return next_state.outcome

    # just assume we've lost if we haven't won yet, this is not very sofisticated
    # and resticts us to working with x as the main player
    if t >= max_depth:
        return "o"

    with scope(prefix="o{}{}".format(next_state, t)):
        return sample_outcome(next_state,
                              sample_action(next_state,
                                            other_player(player),
                                            max_depth=max_depth,
                                            t=t + 1),
                              other_player(player),
                              max_depth=max_depth,
                              t=t + 1)
示例#16
0
def belief_policy_model(belief, t, discount=1.0, discount_factor=0.95, max_depth=10,
                        bu_nsteps=10, bu_lr=0.1):
    if t > max_depth:
        return pyro.sample("a%d" % t, dist.Categorical(tensor([1., 1., 1.])))
    action_weights = torch.zeros(len(actions))
    for i, action in enumerate(actions):
        with scope(prefix="%s%d" % (action,t)):
            value = belief_value_model(belief, action, t,
                                       discount=discount,
                                       discount_factor=discount_factor,
                                       max_depth=max_depth,
                                       bu_nsteps=bu_nsteps,
                                       bu_lr=bu_lr)
            action_weights[i] = torch.exp(value)  # action weight is softmax of value
    # Make the weights positive, then subtract from max
    min_weight = torch.min(action_weights)
    max_weight = torch.max(action_weights)
    action_weights = tensor([remap(action_weights[i], min_weight, max_weight, 0., 1.)
                             for i in range(len(action_weights))])
    return actions[pyro.sample("a%d" % t, dist.Categorical(action_weights))]
示例#17
0
    def model():
        # Resample the continuous structure of the tree.
        node_queue = [root]
        while len(node_queue) > 0:
            parent = node_queue.pop(0)
            children, rules = scene_tree.get_children_and_rules(parent)
            for child, rule in zip(children, rules):
                with scope(prefix=parent.name):
                    rule.sample_child(parent, child)
                node_queue.append(child)

        # Implement observation constraints
        if fix_observeds:
            xyz_observed_variance = 1E-2
            rot_observed_variance = 1E-2
            for node, original_node in zip(scene_tree.nodes,
                                           original_tree.nodes):
                if node.observed:
                    xyz_observed_dist = dist.Normal(original_node.translation,
                                                    xyz_observed_variance)
                    rot_observed_dist = dist.Normal(original_node.rotation,
                                                    rot_observed_variance)
                    pyro.sample("%s_xyz_observed" % node.name,
                                xyz_observed_dist,
                                obs=node.translation)
                    pyro.sample("%s_rotation_observed" % node.name,
                                rot_observed_dist,
                                obs=node.rotation)

        for k, constraint in enumerate(constraints):
            clamped_error_distribution = dist.Normal(0., 0.001)
            violation, _, _ = constraint.eval_violation(scene_tree)
            positive_violations = torch.clamp(violation, 0., np.inf)
            pyro.sample("%s_%d_err" % (type(constraint).__name__, k),
                        clamped_error_distribution,
                        obs=positive_violations)
示例#18
0
def _reg_fn(fn):
    # Prefix sample sites
    name = fn.__qualname__
    fn = scope(fn, prefix=name)

    # Inspect function signature
    sig = _parse_signature(fn)

    def wrapped_fn(**kwargs):
        # Checking consistency of kwargs
        assert sorted(list(kwargs.keys())) == sig['params'], """
        '%s': keyword arguments %s expected, but %s given""" % (
            name, str(sig['params']), str(list(kwargs)))
        assert sorted(list(VARIABLES[name])) == sig['yaml'], """
        '%s': yaml variables %s expected, but %s given""" % (
            name, str(sig['yaml_var']), str(VARIABLES[name]))

        updates = {key: val() for key, val in VARIABLES[name].items()}

        # Update kwargs and run wrapped function
        kwargs.update(updates)
        return fn(**kwargs)

    return wrapped_fn
示例#19
0
    def model(self, x, zs):
        # pylint: disable=too-many-locals
        def _compute_rim(decoded):
            shared_representation = get_module(
                "metagene_shared",
                lambda: torch.nn.Sequential(
                    torch.nn.Conv2d(
                        decoded.shape[1], decoded.shape[1], kernel_size=1),
                    torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05),
                    torch.nn.LeakyReLU(0.2, inplace=True),
                ),
            )(decoded)
            rim = torch.cat(
                [
                    get_module(
                        f"decoder_{_encode_metagene_name(n)}",
                        partial(self._create_metagene_decoder,
                                decoded.shape[1], n),
                    )(shared_representation) for n in self.metagenes
                ],
                dim=1,
            )
            rim = torch.nn.functional.softmax(rim, dim=1)
            return rim

        num_genes = x["data"][0].shape[1]
        decoded = self._decode(zs)
        label = center_crop(x["label"], [None, *decoded.shape[-2:]])

        rim = checkpoint(_compute_rim, decoded)
        rim = center_crop(rim, [None, None, *label.shape[-2:]])
        rim = p.sample("rim", Delta(rim))

        scale = p.sample(
            "scale",
            Delta(
                center_crop(
                    self._get_scale_decoder(decoded.shape[1])(decoded),
                    [None, None, *label.shape[-2:]],
                )),
        )
        rim = scale * rim

        with p.poutine.scale(scale=len(x["data"]) / self.n):
            rate_mg_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "rate_mg_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            rate_mg = torch.stack([
                p.sample(_encode_metagene_name(n), rate_mg_prior)
                for n in self.metagenes
            ])
            rate_mg = p.sample("rate_mg", Delta(rate_mg))

            rate_g_effects_baseline = get_param(
                "rate_g_effects_baseline",
                lambda: self.__init_rate_baseline().log(),
                lr_multiplier=5.0,
            )
            logits_g_effects_baseline = get_param(
                "logits_g_effects_baseline",
                # pylint: disable=unnecessary-lambda
                self.__init_logits_baseline,
                lr_multiplier=5.0,
            )
            rate_g_effects_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "rate_g_effects_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            rate_g_effects = p.sample("rate_g_effects", rate_g_effects_prior)
            rate_g_effects = torch.cat(
                [rate_g_effects_baseline.unsqueeze(0), rate_g_effects])
            logits_g_effects_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "logits_g_effects_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            logits_g_effects = p.sample(
                "logits_g_effects",
                logits_g_effects_prior,
            )
            logits_g_effects = torch.cat(
                [logits_g_effects_baseline.unsqueeze(0), logits_g_effects])

            effects = torch.cat(
                [
                    torch.ones(x["effects"].shape[0], 1).to(x["effects"]),
                    x["effects"],
                ],
                1,
            ).float()

            logits_g = effects @ logits_g_effects
            rate_g = effects @ rate_g_effects
            rate_mg = rate_g[:, None] + rate_mg

        with scope(prefix=self.tag):
            image_distr = self._sample_image(x, decoded)

            def _compute_sample_params(data, label, rim, rate_mg, logits_g):
                nonmissing = label != 0
                zero_count_spots = 1 + torch.where(data.sum(1) == 0)[0]
                nonpartial = binary_fill_holes(
                    np.isin(label.cpu(), [0, *zero_count_spots.cpu()]))
                nonpartial = torch.as_tensor(nonpartial).to(nonmissing)
                mask = nonpartial & nonmissing

                if not mask.any():
                    return (
                        data[[]],
                        torch.zeros(0, num_genes).to(rim),
                        logits_g.expand(0, -1),
                    )

                label = label[mask] - 1
                idxs, label = torch.unique(label, return_inverse=True)
                data = data[idxs]

                rim = rim[:, mask]
                labelonehot = sparseonehot(label)
                rim = torch.sparse.mm(labelonehot.t().float(), rim.t())

                rgs = rim @ rate_mg.exp()

                return data, rgs, logits_g.expand(len(rgs), -1)

            data, rgs, logits_g = zip(*it.starmap(
                _compute_sample_params,
                zip(x["data"], label, rim, rate_mg, logits_g),
            ))

            expression_distr = NegativeBinomial(
                total_count=1e-8 + torch.cat(rgs),
                logits=torch.cat(logits_g),
            )
            p.sample("xsg", expression_distr, obs=torch.cat(data))

        return image_distr, expression_distr
示例#20
0
 def model1():
     with scope(prefix="a"):
         with scope(prefix="b"):
             pyro.sample("x", dist.Bernoulli(0.5))
示例#21
0
文件: st.py 项目: ludvb/xfuse
    def model(self, x, zs):
        # pylint: disable=too-many-locals, too-many-statements

        dataset = require("dataloader").dataset

        def _compute_rim(decoded):
            shared_representation = get_module(
                "metagene_shared",
                lambda: torch.nn.Sequential(
                    torch.nn.Conv2d(
                        decoded.shape[1], decoded.shape[1], kernel_size=1
                    ),
                    torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05),
                    torch.nn.LeakyReLU(0.2, inplace=True),
                ),
            )(decoded)
            rim = torch.cat(
                [
                    get_module(
                        f"decoder_{_encode_metagene_name(n)}",
                        partial(
                            self._create_metagene_decoder, decoded.shape[1], n
                        ),
                    )(shared_representation)
                    for n in self.metagenes
                ],
                dim=1,
            )
            rim = torch.nn.functional.softmax(rim, dim=1)
            return rim

        decoded = self._decode(zs)
        label = center_crop(x["label"], [None, *decoded.shape[-2:]])

        rim = checkpoint(_compute_rim, decoded)
        rim = center_crop(rim, [None, None, *label.shape[-2:]])
        rim = pyro.sample("rim", Delta(rim))

        scale = pyro.sample(
            "scale",
            Delta(
                center_crop(
                    self._get_scale_decoder(decoded.shape[1])(decoded),
                    [None, None, *label.shape[-2:]],
                )
            ),
        )
        rim = scale * rim

        rate_mg_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "rate_mg_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )
        with pyro.poutine.scale(scale=len(x["data"]) / dataset.size()):
            rate_mg = torch.stack(
                [
                    pyro.sample(
                        _encode_metagene_name(n),
                        rate_mg_prior,
                        infer={"is_global": True},
                    )
                    for n in self.metagenes
                ]
            )
        rate_mg = pyro.sample("rate_mg", Delta(rate_mg))

        rate_g_conditions_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "rate_g_conditions_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )
        logits_g_conditions_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "logits_g_conditions_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )

        rate_g, logits_g = [], []

        for batch_idx, (slide, covariates) in enumerate(
            zip(x["slide"], x["covariates"])
        ):
            rate_g_slide = get_param(
                "rate_g_condition_baseline",
                lambda: self.__init_rate_baseline().log(),
                lr_multiplier=5.0,
            )
            logits_g_slide = get_param(
                "logits_g_condition_baseline",
                self.__init_logits_baseline,
                lr_multiplier=5.0,
            )

            for covariate, condition in covariates.items():
                try:
                    conditions = get("covariates")[covariate]
                except KeyError:
                    continue

                if pd.isna(condition):
                    with pyro.poutine.scale(
                        scale=1.0 / dataset.size(slide=slide)
                    ):
                        pyro.sample(
                            f"condition-{covariate}-{batch_idx}",
                            OneHotCategorical(
                                to_device(torch.ones(len(conditions)))
                                / len(conditions)
                            ),
                            infer={"is_global": True},
                        )
                        # ^ NOTE 1: This statement affects the ELBO but not its
                        #           gradient. The pmf is non-differentiable but
                        #           it doesn't matter---our prior over the
                        #           conditions is uniform; even if a gradient
                        #           existed, it would always be zero.
                        # ^ NOTE 2: The result is used to index the effect of
                        #           the condition. However, this takes place in
                        #           the guide to avoid sampling effets that are
                        #           not used in the current minibatch,
                        #           potentially (?) reducing noise in the
                        #           learning signal. Therefore, the result here
                        #           is discarded.
                    condition_scale = 1e-99
                    # ^ HACK: Pyro requires scale > 0
                else:
                    condition_scale = 1.0 / dataset.size(
                        covariate=covariate, condition=condition
                    )

                with pyro.poutine.scale(scale=condition_scale):
                    rate_g_slide = rate_g_slide + pyro.sample(
                        f"rate_g_condition-{covariate}-{batch_idx}",
                        rate_g_conditions_prior,
                        infer={"is_global": True},
                    )
                    logits_g_slide = logits_g_slide + pyro.sample(
                        f"logits_g_condition-{covariate}-{batch_idx}",
                        logits_g_conditions_prior,
                        infer={"is_global": True},
                    )

            rate_g.append(rate_g_slide)
            logits_g.append(logits_g_slide)

        logits_g = torch.stack(logits_g)[:, self._gene_indices]
        rate_g = torch.stack(rate_g)[:, self._gene_indices]
        rate_mg = rate_g.unsqueeze(1) + rate_mg[:, self._gene_indices]

        with scope(prefix=self.tag):
            self._sample_image(x, decoded)

            for i, (data, label, rim, rate_mg, logits_g) in enumerate(
                zip(x["data"], label, rim, rate_mg, logits_g)
            ):
                zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0]
                partial_idxs = np.unique(
                    torch.cat([label[0], label[-1], label[:, 0], label[:, -1]])
                    .cpu()
                    .numpy()
                )
                partial_idxs = np.setdiff1d(
                    partial_idxs, zero_count_idxs.cpu().numpy()
                )
                mask = np.invert(
                    np.isin(label.cpu().numpy(), [0, *partial_idxs])
                )
                mask = torch.as_tensor(mask, device=label.device)

                if not mask.any():
                    continue

                label = label[mask]
                idxs, label = torch.unique(label, return_inverse=True)
                data = data[idxs - 1]
                pyro.sample(f"idx-{i}", Delta(idxs.float()))

                rim = rim[:, mask]
                labelonehot = sparseonehot(label)
                rim = torch.sparse.mm(labelonehot.t().float(), rim.t())
                rsg = rim @ rate_mg.exp()

                expression_distr = NegativeBinomial(
                    total_count=1e-8 + rsg, logits=logits_g
                )
                pyro.sample(f"xsg-{i}", expression_distr, obs=data)
def rejection_sample_structure_to_feasibility(tree,
                                              constraints=[],
                                              max_n_iters=100,
                                              do_forward_sim=False,
                                              timestep=0.001,
                                              T=1.):
    # Pre-build prog to check ourselves against
    builder, mbp, sg, node_to_free_body_ids_map, body_id_to_node_map = \
        compile_scene_tree_to_mbp_and_sg(tree, timestep=timestep)
    mbp.Finalize()
    floating_base_bodies = mbp.GetFloatingBaseBodies()
    diagram = builder.Build()
    diagram_context = diagram.CreateDefaultContext()
    mbp_context = diagram.GetMutableSubsystemContext(mbp, diagram_context)
    q0 = mbp.GetPositions(mbp_context)
    nq = len(q0)

    # Set up projection NLP.
    ik = InverseKinematics(mbp, mbp_context)
    q_dec = ik.q()
    prog = ik.prog()
    # Nonpenetration constraint.
    ik.AddMinimumDistanceConstraint(0.001)
    # Other requested constraints.
    for constraint in constraints:
        constraint.add_to_ik_prog(tree, ik, mbp, mbp_context,
                                  node_to_free_body_ids_map)

    from pyro.contrib.autoname import scope
    best_q = q0
    best_violation = np.inf
    for k in range(max_n_iters):
        node_queue = [tree.get_root()]
        while len(node_queue) > 0:
            parent = node_queue.pop(0)
            children, rules = tree.get_children_and_rules(parent)
            for child, rule in zip(children, rules):
                with scope(prefix=parent.name):
                    rule.sample_child(parent, child)
                node_queue.append(child)
        for node, body_ids in node_to_free_body_ids_map.items():
            for body_id in body_ids:
                mbp.SetFreeBodyPose(mbp_context, mbp.get_body(body_id),
                                    torch_tf_to_drake_tf(node.tf))
        q = mbp.GetPositions(mbp_context)
        all_bindings = prog.GetAllConstraints()
        satisfied = prog.CheckSatisfied(all_bindings, q)
        if satisfied:
            return tree, True
        # Otherwise compute violation
        evals = np.concatenate([
            binding.evaluator().Eval(q).flatten() for binding in all_bindings
        ])
        lbs = np.concatenate([
            binding.evaluator().lower_bound().flatten()
            for binding in all_bindings
        ])
        ubs = np.concatenate([
            binding.evaluator().upper_bound().flatten()
            for binding in all_bindings
        ])
        viols = np.maximum(np.clip(lbs - evals, 0., np.inf),
                           np.clip(evals - ubs, 0., np.inf))
        total_violation = np.sum(viols)
        if total_violation < best_violation:
            print("Updating best viol to ", best_violation)
            best_violation = total_violation
            best_q = q
    # Load best q into tree
    mbp.SetPositions(mbp_context, q)
    for body_id, node in body_id_to_node_map.items():
        if body_id in floating_base_bodies:
            node.tf = drake_tf_to_torch_tf(
                mbp.GetFreeBodyPose(mbp_context, mbp.get_body(body_id)))
    return tree, False