Exemple #1
0
def test_to_pyro_module_():

    pyro.set_rng_seed(123)
    actual = nn.Sequential(
        nn.Linear(28 * 28, 200),
        nn.Sigmoid(),
        nn.Linear(200, 200),
        nn.Sigmoid(),
        nn.Linear(200, 10),
    )
    to_pyro_module_(actual)
    pyro.clear_param_store()

    pyro.set_rng_seed(123)
    expected = PyroModule[nn.Sequential](
        PyroModule[nn.Linear](28 * 28, 200),
        PyroModule[nn.Sigmoid](),
        PyroModule[nn.Linear](200, 200),
        PyroModule[nn.Sigmoid](),
        PyroModule[nn.Linear](200, 10),
    )
    pyro.clear_param_store()

    def assert_identical(a, e):
        assert type(a) is type(e)
        if isinstance(a, dict):
            assert set(a) == set(e)
            for key in a:
                assert_identical(a[key], e[key])
        elif isinstance(a, nn.Module):
            assert_identical(a.__dict__, e.__dict__)
        elif isinstance(a, (str, int, float, torch.Tensor)):
            assert_equal(a, e)

    assert_identical(actual, expected)

    # check output
    data = torch.randn(28 * 28)
    actual_out = actual(data)
    pyro.clear_param_store()
    expected_out = expected(data)
    assert_equal(actual_out, expected_out)

    # check randomization
    def randomize(model):
        for m in model.modules():
            for name, value in list(m.named_parameters(recurse=False)):
                setattr(
                    m,
                    name,
                    PyroSample(
                        prior=dist.Normal(0, 1)
                        .expand(value.shape)
                        .to_event(value.dim())
                    ),
                )

    randomize(actual)
    randomize(expected)
    assert_identical(actual, expected)
Exemple #2
0
def as_pyro_module(module):
    to_pyro_module_(module, recurse=True)
    for m in module.modules():
        for n, p in list(m.named_parameters(recurse=False)):
            setattr(
                m, n,
                PyroSample(
                    dist.Normal(torch.zeros_like(p),
                                torch.ones_like(p)).to_event()))
    return module
Exemple #3
0
def test_name_preserved_by_to_pyro_module():
    features = torch.randn(4)
    data = torch.randn(3)

    class Model(PyroModule):
        def __init__(self):
            super().__init__()
            self.scale = PyroParam(torch.ones(3), constraints.positive)
            self.loc = torch.nn.Linear(4, 3)

        def forward(self, features, data):
            loc = self.loc(features)
            scale = self.scale
            with pyro.plate("data", len(data)):
                pyro.sample("obs", dist.Normal(loc, scale), obs=data)

    model = Model()
    params = list(model.parameters())
    param_names = set()

    def optim_config(param_name):
        param_names.add(param_name)
        return {"lr": 0.0}

    # Record while model.loc is an nn.Module.
    loss = poutine.trace(model).get_trace(features, data).log_prob_sum()
    loss.backward()
    adam = optim.Adam(optim_config)
    adam(params)
    assert param_names
    expected_param_names = param_names.copy()
    del adam, loss
    param_names.clear()
    pyro.clear_param_store()

    # Record while model.loc is a PyroModule.
    to_pyro_module_(model.loc)
    loss = poutine.trace(model).get_trace(features, data).log_prob_sum()
    loss.backward()
    adam = optim.Adam(optim_config)
    adam(params)
    assert param_names
    actual_param_names = param_names.copy()
    del adam, loss
    param_names.clear()
    pyro.clear_param_store()

    assert actual_param_names == {"scale", "loc.weight", "loc.bias"}
    assert actual_param_names == expected_param_names
Exemple #4
0
def test_bayesian_gru():
    input_size = 2
    hidden_size = 3
    batch_size = 4
    seq_len = 5

    # Construct a simple GRU.
    gru = nn.GRU(input_size, hidden_size)
    input_ = torch.randn(seq_len, batch_size, input_size)
    output, _ = gru(input_)
    assert output.shape == (seq_len, batch_size, hidden_size)
    output2, _ = gru(input_)
    assert torch.allclose(output2, output)

    # Make it Bayesian.
    to_pyro_module_(gru)
    for name, value in list(gru.named_parameters(recurse=False)):
        prior = dist.Normal(0, 1).expand(value.shape).to_event(value.dim())
        setattr(gru, name, PyroSample(prior=prior))
    output, _ = gru(input_)
    assert output.shape == (seq_len, batch_size, hidden_size)
    output2, _ = gru(input_)
    assert not torch.allclose(output2, output)
Exemple #5
0
    def encode(self, name: str, prior: Distribution):
        """
        Apply encoder network to input data to obtain hidden layer encoding.
        Parameters
        ----------
        args
            Pyro model args
        kwargs
            Pyro model kwargs
        -------

        """
        try:
            args, kwargs = self.args_kwargs  # stored as a tuple of (tuple, dict)
            # get the data for NN from
            in_names = self.amortised_plate_sites["input"]
            x_in = [
                kwargs[i] if i in kwargs.keys() else args[i] for i in in_names
            ]
            # apply data transform before passing to NN
            in_transforms = self.amortised_plate_sites["input_transform"]
            x_in = [in_transforms[i](x) for i, x in enumerate(x_in)]
            # apply learnable normalisation before passing to NN:
            input_normalisation = self.amortised_plate_sites.get(
                "input_normalisation", None)
            if input_normalisation is not None:
                for i in range(len(self.amortised_plate_sites["input"])):
                    if input_normalisation[i]:
                        x_in[i] = x_in[i] * deep_getattr(
                            self, f"input_normalisation_{i}")
            if "single" in self.encoder_mode:
                # encode with a single encoder
                res = deep_getattr(self, "one_encoder")(*x_in)
                if "multiple" in self.encoder_mode:
                    # when there is a second layer of multiple encoders fetch encoders and encode data
                    x_in[0] = res
                    res = deep_getattr(self.multiple_encoders, name)(*x_in)
            else:
                # when there are multiple encoders fetch encoders and encode data
                res = deep_getattr(self.multiple_encoders, name)(*x_in)
            return res
        except AttributeError:
            pass

        # Initialize.
        # create normalisation parameters if necessary:
        input_normalisation = self.amortised_plate_sites.get(
            "input_normalisation", None)
        if input_normalisation is not None:
            for i in range(len(self.amortised_plate_sites["input"])):
                if input_normalisation[i]:
                    deep_setattr(
                        self,
                        f"input_normalisation_{i}",
                        PyroParam(
                            torch.ones((1, self.single_n_in)).to(
                                prior.mean.device).requires_grad_(True)),
                    )
        # create encoder neural networks
        if "single" in self.encoder_mode:
            if self.encoder_instance is not None:
                # copy provided encoder instance
                one_encoder = deepcopy(self.encoder_instance).to(
                    prior.mean.device)
                # convert to pyro module
                to_pyro_module_(one_encoder)
                deep_setattr(self, "one_encoder", one_encoder)
            else:
                # create encoder instance from encoder class
                deep_setattr(
                    self,
                    "one_encoder",
                    self.encoder_class(n_in=self.single_n_in,
                                       n_out=self.n_hidden["single"],
                                       **self.encoder_kwargs).to(
                                           prior.mean.device),
                )
        if "multiple" in self.encoder_mode:
            # determine the number of hidden layers
            if name in self.n_hidden.keys():
                n_hidden = self.n_hidden[name]
            else:
                n_hidden = self.n_hidden["multiple"]
            multi_encoder_kwargs = deepcopy(self.multi_encoder_kwargs)
            multi_encoder_kwargs["n_hidden"] = n_hidden

            # create multiple encoders
            if self.encoder_instance is not None:
                # copy instances
                encoder_ = deepcopy(self.encoder_instance).to(
                    prior.mean.device)
                # convert to pyro module
                to_pyro_module_(encoder_)
                deep_setattr(
                    self,
                    "multiple_encoders." + name,
                    encoder_,
                )
            else:
                # create instances
                deep_setattr(
                    self,
                    "multiple_encoders." + name,
                    self.encoder_class(n_in=self.multiple_n_in,
                                       n_out=n_hidden,
                                       **multi_encoder_kwargs).to(
                                           prior.mean.device),
                )
        return self.encode(name, prior)
    def __init__(
        self,
        model,
        amortised_plate_sites: dict,
        n_in: int,
        n_hidden: dict = None,
        init_param=0,
        init_param_scale: float = 1 / 50,
        scales_offset: float = -2,
        encoder_class=FCLayersPyro,
        encoder_kwargs=None,
        multi_encoder_kwargs=None,
        encoder_instance: torch.nn.Module = None,
        create_plates=None,
        encoder_mode: Literal["single", "multiple", "single-multiple"] = "single",
    ):
        """

        Parameters
        ----------
        model
            Pyro model
        amortised_plate_sites
            Dictionary with amortised plate details:
             the name of observation/minibatch plate,
             indexes of model args to provide to encoder,
             variable names that belong to the observation plate
             and the number of dimensions in non-plate axis of each variable - such as:
             {
                 "name": "obs_plate",
                 "input": [0],  # expression data + (optional) batch index ([0, 2])
                 "input_transform": [torch.log1p], # how to transform input data before passing to NN
                 "sites": {
                     "n_s_cells_per_location": 1,
                     "y_s_groups_per_location": 1,
                     "z_sr_groups_factors": self.n_groups,
                     "w_sf": self.n_factors,
                     "l_s_add": 1,
                 }
             }
        n_in
            Number of input dimensions (for encoder_class).
        n_hidden
            Number of hidden nodes in each layer, one of 3 options:
            1. Integer denoting the number of hidden nodes
            2. Dictionary with {"single": 200, "multiple": 200} denoting the number of hidden nodes for each `encoder_mode` (See below)
            3. Allowing different number of hidden nodes for each model site. Dictionary with the number of hidden nodes for single encode mode and each model site:
            {
                     "single": 200
                     "n_s_cells_per_location": 5,
                     "y_s_groups_per_location": 5,
                     "z_sr_groups_factors": 128,
                     "w_sf": 128,
                     "l_s_add": 5,
            }
        init_param
            Not implemented yet - initial values for amortised variables.
        init_param_scale
            How to scale/normalise initial values for weights converting hidden layers to mean and sd.
        encoder_class
            Class for defining encoder network.
        encoder_kwargs
            Keyword arguments for encoder_class.
        multi_encoder_kwargs
            Optional separate keyword arguments for encoder_class, useful when encoder_mode == "single-multiple".
        encoder_instance
            Encoder network instance, overrides class input and the input instance is copied with deepcopy.
        create_plates
            Function for creating plates
        encoder_mode
            Use single encoder for all variables ("single"), one encoder per variable ("multiple")
            or a single encoder in the first step and multiple encoders in the second step ("single-multiple").
        """

        super().__init__(model, create_plates=create_plates)
        self.amortised_plate_sites = amortised_plate_sites
        self.encoder_mode = encoder_mode
        self.scales_offset = scales_offset

        self.softplus = SoftplusTransform()

        if n_hidden is None:
            n_hidden = {"single": 200, "multiple": 200}
        else:
            if isinstance(n_hidden, int):
                n_hidden = {"single": n_hidden, "multiple": n_hidden}
            elif not isinstance(n_hidden, dict):
                raise ValueError("n_hidden must be either in or dict")

        encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict()
        encoder_kwargs["n_hidden"] = n_hidden["single"]
        self.encoder_kwargs = encoder_kwargs
        if multi_encoder_kwargs is None:
            multi_encoder_kwargs = deepcopy(encoder_kwargs)
        self.multi_encoder_kwargs = multi_encoder_kwargs
        if "multiple" in n_hidden.keys():
            self.multi_encoder_kwargs["n_hidden"] = n_hidden["multiple"]

        self.single_n_in = n_in
        self.multiple_n_in = n_in
        self.n_out = (
            np.sum([np.sum(amortised_plate_sites["sites"][k]) for k in amortised_plate_sites["sites"].keys()]) * 2
        )
        self.n_hidden = n_hidden
        self.encoder_class = encoder_class
        self.encoder_instance = encoder_instance
        if "single" in self.encoder_mode:
            # create a single encoder NN
            if encoder_instance is not None:
                self.one_encoder = deepcopy(encoder_instance)
                # convert to pyro module
                to_pyro_module_(self.one_encoder)
            else:
                self.one_encoder = encoder_class(
                    n_in=self.single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs
                )
            if "multiple" in self.encoder_mode:
                self.multiple_n_in = self.n_hidden["single"]

        self.init_param_scale = init_param_scale
    def _setup_prototype(self, *args, **kwargs):

        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.hidden2locs = PyroModule()
        self.hidden2scales = PyroModule()

        if "multiple" in self.encoder_mode:
            # create module for collecting multiple encoder NN
            self.multiple_encoders = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # determine the number of hidden layers
            if "multiple" in self.encoder_mode:
                if "multiple" in self.n_hidden.keys():
                    n_hidden = self.n_hidden["multiple"]
                else:
                    n_hidden = self.n_hidden[name]
            elif "single" in self.encoder_mode:
                n_hidden = self.n_hidden["single"]
            # add linear layer for locs and scales
            param_dim = (n_hidden, self.amortised_plate_sites["sites"][name])
            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2locs,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2scales,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            if "multiple" in self.encoder_mode:
                # create multiple encoders
                if self.encoder_instance is not None:
                    # copy instances
                    encoder_ = deepcopy(self.encoder_instance).to(site["value"].device)
                    # convert to pyro module
                    to_pyro_module_(encoder_)
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        encoder_,
                    )
                else:
                    # create instances
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to(
                            site["value"].device
                        ),
                    )