Beispiel #1
0
    def __call__(self, *args, **kwargs):
        key = _hashable_args_kwargs(args, kwargs)

        # if first time
        if key not in self.compiled:
            # param capture
            with poutine.block():
                with poutine.trace(param_only=True) as first_param_capture:
                    self.fn(*args, **kwargs)

            self._param_names = list(
                set(first_param_capture.trace.nodes.keys()))
            unconstrained_params = tuple(
                pyro.param(name).unconstrained() for name in self._param_names)
            params_and_args = unconstrained_params + args
            weakself = weakref.ref(self)

            def compiled(*params_and_args):
                self = weakself()
                unconstrained_params = params_and_args[:len(self._param_names)]
                args = params_and_args[len(self._param_names):]
                constrained_params = {}
                for name, unconstrained_param in zip(self._param_names,
                                                     unconstrained_params):
                    constrained_param = pyro.param(
                        name)  # assume param has been initialized
                    assert constrained_param.unconstrained(
                    ) is unconstrained_param
                    constrained_params[name] = constrained_param
                return poutine.replay(self.fn,
                                      params=constrained_params)(*args,
                                                                 **kwargs)

            if self.ignore_warnings:
                compiled = ignore_jit_warnings()(compiled)
            with pyro.validation_enabled(False):
                time_compilation = self.jit_options.pop(
                    "time_compilation", False)
                with optional(timed(), time_compilation) as t:
                    self.compiled[key] = torch.jit.trace(
                        compiled, params_and_args, **self.jit_options)
                if time_compilation:
                    self.compile_time = t.elapsed
        else:
            unconstrained_params = [
                pyro.param(name).unconstrained() for name in self._param_names
            ]
            params_and_args = unconstrained_params + list(args)

        with poutine.block(hide=self._param_names):
            with poutine.trace(param_only=True) as param_capture:
                ret = self.compiled[key](*params_and_args)

        for name in param_capture.trace.nodes.keys():
            if name not in self._param_names:
                raise NotImplementedError(
                    "pyro.ops.jit.trace assumes all params are created on "
                    "first invocation, but found new param: {}".format(name))

        return ret
Beispiel #2
0
def test_distribution_validate_args(dist_class, args, validate_args):
    with pyro.validation_enabled(validate_args):
        if not validate_args:
            dist_class(**args)
        else:
            with pytest.raises(ValueError):
                dist_class(**args)
Beispiel #3
0
    def _lp_fn_jit(self, skip_jit_warnings, jit_options, params):
        if not params:
            return self._lp_fn(params)
        names, vals = zip(*sorted(params.items()))

        if self._compiled_fn:
            return self._compiled_fn(*vals)

        with pyro.validation_enabled(False):
            tmp = []
            for _, v in pyro.get_param_store().named_parameters():
                if v.requires_grad:
                    v.requires_grad_(False)
                    tmp.append(v)

            def _lp_jit(*zi):
                params = dict(zip(names, zi))
                return self._lp_fn(params)

            if skip_jit_warnings:
                _lp_jit = ignore_jit_warnings()(_lp_jit)
            self._compiled_fn = torch.jit.trace(_lp_jit, vals, **jit_options)

            for v in tmp:
                v.requires_grad_(True)
            return self._compiled_fn(*vals)
Beispiel #4
0
    def _potential_energy_jit(self, z):
        names, vals = zip(*sorted(z.items()))
        if self._compiled_potential_fn:
            return self._compiled_potential_fn(*vals)

        def compiled(*zi):
            z_constrained = list(zi)
            # transform to constrained space.
            for i, name in enumerate(names):
                if name in self.transforms:
                    transform = self.transforms[name]
                    z_constrained[i] = transform.inv(z_constrained[i])
            z_constrained = dict(zip(names, z_constrained))
            trace = self._get_trace(z_constrained)
            potential_energy = -self._compute_trace_log_prob(trace)
            # adjust by the jacobian for this transformation.
            for i, name in enumerate(names):
                if name in self.transforms:
                    transform = self.transforms[name]
                    potential_energy += transform.log_abs_det_jacobian(
                        z_constrained[name], zi[i]).sum()
            return potential_energy

        with pyro.validation_enabled(False), optional(
                ignore_jit_warnings(), self._ignore_jit_warnings):
            self._compiled_potential_fn = torch.jit.trace(compiled,
                                                          vals,
                                                          check_trace=False)
        return self._compiled_potential_fn(*vals)
Beispiel #5
0
def test_distribution_validate_args(dist_class, args, validate_args):
    with pyro.validation_enabled(validate_args):
        if not validate_args:
            dist_class(**args)
        else:
            with pytest.raises(ValueError):
                dist_class(**args)
Beispiel #6
0
 def test_random_module(self):
     pyro.clear_param_store()
     with pyro.validation_enabled():
         lifted_tr = poutine.trace(pyro.random_module("name", self.model, prior=self.prior)).get_trace()
     for name in lifted_tr.nodes.keys():
         if lifted_tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
Beispiel #7
0
 def reset_step_size_adaptation(self, z):
     r"""
     Finds a reasonable step size and resets step size adaptation scheme.
     """
     if self._find_reasonable_step_size is not None:
         with pyro.validation_enabled(False):
             self.step_size = self._find_reasonable_step_size(z)
     self._step_size_adapt_scheme.prox_center = math.log(10 * self.step_size)
     self._step_size_adapt_scheme.reset()
Beispiel #8
0
 def test_random_module_warn(self):
     pyro.clear_param_store()
     bad_prior = {'foo': None}
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         with pyro.validation_enabled():
             poutine.trace(pyro.random_module("name", self.model, prior=bad_prior)).get_trace()
         assert len(w), 'No warnings were raised'
         for warning in w:
             logger.info(warning)
Beispiel #9
0
def initialize_model(param, ind2val, itemattr, dataloaders):
    pyro.clear_param_store()
    pyro.validation_enabled(False)
    torch.manual_seed(param['train_seed'])

    dummybatch = next(iter(dataloaders['train']))
    dummybatch['phase_mask'] = (dummybatch['mask_type'] == 1).float()
    dummybatch = {
        key: val.long().to(param.get("device"))
        for key, val in dummybatch.items()
    }
    if param.get('remove_item_group'):
        itemattr['category'] = itemattr['category'] * 0
    model = models.PyroRecommender(**param,
                                   item_group=torch.tensor(
                                       itemattr['category']).long())
    guide = models.MeanFieldGuide(model=model, batch=dummybatch, **param)

    return model, guide
Beispiel #10
0
def test_iarange_broadcast_error(Elbo, is_validate):
    def model():
        p = torch.tensor(0.5, requires_grad=True)
        with pyro.iarange("iarange", 10, 5):
            pyro.sample("x", dist.Bernoulli(p).expand_by([1]))

    with pyro.validation_enabled(is_validate):
        if is_validate:
            assert_error(model, model, Elbo())
        else:
            assert_ok(model, model, Elbo())
Beispiel #11
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return params
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad,
                self.step_size, self.num_steps, z_grads=z_grads)
            # apply Metropolis correction.
            r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
            energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # handle the NaN case which may be the case for a diverging trajectory
        # when using a large step size.
        delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy
        if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps:
            self._divergences.append(self._t - self._warmup_steps)

        accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.),
                                                                     scalar_like(accept_prob, 1.)))
        accepted = False
        if rand < accept_prob:
            accepted = True
            z = z_new
            z_grads = z_grads_new
            self._cache(z, potential_energy_new, z_grads)

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)

        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n
        return z.copy()
Beispiel #12
0
    def run(self, *args, **kwargs):
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with pyro.validation_enabled(not self.disable_validation):
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            if hasattr(self.kernel, 'transforms'):
                if self.kernel.transforms is not None:
                    self.transforms = self.kernel.transforms
            elif self.kernel.model:
                _, _, self.transforms, _ = initialize_model(
                    self.kernel.model, model_args=args, model_kwargs=kwargs)
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
Beispiel #13
0
def test_iarange_broadcast_error(Elbo, is_validate):

    def model():
        p = torch.tensor(0.5, requires_grad=True)
        with pyro.iarange("iarange", 10, 5):
            pyro.sample("x", dist.Bernoulli(p).expand_by([1]))

    with pyro.validation_enabled(is_validate):
        if is_validate:
            assert_error(model, model, Elbo())
        else:
            assert_ok(model, model, Elbo())
Beispiel #14
0
    def run(self, *args, **kwargs):
        """
        Run StreamingMCMC to compute required `self._statistics`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains

        with optional(
                pyro.validation_enabled(not self.disable_validation),
                self.disable_validation is not None,
        ):
            args = [
                arg.detach() if torch.is_tensor(arg) else arg for arg in args
            ]
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    # If transforms is not explicitly provided, infer automatically using
                    # model args, kwargs.
                    if self.transforms is None:
                        self._set_transforms(*args, **kwargs)
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x

                    # unpack latent
                    pos = 0
                    z_acc = z_structure.copy()
                    for k in sorted(z_structure):
                        shape = z_structure[k]
                        next_pos = pos + shape.numel()
                        z_acc[k] = x_cloned[pos:next_pos].reshape(shape)
                        pos = next_pos

                    for name, z in z_acc.items():
                        if name in self.transforms:
                            z_acc[name] = self.transforms[name].inv(z)

                    self._statistics.update({
                        (chain_id, name): transformed_sample
                        for name, transformed_sample in z_acc.items()
                    })

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
Beispiel #15
0
    def _configure_adaptation(self):
        initial_step_size = None
        if self.adapt_step_size:
            z = {
                name: node["value"].detach()
                for name, node in self._iter_latent_nodes(self.initial_trace)
            }
            for name, transform in self.transforms.items():
                z[name] = transform(z[name])
            with pyro.validation_enabled(False):
                initial_step_size = self._find_reasonable_step_size(z)

        self._adapter.configure(self._warmup_steps, initial_step_size)
Beispiel #16
0
    def sample(self, trace):
        z = {
            name: node["value"].detach()
            for name, node in self._iter_latent_nodes(trace)
        }
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])

        r, _ = self._sample_r(name="r_t={}".format(self._t))

        potential_energy, z_grads = self._fetch_from_cache()
        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False),
                      self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z,
                r,
                self._potential_energy,
                self.inverse_mass_matrix,
                self.step_size,
                self.num_steps,
                z_grads=z_grads)
            # apply Metropolis correction.
            energy_proposal = self._kinetic_energy(
                r_new) + potential_energy_new
            energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \
                else self._energy(z, r)
        delta_energy = energy_proposal - energy_current
        # Set accept prob to 0.0 if delta_energy is `NaN` which may be
        # the case for a diverging trajectory when using a large step size.
        if torch_isnan(delta_energy):
            accept_prob = delta_energy.new_tensor(0.0)
        else:
            accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t),
                           dist.Uniform(torch.zeros(1), torch.ones(1)))
        if rand < accept_prob:
            self._accept_cnt += 1
            z = z_new

        if self._t < self._warmup_steps:
            self._adapter.step(self._t, z, accept_prob)

        self._t += 1

        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Beispiel #17
0
def test_enum_discrete_iarange_dependency_warning(enumerate_, is_validate):

    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        with pyro.iarange("iarange", 10, 5):
            x = pyro.sample("x", dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})
        pyro.sample("y", dist.Bernoulli(x.mean()))  # user should move this line up

    with pyro.validation_enabled(is_validate):
        if enumerate_ and is_validate:
            assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
        else:
            assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Beispiel #18
0
    def to_script_module(self):
        """
        Compile this module using :func:`torch.jit.trace_module` ,
        assuming self has already been fit to data.

        :return: A traced version of self with an :meth:`ite` method.
        :rtype: torch.jit.ScriptModule
        """
        self.train(False)
        fake_x = torch.randn(2, self.feature_dim)
        with pyro.validation_enabled(False):
            # Disable check_trace due to nondeterministic nodes.
            result = torch.jit.trace_module(self, {"ite": (fake_x,)}, check_trace=False)
        return result
Beispiel #19
0
def test_enum_discrete_iarange_dependency_warning(enumerate_, is_validate):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        with pyro.iarange("iarange", 10, 5):
            x = pyro.sample("x",
                            dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})
        pyro.sample("y",
                    dist.Bernoulli(x.mean()))  # user should move this line up

    with pyro.validation_enabled(is_validate):
        if enumerate_ and is_validate:
            assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
        else:
            assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Beispiel #20
0
def test_irange_in_guide_not_model_error(subsample_size, Elbo, is_validate):
    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 10, subsample_size):
            pass
        pyro.sample("x", dist.Bernoulli(p))

    with pyro.validation_enabled(is_validate):
        if is_validate:
            assert_error(model, guide, Elbo())
        else:
            assert_ok(model, guide, Elbo())
Beispiel #21
0
def test_irange_in_guide_not_model_error(subsample_size, Elbo, is_validate):

    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 10, subsample_size):
            pass
        pyro.sample("x", dist.Bernoulli(p))

    with pyro.validation_enabled(is_validate):
        if is_validate:
            assert_error(model, guide, Elbo())
        else:
            assert_ok(model, guide, Elbo())
Beispiel #22
0
def test_enum_discrete_non_enumerated_iarange_ok(enumerate_):

    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})

        with pyro.iarange("non_enum", 2):
            a = pyro.sample("a", dist.Bernoulli(0.5).expand_by([2]),
                            infer={'enumerate': None})

        p = (1.0 + a.sum(-1)) / (2.0 + a.size(0))  # introduce dependency of b on a

        with pyro.iarange("enum_1", 3):
            pyro.sample("b", dist.Bernoulli(p).expand_by([3]),
                        infer={'enumerate': enumerate_})

    with pyro.validation_enabled():
        assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Beispiel #23
0
    def _potential_fn_jit(self, skip_jit_warnings, jit_options, params):
        if not params:
            return self._potential_fn(params)
        names, vals = zip(*sorted(params.items()))

        if self._compiled_fn:
            return self._compiled_fn(*vals)

        with pyro.validation_enabled(False):

            def _pe_jit(*zi):
                params = dict(zip(names, zi))
                return self._potential_fn(params)

            if skip_jit_warnings:
                _pe_jit = ignore_jit_warnings()(_pe_jit)
            self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
            return self._compiled_fn(*vals)
Beispiel #24
0
def test_enum_discrete_iranges_iarange_dependency_warning(enumerate_, is_validate):

    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        inner_iarange = pyro.iarange("iarange", 10, 5)

        for i in pyro.irange("irange1", 2):
            with inner_iarange:
                pyro.sample("x_{}".format(i), dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})

        for i in pyro.irange("irange2", 2):
            pyro.sample("y_{}".format(i), dist.Bernoulli(0.5))

    with pyro.validation_enabled(is_validate):
        if enumerate_ and is_validate:
            assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
        else:
            assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Beispiel #25
0
def test_enum_discrete_non_enumerated_iarange_ok(enumerate_):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})

        with pyro.iarange("non_enum", 2):
            a = pyro.sample("a",
                            dist.Bernoulli(0.5).expand_by([2]),
                            infer={'enumerate': None})

        p = (1.0 + a.sum(-1)) / (2.0 + a.size(0)
                                 )  # introduce dependency of b on a

        with pyro.iarange("enum_1", 3):
            pyro.sample("b",
                        dist.Bernoulli(p).expand_by([3]),
                        infer={'enumerate': enumerate_})

    with pyro.validation_enabled():
        assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Beispiel #26
0
def test_enum_discrete_iranges_iarange_dependency_warning(
        enumerate_, is_validate):
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
        inner_iarange = pyro.iarange("iarange", 10, 5)

        for i in pyro.irange("irange1", 2):
            with inner_iarange:
                pyro.sample("x_{}".format(i),
                            dist.Bernoulli(0.5).expand_by([5]),
                            infer={'enumerate': enumerate_})

        for i in pyro.irange("irange2", 2):
            pyro.sample("y_{}".format(i), dist.Bernoulli(0.5))

    with pyro.validation_enabled(is_validate):
        if enumerate_ and is_validate:
            assert_warning(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
        else:
            assert_ok(model, model, TraceEnum_ELBO(max_iarange_nesting=1))
Beispiel #27
0
    def sample(self, trace):
        z = {
            name: node["value"].detach()
            for name, node in self._iter_latent_nodes(trace)
        }
        potential_energy, z_grads = self._fetch_from_cache()
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])
        r, r_flat = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \
            else self._energy(z, r)

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample(
                "slicevar_exp_t={}".format(self._t),
                dist.Exponential(energy_current.new_tensor(1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_flat
        if self.use_multinomial_sampling:
            tree_weight = energy_current.new_zeros(())
        else:
            tree_weight = energy_current.new_ones(())

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False),
                      self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            for tree_depth in range(self._max_tree_depth + 1):
                direction = pyro.sample(
                    "direction_t={}_treedepth={}".format(self._t, tree_depth),
                    dist.Bernoulli(probs=torch.ones(1) * 0.5))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right,
                                                z_right_grads, log_slice,
                                                direction, tree_depth,
                                                energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads,
                                                log_slice, direction,
                                                tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    z_left_grads = new_tree.z_left_grads

                if new_tree.turning or new_tree.diverging:  # stop doubling
                    break

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample(
                    "rand_t={}_treedepth={}".format(self._t, tree_depth),
                    dist.Uniform(new_tree_prob.new_tensor(0.),
                                 new_tree_prob.new_tensor(1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    self._cache(new_tree.z_proposal_pe,
                                new_tree.z_proposal_grads)

                r_sum = r_sum + new_tree.r_sum
                if self._is_turning(r_left, r_right, r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = logsumexp(torch.stack(
                            [tree_weight, new_tree.weight]),
                                                dim=0)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        if self._t < self._warmup_steps:
            accept_prob = new_tree.sum_accept_probs / new_tree.num_proposals
            self._adapter.step(self._t, z, accept_prob)

        if accepted:
            self._accept_cnt += 1

        self._t += 1
        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Beispiel #28
0
    def run(self, *args, **kwargs):
        """
        Run MCMC to generate samples and populate `self._samples`.

        Example usage:

        .. code-block:: python

            def model(data):
                ...

            nuts_kernel = NUTS(model)
            mcmc = MCMC(nuts_kernel, num_samples=500)
            mcmc.run(data)
            samples = mcmc.get_samples()

        :param args: optional arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        :param kwargs: optional keywords arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with optional(pyro.validation_enabled(not self.disable_validation),
                      self.disable_validation is not None):
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            # Use `kernel.transforms` when available
            if getattr(self.kernel, "transforms", None) is not None:
                self.transforms = self.kernel.transforms
            # Else, get transforms from model (e.g. in multiprocessing).
            elif self.kernel.model:
                warmup_steps = 0
                self.kernel.setup(warmup_steps, *args, **kwargs)
                self.transforms = self.kernel.transforms
            # Assign default value
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
Beispiel #29
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            potential_energy = self.potential_fn(z)
            self._cache(z, potential_energy)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return z
        r, r_flat = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t),
                                         dist.Exponential(scalar_like(energy_current, 1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_flat
        sum_accept_probs = 0.
        num_proposals = 0
        tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.)

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            tree_depth = 0
            while tree_depth < self._max_tree_depth:
                direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
                                        dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    z_left_grads = new_tree.z_left_grads

                sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs
                num_proposals = num_proposals + new_tree.num_proposals

                # stop doubling
                if new_tree.diverging:
                    if self._t >= self._warmup_steps:
                        self._divergences.append(self._t - self._warmup_steps)
                    break

                if new_tree.turning:
                    break

                tree_depth += 1

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
                                   dist.Uniform(scalar_like(new_tree_prob, 0.),
                                                scalar_like(new_tree_prob, 1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    self._cache(z, new_tree.z_proposal_pe, new_tree.z_proposal_grads)

                r_sum = r_sum + new_tree.r_sum
                if self._is_turning(r_left, r_right, r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = _logaddexp(tree_weight, new_tree.weight)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        accept_prob = sum_accept_probs / num_proposals

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob)
        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n

        return z.copy()
Beispiel #30
0
def get_model_relations(
    model: Callable,
    model_args: Optional[tuple] = None,
    model_kwargs: Optional[dict] = None,
):
    """
    Infer relations of RVs and plates from given model and optionally data.
    See https://github.com/pyro-ppl/pyro/issues/949 for more details.

    This returns a dictionary with keys:

    -  "sample_sample" map each downstream sample site to a list of the upstream
       sample sites on which it depend;
    -  "sample_dist" maps each sample site to the name of the distribution at
       that site;
    -  "plate_sample" maps each plate name to a list of the sample sites within
       that plate; and
    -  "observe" is a list of observed sample sites.

    For example for the model::

        def model(data):
            m = pyro.sample('m', dist.Normal(0, 1))
            sd = pyro.sample('sd', dist.LogNormal(m, 1))
            with pyro.plate('N', len(data)):
                pyro.sample('obs', dist.Normal(m, sd), obs=data)

    the relation is::

        {'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
         'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
         'plate_sample': {'N': ['obs']},
         'observed': ['obs']}

    :param callable model: A model to inspect.
    :param model_args: Optional tuple of model args.
    :param model_kwargs: Optional dict of model kwargs.
    :rtype: dict
    """
    if model_args is None:
        model_args = ()
    if model_kwargs is None:
        model_kwargs = {}

    with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(
            False):
        with TrackProvenance():
            trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

    sample_sample = {}
    sample_param = {}
    sample_dist = {}
    param_constraint = {}
    plate_sample = defaultdict(list)
    observed = []

    def _get_type_from_frozenname(frozen_name):
        return trace.nodes[frozen_name]["type"]

    for name, site in trace.nodes.items():
        if site["type"] == "param":
            param_constraint[name] = str(site["kwargs"]["constraint"])

        if site["type"] != "sample" or site_is_subsample(site):
            continue

        sample_sample[name] = [
            upstream
            for upstream in get_provenance(site["fn"].log_prob(site["value"]))
            if upstream != name
            and _get_type_from_frozenname(upstream) == "sample"
        ]

        sample_param[name] = [
            upstream
            for upstream in get_provenance(site["fn"].log_prob(site["value"]))
            if upstream != name
            and _get_type_from_frozenname(upstream) == "param"
        ]

        sample_dist[name] = _get_dist_name(site["fn"])
        for frame in site["cond_indep_stack"]:
            plate_sample[frame.name].append(name)
        if site["is_observed"]:
            observed.append(name)

    def _resolve_plate_samples(plate_samples):
        for p, pv in plate_samples.items():
            pv = set(pv)
            for q, qv in plate_samples.items():
                qv = set(qv)
                if len(pv & qv) > 0 and len(pv - qv) > 0 and len(qv - pv) > 0:
                    plate_samples_ = plate_samples.copy()
                    plate_samples_[q] = pv & qv
                    plate_samples_[q + "__CLONE"] = qv - pv
                    return _resolve_plate_samples(plate_samples_)
        return plate_samples

    plate_sample = _resolve_plate_samples(plate_sample)
    # convert set to list to keep order of variables
    plate_sample = {
        k: [name for name in trace.nodes if name in v]
        for k, v in plate_sample.items()
    }

    return {
        "sample_sample": sample_sample,
        "sample_param": sample_param,
        "sample_dist": sample_dist,
        "param_constraint": param_constraint,
        "plate_sample": dict(plate_sample),
        "observed": observed,
    }
Beispiel #31
0
    def run(self, *args, **kwargs):
        """
        Run MCMC to generate samples and populate `self._samples`.

        Example usage:

        .. code-block:: python

            def model(data):
                ...

            nuts_kernel = NUTS(model)
            mcmc = MCMC(nuts_kernel, num_samples=500)
            mcmc.run(data)
            samples = mcmc.get_samples()

        :param args: optional arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        :param kwargs: optional keywords arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with optional(pyro.validation_enabled(not self.disable_validation),
                      self.disable_validation is not None):
            # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer"
            # at https://github.com/pytorch/pytorch/issues/10375
            # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which
            # requires_grad", which happens with `jit_compile` under PyTorch 1.7
            args = [
                arg.detach() if torch.is_tensor(arg) else arg for arg in args
            ]
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            # Use `kernel.transforms` when available
            if getattr(self.kernel, "transforms", None) is not None:
                self.transforms = self.kernel.transforms
            # Else, get transforms from model (e.g. in multiprocessing).
            elif self.kernel.model:
                warmup_steps = 0
                self.kernel.setup(warmup_steps, *args, **kwargs)
                self.transforms = self.kernel.transforms
            # Assign default value
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
Beispiel #32
0
def get_dependencies(
    model: Callable,
    model_args: Optional[tuple] = None,
    model_kwargs: Optional[dict] = None,
) -> Dict[str, object]:
    r"""
    Infers dependency structure about a conditioned model.

    This returns a nested dictionary with structure like::

        {
            "prior_dependencies": {
                "variable1": {"variable1": set()},
                "variable2": {"variable1": set(), "variable2": set()},
                ...
            },
            "posterior_dependencies": {
                "variable1": {"variable1": {"plate1"}, "variable2": set()},
                ...
            },
        }

    where

    -   `prior_dependencies` is a dict mapping downstream latent and observed
        variables to dictionaries mapping upstream latent variables on which
        they depend to sets of plates inducing full dependencies.
        That is, included plates introduce quadratically many dependencies as
        in complete-bipartite graphs, whereas excluded plates introduce only
        linearly many dependencies as in independent sets of parallel edges.
        Prior dependencies follow the original model order.
    -   `posterior_dependencies` is a similar dict, but mapping latent
        variables to the latent or observed sits on which they depend in the
        posterior. Posterior dependencies are reversed from the model order.

    Dependencies elide ``pyro.deterministic`` sites and ``pyro.sample(...,
    Delta(...))`` sites.

    **Examples**

    Here is a simple example with no plates. We see every node depends on
    itself, and only the latent variables appear in the posterior::

        def model_1():
            a = pyro.sample("a", dist.Normal(0, 1))
            pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0))

        assert get_dependencies(model_1) == {
            "prior_dependencies": {
                "a": {"a": set()},
                "b": {"a": set(), "b": set()},
            },
            "posterior_dependencies": {
                "a": {"a": set(), "b": set()},
            },
        }

    Here is an example where two variables ``a`` and ``b`` start out
    conditionally independent in the prior, but become conditionally dependent
    in the posterior do the so-called collider variable ``c`` on which they
    both depend. This is called "moralization" in the graphical model
    literature::

        def model_2():
            a = pyro.sample("a", dist.Normal(0, 1))
            b = pyro.sample("b", dist.LogNormal(0, 1))
            c = pyro.sample("c", dist.Normal(a, b))
            pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.))

        assert get_dependencies(model_2) == {
            "prior_dependencies": {
                "a": {"a": set()},
                "b": {"b": set()},
                "c": {"a": set(), "b": set(), "c": set()},
                "d": {"c": set(), "d": set()},
            },
            "posterior_dependencies": {
                "a": {"a": set(), "b": set(), "c": set()},
                "b": {"b": set(), "c": set()},
                "c": {"c": set(), "d": set()},
            },
        }

    Dependencies can be more complex in the presence of plates. So far all the
    dict values have been empty sets of plates, but in the following posterior
    we see that ``c`` depends on itself across the plate ``p``. This means
    that, among the elements of ``c``, e.g. ``c[0]`` depends on ``c[1]`` (this
    is why we explicitly allow variables to depend on themselves)::

        def model_3():
            with pyro.plate("p", 5):
                a = pyro.sample("a", dist.Normal(0, 1))
            pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))

        assert get_dependencies(model_3) == {
            "prior_dependencies": {
                "a": {"a": set()},
                "b": {"a": set(), "b": set()},
            },
            "posterior_dependencies": {
                "a": {"a": {"p"}, "b": set()},
            },
        }

    [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
        "Faithful inversion of generative models for effective amortized inference"
        https://dl.acm.org/doi/10.5555/3327144.3327229

    :param callable model: A model.
    :param tuple model_args: Optional tuple of model args.
    :param dict model_kwargs: Optional dict of model kwargs.
    :returns: A dictionary of metadata (see above).
    :rtype: dict
    """
    if model_args is None:
        model_args = ()
    if model_kwargs is None:
        model_kwargs = {}

    # Collect sites with tracked provenance.
    with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(
            False):
        with TrackProvenance():
            trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
    sample_sites = [msg for msg in trace.nodes.values() if is_sample_site(msg)]

    # Collect observations.
    observed = {msg["name"] for msg in sample_sites if msg["is_observed"]}
    plates = {
        msg["name"]: {f.name
                      for f in msg["cond_indep_stack"] if f.vectorized}
        for msg in sample_sites
    }

    # Find direct prior dependencies among latent and observed sites.
    prior_dependencies = {n: {n: set()} for n in plates}  # no deps yet
    for i, downstream in enumerate(sample_sites):
        upstreams = [
            u for u in sample_sites[:i] if not u["is_observed"]
            if u["value"].numel()
        ]
        if not upstreams:
            continue
        log_prob = downstream["fn"].log_prob(downstream["value"])
        provenance = get_provenance(log_prob)
        for upstream in upstreams:
            u = upstream["name"]
            if u in provenance:
                d = downstream["name"]
                prior_dependencies[d][u] = set()

    # Next reverse dependencies and restrict downstream nodes to latent sites.
    posterior_dependencies = {n: {} for n in plates if n not in observed}
    for d, upstreams in prior_dependencies.items():
        for u, p in upstreams.items():
            if u not in observed:
                # Note the folowing reverses:
                # u is henceforth downstream and d is henceforth upstream.
                posterior_dependencies[u][d] = p.copy()

    # Moralize: add dependencies among latent variables in each Markov blanket.
    # This assumes all latents are eventually observed, at least indirectly.
    order = {msg["name"]: i for i, msg in enumerate(reversed(sample_sites))}
    for d, upstreams in prior_dependencies.items():
        upstreams = {u: p for u, p in upstreams.items() if u not in observed}
        for u1, p1 in upstreams.items():
            for u2, p2 in upstreams.items():
                if order[u1] <= order[u2]:
                    p12 = posterior_dependencies[u2].setdefault(u1, set())
                    p12 |= plates[u1] & plates[u2] - plates[d]
                    p12 |= plates[u2] & p1
                    p12 |= plates[u1] & p2

    return {
        "prior_dependencies": prior_dependencies,
        "posterior_dependencies": posterior_dependencies,
    }
Beispiel #33
0
"""
import time as time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from agents import Informed
from inference import Inferrer
from helpers import offer_state_mapping

zeros = torch.zeros
ones = torch.ones

import pyro
pyro.validation_enabled()
#%% Define submodel and saving location
n_subjects = 89 #number of subjects
#model = 'theta'
#model = 'beta'
#model = 'theta_beta'
#model = 'theta_beta_gamma'
#model = 'theta_beta_kappa'
model = 'theta_beta_gamma_kappa'

if model == 'theta':
    vals = torch.ones(n_subjects, 5)
    vals[:, 0] = -100.
    vals[:, 1] = -100.
    vals[:, 2] = 100.
    vals[:, 3] = 10.
    def sample(self, trace):
        z, potential_energy, z_grads = self._fetch_from_cache()
        r, _ = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(z, r, self._potential_energy,
                                                                              self.inverse_mass_matrix,
                                                                              self.step_size,
                                                                              self.batch_size,
                                                                              self.num_steps,
                                                                              z_grads=z_grads)
            # apply Metropolis correction.
            energy_proposal = self._kinetic_energy(r_new) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # Set accept prob to 0.0 if delta_energy is `NaN` which may be
        # the case for a diverging trajectory when using a large step size.
        if torch_isnan(delta_energy):
            accept_prob = delta_energy.new_tensor(0.0)
        else:
            accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = torch.rand(self.batch_size)
        accepted = rand < accept_prob
        self._accept_cnt += accepted.sum()/self.batch_size

        # select accepted zs to get z_new
        transitioned_z = {}
        for name in z:
            assert len(z_grads[name].shape) == 2
            assert z_grads[name].shape[0] == self.batch_size
            assert len(z[name].shape) == 2
            assert z[name].shape[0] == self.batch_size
            old_val = z[name]
            old_grad = z_grads[name]
            new_val = z[name]
            new_grad = z_grads_new[name]
            val_dim = old_val.shape[1]
            accept_val = accepted.view(self.batch_size, 1).repeat(1, val_dim)
            transitioned_z[name] = torch.where(accept_val,
                                               new_val,
                                               old_val)
            transitioned_grads = torch.where(accept_val,
                                             new_grad,
                                             old_grad)

        self._cache(transitioned_z,
                    potential_energy,
                    transitioned_grads)

        if self._t < self._warmup_steps:
            self._adapter.step(self._t, transitioned_z, accept_prob)

        self._t += 1

        # get trace with the constrained values for `z`.
        z = transitioned_z.copy()
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
Beispiel #35
0
def main(**kwargs):
    param = utils.load_param()

    # Overwrite param with whatever is in kwargs:
    try:
        for key, val in kwargs.items():
            logging.info(f"Overwriting parameter {key} to {val}.")
            param[key] = val
    except:
        logging.info("ERROR: Did no overwrite of default param.")

    if param['device'] == "cuda":
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    if param.get('real_data'):
        logging.info("Loading real data")

        ind2val, itemattr, dataloaders = prepare.load_dataloaders(
            data_dir="data_real",
            data_type="lake-noclickrate-0.2",
            batch_size=param['batch_size'],
            split_trainvalid=param['split_trainvalid'],
            num_workers=0,
            override_candidate_sampler="actual",
            t_testsplit=param['t_testsplit'])

    else:
        sim_param = utils.load_sim_param()
        #%% Place all items in a group:
        item_group = 1 + (torch.arange(sim_param['num_items']) //
                          (sim_param['num_items'] /
                           (sim_param['num_groups'] - 1))).long()
        item_group[:3] = 0  # first three items are special group
        itemattr = {'category': item_group.cpu().numpy()}

        # %% TRAIN: MODEL+CALLBACKS+TRAINER
        pyro.clear_param_store()
        env = models.PyroRecommender(**sim_param,
                                     item_group=torch.tensor(
                                         itemattr['category']))
        env.init_set_of_real_parameters()
        sim = simulator.Simulator(**sim_param, env=env)
        ind2val, itemattr, dataloaders, sim = simulator.collect_simulated_data(
            sim,
            policy_epsilon=sim_param['collect_data_randompct'],
            **sim_param)

    param['num_items'] = len(ind2val['itemId'])
    param['num_groups'] = len(np.unique(itemattr['category']))
    param['num_users'], param['maxlen_time'], _ = dataloaders[
        'train'].dataset.data['action'].size()
    param['num_users'] = param['num_users'] + 1
    #param['num_displayTypes'] = 3
    # Move data to device
    #for key, val in dataloaders['train'].dataset.data.items():
    #    dataloaders['train'].dataset.data[key] = val.to(param['device'])
    #%%
    pyro.clear_param_store()
    pyro.validation_enabled(True)
    torch.manual_seed(param['train_seed'])
    import pyrotrainer
    dummybatch = next(iter(dataloaders['train']))
    dummybatch['phase_mask'] = dummybatch['mask_train']
    dummybatch = {
        key: val.long().to(param.get("device"))
        for key, val in dummybatch.items()
    }

    model = models.PyroRecommender(**param,
                                   item_group=torch.tensor(
                                       itemattr['category']).long())
    guide = models.MeanFieldGuide(model=model, batch=dummybatch, **param)

    #%% START WITH TRUE PARAMETERS IF THIS IS TRUE:
    if param.get("start_true"):
        logging.info(f"Starting in true mean parameters...:")
        pyro.clear_param_store()
        for key, val in env.par_real.items():
            pyro.param(f"{key}-mean", val)
            pyro.param(f"{key}-scale", torch.zeros_like(val) + 1e-5)
            print(key)

    #%% Define callbacks:

    # Common callbacks:
    optim = pyrotrainer.SviStep(model=model, guide=guide, **param)

    step_callbacks = [optim, pyrotrainer.calc_batch_stats]

    phase_end_callbacks = [
        pyrotrainer.report_phase_end,
        pyrotrainer.ReportPyroParameters(),
        pyrotrainer.EarlyStoppingAndCheckpoint(
            stopping_criteria=param['stopping_criteria'],
            patience=param['patience'],
            name=param['name'])
    ]

    after_training_callbacks = []

    if param['real_data']:
        plot_finn_ads = pyrotrainer.PlotFinnAdsRecommended(ind2val,
                                                           epoch_interval=3)
        phase_end_callbacks.append(plot_finn_ads)
        after_training_callbacks.append(pyrotrainer.VisualizeEmbeddings())
    else:
        test_sim = simulator.Simulator(**param, env=env)
        step_callbacks.append(pyrotrainer.Simulator_batch_stats(test_sim))
        after_training_callbacks.append(
            pyrotrainer.VisualizeEmbeddings(sim=test_sim))
        after_training_callbacks.append(
            pyrotrainer.RewardComputation(param, test_sim))

    after_training_callbacks.append(pyrotrainer.ReportHparam(param))
    #%%
    trainer = pyrotrainer.PyroTrainer(
        model,
        guide,
        dataloaders,
        before_training_callbacks=[pyrotrainer.checksum_data],
        after_training_callbacks=after_training_callbacks,
        step_callbacks=step_callbacks,
        phase_end_callbacks=phase_end_callbacks,
        max_epoch=param['max_epochs'],
        **param)
    return param, ind2val, trainer