Ejemplo n.º 1
0
    def _set_prior(self, pname, **kwrags):
        prior = dist.Normal(
            torch.FloatTensor([kwrags.get('mean', 0.)]).to(self.device),
            torch.FloatTensor([kwrags.get('std', 1.)]).to(self.device))
        dwp_samples = kwrags.get('dwp_samples', 1)

        if pname == 'sn':
            for m in self.modules():
                if isinstance(m, _Bayes):
                    m.kl_function = utils.kl_normal
                    m.prior = prior
        elif pname == 'sn-mc':
            for m in self.modules():
                if isinstance(m, _Bayes):
                    m.kl_function = utils.kl_normal_mc
                    m.prior = prior
        elif pname == 'dwp':
            vae = utils.load_vae(kwrags['vae'], self.device)

            for p in vae.parameters():
                p.requires_grad = False

            klf = utils.kl_dwp(vae, n_tries=dwp_samples)

            for m in self.modules():
                if isinstance(m, BayesConv2d):
                    m.kl_function = klf
                elif isinstance(m, _Bayes):
                    m.kl_function = utils.kl_normal
                    m.prior = prior
        else:
            raise NotImplementedError
Ejemplo n.º 2
0
    def set_prior(self, prior_list, dwp_samples, vae_list, flow_list=None):
        convs = [self.features.conv1, self.features.conv2]
        for i, m in enumerate(convs):
            if not isinstance(m, bayes._Bayes):
                continue

            if prior_list[i] == 'vae':
                vae = utils.load_vae(vae_list[i], self.device)
                for p in vae.parameters():
                    p.requires_grad = False
                m.kl_function = utils.kl_dwp(vae, n_tries=dwp_samples)
            elif prior_list[i] == 'flow':
                flow = utils.load_flow(flow_list[i], self.device)
                for p in flow.parameters():
                    p.requires_grad = False
                m.kl_function = utils.kl_flow(flow, n_tries=dwp_samples)
            elif prior_list[i] == 'sn':
                m.kl_function = utils.kl_normal
                m.prior = dist.Normal(
                    torch.FloatTensor([0.]).to(self.device),
                    torch.FloatTensor([1.]).to(self.device))
            elif prior_list[i] == 'loguniform':
                if self.cfg == 'bayes-mtrunca':
                    m.kl_function = utils.kl_loguniform_with_trunc_alpha
            else:
                raise NotImplementedError
Ejemplo n.º 3
0
    def weights_init(self,
                     init_list,
                     vae_list,
                     flow_list=None,
                     pretrained=None,
                     filters_list=None,
                     logvar=-10.):
        self.apply(
            utils.weight_init(module=nn.Conv2d, initf=nn.init.xavier_normal_))
        self.apply(
            utils.weight_init(module=nn.Linear, initf=nn.init.xavier_normal_))
        self.apply(
            utils.weight_init(module=bayes.LogScaleConv2d,
                              initf=utils.const_init(logvar)))
        self.apply(
            utils.weight_init(module=bayes.LogScaleLinear,
                              initf=utils.const_init(logvar)))

        if len(init_list) > 0 and init_list[0] == 'pretrained':
            assert len(init_list) == 1
            w_pretrained = torch.load(pretrained)
            for k, v in w_pretrained.items():
                if k in self.state_dict():
                    self.state_dict()[k].data.copy_(v)
                else:
                    tokens = k.split('.')
                    self.state_dict()['.'.join(tokens[:2] + ['mean'] +
                                               tokens[-1:])].data.copy_(v)
            return

        convs = [self.features.conv1, self.features.conv2]
        for i, m in enumerate(convs):
            init = init_list[i] if i < len(init_list) else 'xavier'
            w = m.mean.weight if isinstance(m, bayes._Bayes) else m.weight
            if init == 'vae':
                vae_path = vae_list[i]
                vae = utils.load_vae(vae_path, device=self.device)
                z = torch.randn(
                    w.size(0) * w.size(1), vae.encoder.z_dim, 1,
                    1).to(vae.device)
                x = vae.decode(z)[0]
                w.data = x.reshape(w.shape)
            elif init == 'flow':
                flow_path = flow_list[i]
                flow = utils.load_flow(flow_path, device=self.device)
                utils.flow_init(flow)(w)
            elif init == 'xavier':
                pass
            elif init == 'filters':
                filters = np.load(filters_list[i])
                N = np.prod(w.shape[:2])
                filters = filters[np.random.permutation(len(filters))[:N]]
                w.data = torch.from_numpy(filters.reshape(*w.shape)).to(
                    self.device)
            else:
                raise NotImplementedError
Ejemplo n.º 4
0
 def set_dwp_regularizer(self, vae_list):
     for path in vae_list:
         vae = utils.load_vae(path, device=self.device)
         for p in vae.parameters():
             p.requires_grad = False
         self.vaes.append(vae)