示例#1
0
    def __init__(
        self,
        n_input: int,
        n_conditions: int,
        lam_scale: float,
        bias_scale: float,
        alpha: float = 1.0,
    ):
        super().__init__()

        self.n_input = n_input
        self.n_conditions = n_conditions

        # weight on monotonic constraint
        self.register_buffer("alpha", torch.as_tensor(alpha))

        # scale of priors on weights and bias
        self.register_buffer("lam_p_scale", torch.as_tensor(lam_scale))
        self.register_buffer("bias_p_scale", torch.as_tensor(bias_scale))

        # parameters for guide
        self.weight_loc = nn.Parameter(
            torch.nn.init.normal_(torch.Tensor(n_input)))
        self.weight_scale = PyroParam(torch.full((n_input, ), lam_scale),
                                      constraint=constraints.positive)
        self.bias_loc = nn.Parameter(
            torch.nn.init.normal_(torch.Tensor(n_conditions)))
        self.bias_scale = PyroParam(torch.full((n_conditions, ), bias_scale),
                                    constraint=constraints.positive)
示例#2
0
文件: guides.py 项目: yufengwa/pyro
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.locs = PyroModule()
        self.scales = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

            _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim))
            _deep_setattr(self.scales, name,
                          PyroParam(init_scale, self.scale_constraint, event_dim))
示例#3
0
    def __init__(self,
                 nu=1.5,
                 num_gps=1,
                 length_scale_init=None,
                 kernel_scale_init=None):
        if nu not in [0.5, 1.5, 2.5]:
            raise NotImplementedError(
                "The only supported values of nu are 0.5, 1.5 and 2.5")
        self.nu = nu
        self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu]
        self.num_gps = num_gps

        if length_scale_init is None:
            length_scale_init = torch.ones(num_gps)
        assert length_scale_init.shape == (num_gps, )

        if kernel_scale_init is None:
            kernel_scale_init = torch.ones(num_gps)
        assert kernel_scale_init.shape == (num_gps, )

        super().__init__()

        self.length_scale = PyroParam(length_scale_init,
                                      constraint=constraints.positive)
        self.kernel_scale = PyroParam(kernel_scale_init,
                                      constraint=constraints.positive)

        if self.state_dim > 1:
            for x in range(self.state_dim):
                for y in range(self.state_dim):
                    mask = torch.zeros(self.state_dim, self.state_dim)
                    mask[x, y] = 1.0
                    self.register_buffer("mask{}{}".format(x, y), mask)
示例#4
0
    def __init__(self,
                 obs_dim=1,
                 state_dim=2,
                 obs_noise_scale_init=None,
                 learnable_observation_loc=False):
        self.obs_dim = obs_dim
        self.state_dim = state_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                              constraint=constraints.positive)
        self.trans_matrix = nn.Parameter(
            torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim))
        self.obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim))
        self.init_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                             constraint=constraints.positive)

        if learnable_observation_loc:
            self.obs_loc = nn.Parameter(torch.zeros(obs_dim))
        else:
            self.register_buffer('obs_loc', torch.zeros(obs_dim))
示例#5
0
    def init_mvn_guide(self):
        """ Initialize multivariate normal guide
        """
        init_loc = torch.full((self.n_params, ), 0.0)
        init_scale = eye_like(init_loc, self.n_params) * 0.1

        _deep_setattr(self, "mvn.loc", PyroParam(init_loc, constraints.real))
        _deep_setattr(self, "mvn.scale_tril",
                      PyroParam(init_scale, constraints.lower_cholesky))
示例#6
0
    def __init__(
        self,
        obs_dim=1,
        state_dim=2,
        nu=1.5,
        obs_noise_scale_init=None,
        length_scale_init=None,
        kernel_scale_init=None,
        learnable_observation_loc=False,
    ):
        self.obs_dim = obs_dim
        self.state_dim = state_dim
        self.nu = nu

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(
            nu=nu,
            num_gps=obs_dim,
            length_scale_init=length_scale_init,
            kernel_scale_init=kernel_scale_init,
        )
        self.dt = 1.0
        self.full_state_dim = self.kernel.state_dim * obs_dim + state_dim
        self.full_gp_state_dim = self.kernel.state_dim * obs_dim

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                              constraint=constraints.positive)
        self.z_trans_matrix = nn.Parameter(
            torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim))
        self.z_obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim))
        self.init_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                             constraint=constraints.positive)

        gp_obs_matrix = torch.zeros(self.kernel.state_dim * obs_dim, obs_dim)
        for i in range(obs_dim):
            gp_obs_matrix[self.kernel.state_dim * i, i] = 1.0
        self.register_buffer("gp_obs_matrix", gp_obs_matrix)

        self.obs_selector = torch.tensor(
            [self.kernel.state_dim * d for d in range(obs_dim)],
            dtype=torch.long)

        if learnable_observation_loc:
            self.obs_loc = nn.Parameter(torch.zeros(obs_dim))
        else:
            self.register_buffer("obs_loc", torch.zeros(obs_dim))
示例#7
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     self.scale = PyroParam(
         self.loc.new_full((self.latent_dim, ), self._init_scale),
         constraints.positive)
示例#8
0
    def __init__(
        self,
        nu=1.5,
        dt=1.0,
        obs_dim=1,
        length_scale_init=None,
        kernel_scale_init=None,
        obs_noise_scale_init=None,
    ):
        self.nu = nu
        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(
            nu=nu,
            num_gps=obs_dim,
            length_scale_init=length_scale_init,
            kernel_scale_init=kernel_scale_init,
        )

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)

        obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1)
        self.register_buffer("obs_matrix",
                             torch.tensor(obs_matrix).unsqueeze(-1))
示例#9
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            value = PyroParam(site["value"].detach(), constraint=site["fn"].support)
            _deep_setattr(self, name, value)
示例#10
0
    def __init__(self, nu=1.5, dt=1.0, obs_dim=2, num_gps=1,
                 length_scale_init=None, kernel_scale_init=None,
                 obs_noise_scale_init=None):
        self.nu = nu
        self.dt = dt
        assert obs_dim > 1, "If obs_dim==1 you should use IndependentMaternGP"
        self.obs_dim = obs_dim
        self.num_gps = num_gps

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim,)

        self.dt = dt
        self.obs_dim = obs_dim
        self.num_gps = num_gps

        super().__init__()

        self.kernel = MaternKernel(nu=nu, num_gps=num_gps,
                                   length_scale_init=length_scale_init,
                                   kernel_scale_init=kernel_scale_init)
        self.full_state_dim = num_gps * self.kernel.state_dim

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.A = nn.Parameter(0.3 * torch.randn(self.num_gps, self.obs_dim))
示例#11
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     self.scale_tril = PyroParam(
         eye_like(self.loc, self.latent_dim) * self._init_scale,
         constraints.lower_cholesky)
示例#12
0
    def __init__(
        self,
        nu=1.5,
        dt=1.0,
        obs_dim=1,
        linearly_coupled=False,
        length_scale_init=None,
        obs_noise_scale_init=None,
    ):

        if nu != 1.5:
            raise NotImplementedError("The only supported value of nu is 1.5")

        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(nu=nu,
                                   num_gps=obs_dim,
                                   length_scale_init=length_scale_init)
        self.full_state_dim = self.kernel.state_dim * obs_dim

        # we demote self.kernel.kernel_scale from being a nn.Parameter
        # since the relevant scales are now encoded in the wiener noise matrix
        del self.kernel.kernel_scale
        self.kernel.register_buffer("kernel_scale", torch.ones(obs_dim))

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.wiener_noise_tril = PyroParam(
            torch.eye(obs_dim) + 0.03 * torch.randn(obs_dim, obs_dim).tril(-1),
            constraint=constraints.lower_cholesky,
        )

        if linearly_coupled:
            self.obs_matrix = nn.Parameter(
                0.3 * torch.randn(self.obs_dim, self.obs_dim))
        else:
            obs_matrix = torch.zeros(self.full_state_dim, obs_dim)
            for i in range(obs_dim):
                obs_matrix[self.kernel.state_dim * i, i] = 1.0
            self.register_buffer("obs_matrix", obs_matrix)
示例#13
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     if self.rank is None:
         self.rank = int(round(self.latent_dim ** 0.5))
     self.scale = PyroParam(
         self.loc.new_full((self.latent_dim,), 0.5 ** 0.5 * self._init_scale),
         constraint=constraints.positive)
     self.cov_factor = nn.Parameter(
         self.loc.new_empty(self.latent_dim, self.rank).normal_(0, 1 / self.rank ** 0.5))
示例#14
0
    def __init__(self,
                 sites,
                 name='',
                 init_scale: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        super().__init__(sites, name, *args, **kwargs)

        self.scale = PyroParam(self._scale_diagonal(init_scale,
                                                    self.jacobian(self.loc)),
                               event_dim=1,
                               constraint=constraints.positive)
示例#15
0
    def __init__(self,
                 sites,
                 name='',
                 diag=_nomatch,
                 init_scale_full: tp.Union[torch.Tensor, float] = 1.,
                 init_scale_diag: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        self.diag_pattern = re.compile(diag)
        self.sites_full, self.sites_diag = ({
            site['name']: site
            for site in _
        } for _ in partition(lambda _: self.diag_pattern.match(_['name']),
                             sites))

        super().__init__(
            dict_union(self.sites_full, self.sites_diag).values(), name, *args,
            **kwargs)

        self.size_full, self.size_diag = (sum(self.sizes[site] for site in _)
                                          for _ in (self.sites_full,
                                                    self.sites_diag))

        jac = self.jacobian(self.loc)

        self.scale_full = PyroParam(self._scale_matrix(init_scale_full,
                                                       jac[:self.size_full]),
                                    event_dim=2,
                                    constraint=constraints.lower_cholesky)
        self.scale_cross = PyroParam(self.loc.new_zeros(
            torch.Size((self.size_diag, self.size_full))),
                                     event_dim=2)
        self.scale_diag = PyroParam(self._scale_diagonal(
            init_scale_diag, jac[self.size_full:]),
                                    event_dim=1,
                                    constraint=constraints.positive)

        self.guide_z_aux = PyroSample(
            dist.Normal(self.loc.new_zeros(()),
                        1.).expand(self.event_shape).to_event(1))
示例#16
0
    def __init__(self,
                 sites,
                 name='',
                 init_scale: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        super().__init__(sites, name, *args, **kwargs)

        self.scale_tril = PyroParam(self._scale_matrix(init_scale,
                                                       self.jacobian(
                                                           self.loc)),
                                    event_dim=2,
                                    constraint=constraints.lower_cholesky)
示例#17
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model)
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
            *args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._plates = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError(
                    'Expected sample site "{}" to be discrete and '
                    'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical,
                        dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(),
                           fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(
                    Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._plates[frame.name] = frame
                else:
                    raise NotImplementedError(
                        "AutoDiscreteParallel does not support sequential pyro.plate"
                    )
        # Initialize guide params
        for site, Dist, param_spec in self._discrete_sites:
            name = site["name"]
            for param_name, param_init, param_constraint in param_spec:
                _deep_setattr(
                    self, "{}_{}".format(name, param_name),
                    PyroParam(param_init, constraint=param_constraint))
示例#18
0
文件: guides.py 项目: ucals/pyro
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            value = site["value"].detach()
            event_dim = site["fn"].event_dim

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    value = periodic_repeat(value, full_size, dim).contiguous()

            value = PyroParam(value, site["fn"].support, event_dim)
            _deep_setattr(self, name, value)
示例#19
0
 def __init__(self, create_plates=None):
     super().__init__()
     # we define parameters here; make sure that the shape is aligned
     # with the shapes of sample sites in model.
     self.ma_weight_loc = PyroParam(torch.zeros(10, 1, 1, 2, 3, 7),
                                    event_dim=3)
     self.ma_weight_scale = PyroParam(torch.ones(10, 1, 1, 2, 3, 7) * 0.1,
                                      dist.constraints.positive,
                                      event_dim=3)
     self.snap_weight_loc = PyroParam(torch.zeros(10, 1, 1, 2, 7),
                                      event_dim=2)
     self.snap_weight_scale = PyroParam(torch.ones(10, 1, 1, 2, 7) * 0.1,
                                        dist.constraints.positive,
                                        event_dim=2)
     self.seasonal_loc = PyroParam(torch.zeros(10, 1, 7, 2, 7), event_dim=2)
     self.seasonal_scale = PyroParam(torch.ones(10, 1, 7, 2, 7) * 0.1,
                                     dist.constraints.positive,
                                     event_dim=2)
     self.create_plates = create_plates
示例#20
0
 def __init__(self, sites, name='', *args, **kwargs):
     super().__init__(sites, name, *args, **kwargs)
     self.loc = PyroParam(self.init[self.mask], event_dim=1)
示例#21
0
    def _setup_prototype(self, *args, **kwargs):

        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.hidden2locs = PyroModule()
        self.hidden2scales = PyroModule()

        if "multiple" in self.encoder_mode:
            # create module for collecting multiple encoder NN
            self.multiple_encoders = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # determine the number of hidden layers
            if "multiple" in self.encoder_mode:
                if "multiple" in self.n_hidden.keys():
                    n_hidden = self.n_hidden["multiple"]
                else:
                    n_hidden = self.n_hidden[name]
            elif "single" in self.encoder_mode:
                n_hidden = self.n_hidden["single"]
            # add linear layer for locs and scales
            param_dim = (n_hidden, self.amortised_plate_sites["sites"][name])
            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2locs,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2scales,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            if "multiple" in self.encoder_mode:
                # create multiple encoders
                if self.encoder_instance is not None:
                    # copy instances
                    encoder_ = deepcopy(self.encoder_instance).to(site["value"].device)
                    # convert to pyro module
                    to_pyro_module_(encoder_)
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        encoder_,
                    )
                else:
                    # create instances
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to(
                            site["value"].device
                        ),
                    )