Example #1
0
    def __init__(self, prior_info):
        prior_type, prior_hyper = prior_info
        self.fc1_prior = None
        self.fc2_prior = None
        self.fc3_prior = None
        if prior_type == 'Gamma':
            self._prior_gamma(prior_type, prior_hyper)
        elif prior_type == 'Weibull':
            self._prior_weibull(prior_type, prior_hyper)
        elif 'HalfCauchy' in prior_type:
            self._prior_halfcauchy(prior_type, prior_hyper)
        else:
            raise NotImplementedError

        self.fc1_prior['direction'] = ('vMF', {
            'concentration':
            ml_kappa(dim=300, eps=PRIOR_EPSILON)
        })
        self.fc2_prior['direction'] = ('vMF', {
            'concentration':
            ml_kappa(dim=100, eps=PRIOR_EPSILON)
        })
        self.fc3_prior['direction'] = ('vMF', {
            'concentration':
            ml_kappa(dim=10, eps=PRIOR_EPSILON)
        })
 def __init__(self, prior_type, n1, n2):
     self.fc1_prior = None
     self.fc2_prior = None
     if prior_type == 'Gamma':
         self._prior_gamma()
     elif prior_type == 'Weibull':
         self._prior_weibull()
     elif prior_type == 'HalfCauchy':
         self._prior_halfcauchy()
     else:
         raise NotImplementedError
     # self.fc1_prior['direction'] = ('vMF', {'row_softplus_inv_concentration': softplus_inv(ml_kappa(dim=n1, eps=PRIOR_EPSILON)),
     #                                        'col_softplus_inv_concentration': softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON))})
     # self.fc2_prior['direction'] = ('vMF', {'row_softplus_inv_concentration': softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON)),
     #                                        'col_softplus_inv_concentration': softplus_inv(ml_kappa(dim=1, eps=PRIOR_EPSILON))})
     self.fc1_prior['direction'] = ('vMF', {
         'row_concentration':
         softplus_inv(ml_kappa(dim=n1, eps=PRIOR_EPSILON)),
         'col_concentration':
         softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON))
     })
     self.fc2_prior['direction'] = ('vMF', {
         'row_concentration':
         softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON)),
         'col_concentration':
         softplus_inv(ml_kappa(dim=1, eps=PRIOR_EPSILON))
     })
Example #3
0
    def __init__(self, prior_info):
        prior_type, prior_hyper = prior_info
        self.conv1_prior = None
        self.conv2_prior = None
        self.fc1_prior = None
        self.fc2_prior = None
        self._prior_halfcauchy(prior_type, prior_hyper)

        self.conv1_prior['direction'] = ('vMF', {
            'row_concentration':
            ml_kappa(dim=1 * 25, eps=PRIOR_EPSILON),
            'col_concentration':
            ml_kappa(dim=20 * 25, eps=PRIOR_EPSILON)
        })
        self.conv2_prior['direction'] = ('vMF', {
            'row_concentration':
            ml_kappa(dim=20 * 25, eps=PRIOR_EPSILON),
            'col_concentration':
            ml_kappa(dim=50 * 25, eps=PRIOR_EPSILON)
        })
        self.fc1_prior['direction'] = ('vMF', {
            'row_concentration':
            ml_kappa(dim=800, eps=PRIOR_EPSILON),
            'col_concentration':
            ml_kappa(dim=500, eps=PRIOR_EPSILON)
        })
        self.fc2_prior['direction'] = ('vMF', {
            'row_concentration':
            ml_kappa(dim=500, eps=PRIOR_EPSILON),
            'col_concentration':
            ml_kappa(dim=10, eps=PRIOR_EPSILON)
        })
Example #4
0
 def __init__(self, batch_shape, event_shape):
     self.batch_shape = batch_shape
     if isinstance(event_shape, Number):
         event_shape = torch.Size([event_shape])
     self.event_shape = event_shape
     assert len(event_shape) == 1
     self.dim = int(event_shape[0])
     super(VonMisesFisherReparametrizedSample, self).__init__()
     self.loc = Parameter(torch.Tensor(batch_shape + event_shape))
     self.softplus_inv_concentration = Parameter(
         torch.Tensor(torch.Size([1])))
     # Too large kappa slow down rejection sampling, so we set upper bound, which is called in forward pass
     self.softplus_inv_concentration_upper_bound = softplus_inv(
         ml_kappa(dim=float(event_shape[0]), eps=2e-3))
     self.beta_sample = None
     self.concentration = None
     self.gradient_correction_required = True
     self.softplus_inv_concentration_normal_mean = softplus_inv(
         ml_kappa(dim=float(event_shape[0]), eps=EPSILON))
     self.softplus_inv_concentration_normal_std = 0.001
     self.direction_init_method = None
     self.rsample = None
     self.loc_init_type = 'random'
Example #5
0
    def reset_parameters(self, hyperparams={}):
        if 'vMF' in hyperparams.keys():
            if 'direction' in hyperparams['vMF'].keys():
                if type(hyperparams['vMF']['direction']) == str:
                    if hyperparams['vMF']['direction'] == 'kaiming':
                        self.direction_init_method = torch.nn.init.kaiming_normal_
                    elif hyperparams['vMF'][
                            'direction'] == 'kaiming_transpose':
                        self.direction_init_method = kaiming_transpose
                    elif hyperparams['vMF']['direction'] == 'orthogonal':
                        self.direction_init_method = torch.nn.init.orthogonal_
                    self.direction_init_method(self.loc)
                elif type(hyperparams['vMF']['direction']) == torch.Tensor:
                    self.loc.data.copy_(hyperparams['vMF']['direction'])
                    self.loc_init_type = 'fixed'
                else:
                    raise NotImplementedError

                self.loc.data /= torch.sum(self.loc.data**2,
                                           dim=-1,
                                           keepdim=True)**0.5

            if 'softplus_inv_concentration_normal_mean' in hyperparams[
                    'vMF'].keys():
                self.softplus_inv_concentration_normal_mean = hyperparams[
                    'vMF']['softplus_inv_concentration_normal_mean']
            if 'softplus_inv_concentration_normal_mean_via_epsilon' in hyperparams[
                    'vMF'].keys():
                epsilon = hyperparams['vMF'][
                    'softplus_inv_concentration_normal_mean_via_epsilon']
                self.softplus_inv_concentration_normal_mean = softplus_inv(
                    ml_kappa(dim=float(self.event_shape[0]), eps=epsilon))
            if 'softplus_inv_concentration_normal_std' in hyperparams[
                    'vMF'].keys():
                self.softplus_inv_concentration_normal_std = hyperparams[
                    'vMF']['softplus_inv_concentration_normal_std']
        torch.nn.init.normal_(self.softplus_inv_concentration,
                              self.softplus_inv_concentration_normal_mean,
                              self.softplus_inv_concentration_normal_std)