def __init__(self, prior_info): prior_type, prior_hyper = prior_info self.fc1_prior = None self.fc2_prior = None self.fc3_prior = None if prior_type == 'Gamma': self._prior_gamma(prior_type, prior_hyper) elif prior_type == 'Weibull': self._prior_weibull(prior_type, prior_hyper) elif 'HalfCauchy' in prior_type: self._prior_halfcauchy(prior_type, prior_hyper) else: raise NotImplementedError self.fc1_prior['direction'] = ('vMF', { 'concentration': ml_kappa(dim=300, eps=PRIOR_EPSILON) }) self.fc2_prior['direction'] = ('vMF', { 'concentration': ml_kappa(dim=100, eps=PRIOR_EPSILON) }) self.fc3_prior['direction'] = ('vMF', { 'concentration': ml_kappa(dim=10, eps=PRIOR_EPSILON) })
def __init__(self, prior_type, n1, n2): self.fc1_prior = None self.fc2_prior = None if prior_type == 'Gamma': self._prior_gamma() elif prior_type == 'Weibull': self._prior_weibull() elif prior_type == 'HalfCauchy': self._prior_halfcauchy() else: raise NotImplementedError # self.fc1_prior['direction'] = ('vMF', {'row_softplus_inv_concentration': softplus_inv(ml_kappa(dim=n1, eps=PRIOR_EPSILON)), # 'col_softplus_inv_concentration': softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON))}) # self.fc2_prior['direction'] = ('vMF', {'row_softplus_inv_concentration': softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON)), # 'col_softplus_inv_concentration': softplus_inv(ml_kappa(dim=1, eps=PRIOR_EPSILON))}) self.fc1_prior['direction'] = ('vMF', { 'row_concentration': softplus_inv(ml_kappa(dim=n1, eps=PRIOR_EPSILON)), 'col_concentration': softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON)) }) self.fc2_prior['direction'] = ('vMF', { 'row_concentration': softplus_inv(ml_kappa(dim=n2, eps=PRIOR_EPSILON)), 'col_concentration': softplus_inv(ml_kappa(dim=1, eps=PRIOR_EPSILON)) })
def __init__(self, prior_info): prior_type, prior_hyper = prior_info self.conv1_prior = None self.conv2_prior = None self.fc1_prior = None self.fc2_prior = None self._prior_halfcauchy(prior_type, prior_hyper) self.conv1_prior['direction'] = ('vMF', { 'row_concentration': ml_kappa(dim=1 * 25, eps=PRIOR_EPSILON), 'col_concentration': ml_kappa(dim=20 * 25, eps=PRIOR_EPSILON) }) self.conv2_prior['direction'] = ('vMF', { 'row_concentration': ml_kappa(dim=20 * 25, eps=PRIOR_EPSILON), 'col_concentration': ml_kappa(dim=50 * 25, eps=PRIOR_EPSILON) }) self.fc1_prior['direction'] = ('vMF', { 'row_concentration': ml_kappa(dim=800, eps=PRIOR_EPSILON), 'col_concentration': ml_kappa(dim=500, eps=PRIOR_EPSILON) }) self.fc2_prior['direction'] = ('vMF', { 'row_concentration': ml_kappa(dim=500, eps=PRIOR_EPSILON), 'col_concentration': ml_kappa(dim=10, eps=PRIOR_EPSILON) })
def __init__(self, batch_shape, event_shape): self.batch_shape = batch_shape if isinstance(event_shape, Number): event_shape = torch.Size([event_shape]) self.event_shape = event_shape assert len(event_shape) == 1 self.dim = int(event_shape[0]) super(VonMisesFisherReparametrizedSample, self).__init__() self.loc = Parameter(torch.Tensor(batch_shape + event_shape)) self.softplus_inv_concentration = Parameter( torch.Tensor(torch.Size([1]))) # Too large kappa slow down rejection sampling, so we set upper bound, which is called in forward pass self.softplus_inv_concentration_upper_bound = softplus_inv( ml_kappa(dim=float(event_shape[0]), eps=2e-3)) self.beta_sample = None self.concentration = None self.gradient_correction_required = True self.softplus_inv_concentration_normal_mean = softplus_inv( ml_kappa(dim=float(event_shape[0]), eps=EPSILON)) self.softplus_inv_concentration_normal_std = 0.001 self.direction_init_method = None self.rsample = None self.loc_init_type = 'random'
def reset_parameters(self, hyperparams={}): if 'vMF' in hyperparams.keys(): if 'direction' in hyperparams['vMF'].keys(): if type(hyperparams['vMF']['direction']) == str: if hyperparams['vMF']['direction'] == 'kaiming': self.direction_init_method = torch.nn.init.kaiming_normal_ elif hyperparams['vMF'][ 'direction'] == 'kaiming_transpose': self.direction_init_method = kaiming_transpose elif hyperparams['vMF']['direction'] == 'orthogonal': self.direction_init_method = torch.nn.init.orthogonal_ self.direction_init_method(self.loc) elif type(hyperparams['vMF']['direction']) == torch.Tensor: self.loc.data.copy_(hyperparams['vMF']['direction']) self.loc_init_type = 'fixed' else: raise NotImplementedError self.loc.data /= torch.sum(self.loc.data**2, dim=-1, keepdim=True)**0.5 if 'softplus_inv_concentration_normal_mean' in hyperparams[ 'vMF'].keys(): self.softplus_inv_concentration_normal_mean = hyperparams[ 'vMF']['softplus_inv_concentration_normal_mean'] if 'softplus_inv_concentration_normal_mean_via_epsilon' in hyperparams[ 'vMF'].keys(): epsilon = hyperparams['vMF'][ 'softplus_inv_concentration_normal_mean_via_epsilon'] self.softplus_inv_concentration_normal_mean = softplus_inv( ml_kappa(dim=float(self.event_shape[0]), eps=epsilon)) if 'softplus_inv_concentration_normal_std' in hyperparams[ 'vMF'].keys(): self.softplus_inv_concentration_normal_std = hyperparams[ 'vMF']['softplus_inv_concentration_normal_std'] torch.nn.init.normal_(self.softplus_inv_concentration, self.softplus_inv_concentration_normal_mean, self.softplus_inv_concentration_normal_std)