Пример #1
0
 def std_v1x(self, x):
     if self.learn_std_v1x:
         return self.f_std_v(self._get_prev(x))
     else:
         return tensorify(x.device,
                          self.std_v1x_val)[0].expand(x.shape[:-1] +
                                                      (self.dim_v, ))
Пример #2
0
 def std_s1vx(self, v, x):
     if self.learn_std_s1vx:
         parav = self._get_parav(x)
         return self.f_std_s(tc.cat([v, parav], dim=-1))
     else:
         return tensorify(x.device,
                          self.std_s1vx_val)[0].expand(x.shape[:-1] +
                                                       (self.dim_s, ))
Пример #3
0
 def std_s1vx(self, v, x):
     if self.learn_std_s1vx:
         if not self.vbranch:
             bn = self._get_bn(x)
             bn_synth = tc.cat([v, bn[..., self.dim_v:]], dim=-1)
             return self.f_std_s(bn_synth)
         else:
             return self.std_s1x(x)
     else:
         return tensorify(x.device, self.std_s1vx_val)[0].expand(x.shape[:-3]+(self.dim_s,))
Пример #4
0
    def _get_priors(mean_s,
                    std_s,
                    shape_s,
                    mvn_prior: bool = False,
                    device=None):
        if not mvn_prior:
            p_s = ds.Normal('s', mean=mean_s, std=std_s, shape=shape_s)
            prior_params_list = []
        else:
            if len(shape_s) != 1:
                raise RuntimeError(
                    "only 1-dim vector is supported for `s` in `mvn_prior` mode"
                )
            dim_s = shape_s[0]
            mean_s = tc.zeros(
                shape_s, device=device) if callable(mean_s) else ds.tensorify(
                    device, mean_s)[0].expand(shape_s).clone().detach()
            std_s_offdiag = tc.zeros(
                (dim_s, dim_s),
                device=device)  # lower triangular of L_ss (excl. diag)
            if callable(std_s):  # for diag of L_ss
                std_s_diag_param = tc.zeros(shape_s, device=device)
            else:
                std_s = ds.tensorify(device, std_s)[0].expand(shape_s)
                std_s_diag_param = std_s.log().clone().detach()
            prior_params_list = [mean_s, std_s_diag_param, std_s_offdiag]

            def std_s_tril():  # L_ss
                return std_s_offdiag.tril(
                    -1) + std_s_diag_param.exp().diagflat()

            p_s = ds.MVNormal('s',
                              mean=mean_s,
                              std_tril=std_s_tril,
                              shape=shape_s)
        return p_s, prior_params_list
Пример #5
0
 def std_s1x(self, x):
     if self.learn_std_s1vx:
         return self.f_std_s(self._get_bn(x))
     else:
         return tensorify(x.device, self.std_s1vx_val)[0].expand(x.shape[:-3]+(self.dim_s,))
Пример #6
0
 def std_v1x(self, x):
     if self.learn_std_v1x:
         if not self.vbranch: return self.f_std_v(self._get_bb(x))
         else: return self.f_std_v(self._get_bn(x))
     else:
         return tensorify(x.device, self.std_v1x_val)[0].expand(x.shape[:-3]+(self.dim_v,))
Пример #7
0
    def __init__(self,
                 shape_s,
                 shape_x,
                 dim_y,
                 mean_x1s,
                 std_x1s,
                 logit_y1s,
                 mean_s1x=None,
                 std_s1x=None,
                 tmean_s1x=None,
                 tstd_s1x=None,
                 mean_s=0.,
                 std_s=1.,
                 learn_tprior=False,
                 src_mvn_prior=False,
                 tgt_mvn_prior=False,
                 device=None):
        if device is not None: ds.Distr.default_device = device
        self._parameter_dict = {}
        self.shape_x, self.dim_y, self.shape_s = shape_x, dim_y, shape_s
        self.learn_tprior = learn_tprior

        self.p_x1s = ds.Normal('x', mean=mean_x1s, std=std_x1s, shape=shape_x)
        self.p_y1s = getattr(ds, 'Bern' if dim_y == 1 else 'Catg')(
            'y', logits=logit_y1s)

        self.p_s, prior_params_list = self._get_priors(mean_s, std_s, shape_s,
                                                       src_mvn_prior, device)
        if src_mvn_prior:
            self._parameter_dict.update(
                zip(['mean_s', 'std_s_diag_param', 'std_s_offdiag'],
                    prior_params_list))
        self.p_sx = self.p_s * self.p_x1s

        if mean_s1x is not None:
            self.q_s1x = ds.Normal('s',
                                   mean=mean_s1x,
                                   std=std_s1x,
                                   shape=shape_s)
        else:
            self.q_s1x = None

        if tmean_s1x is not None:
            self.qt_s1x = ds.Normal('s',
                                    mean=tmean_s1x,
                                    std=tstd_s1x,
                                    shape=shape_s)
        else:
            self.qt_s1x = None

        if learn_tprior:
            if not tgt_mvn_prior:
                tmean_s = tc.zeros(
                    shape_s,
                    device=device) if callable(mean_s) else ds.tensorify(
                        device, mean_s)[0].expand(shape_s).clone().detach()
                tstd_s_param = tc.zeros(
                    shape_s,
                    device=device) if callable(std_s) else ds.tensorify(
                        device,
                        std_s)[0].log().expand(shape_s).clone().detach()
                self._parameter_dict.update({
                    'tmean_s': tmean_s,
                    'tstd_s_param': tstd_s_param
                })

                def tstd_s():
                    return tc.exp(tstd_s_param)

                self.pt_s, tprior_params_list = self._get_priors(
                    tmean_s, tstd_s, shape_s, False, device)
            else:
                self.pt_s, tprior_params_list = self._get_priors(
                    mean_s, std_s, shape_s, True, device)
                self._parameter_dict.update(
                    zip(['tmean_s', 'tstd_s_diag_param', 'tstd_s_offdiag'],
                        tprior_params_list))
        else:
            self.pt_s = self.p_s
        self.pt_sx = self.pt_s * self.p_x1s
        for param in self._parameter_dict.values():
            param.requires_grad_()