Exemple #1
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.locs = PyroModule()
        self.scales = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            init_loc = biject_to(site["fn"].support).inv(
                site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim(
            ) - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size,
                                               dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

            _deep_setattr(self.locs, name,
                          PyroParam(init_loc, constraints.real, event_dim))
            _deep_setattr(
                self.scales, name,
                PyroParam(init_scale, constraints.positive, event_dim))
Exemple #2
0
 def _getattr(obj, attr):
     obj_next = getattr(obj, attr, None)
     if obj_next is not None:
         return obj_next
     setattr(obj, attr, PyroModule())
     return getattr(obj, attr)
    def _setup_prototype(self, *args, **kwargs):

        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.hidden2locs = PyroModule()
        self.hidden2scales = PyroModule()

        if "multiple" in self.encoder_mode:
            # create module for collecting multiple encoder NN
            self.multiple_encoders = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # determine the number of hidden layers
            if "multiple" in self.encoder_mode:
                if "multiple" in self.n_hidden.keys():
                    n_hidden = self.n_hidden["multiple"]
                else:
                    n_hidden = self.n_hidden[name]
            elif "single" in self.encoder_mode:
                n_hidden = self.n_hidden["single"]
            # add linear layer for locs and scales
            param_dim = (n_hidden, self.amortised_plate_sites["sites"][name])
            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2locs,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2scales,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            if "multiple" in self.encoder_mode:
                # create multiple encoders
                if self.encoder_instance is not None:
                    # copy instances
                    encoder_ = deepcopy(self.encoder_instance).to(site["value"].device)
                    # convert to pyro module
                    to_pyro_module_(encoder_)
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        encoder_,
                    )
                else:
                    # create instances
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to(
                            site["value"].device
                        ),
                    )