def __init__(self, D_in, D_out, T, timestep, kernel, M=10, mean_f=zero_mean, mean_g=zero_mean, sigma_f_prior=dist.Uniform(1., 2.), alpha_f_prior=dist.Uniform(0.25, 0.75), sigma_g_prior=dist.Uniform(1., 2.), alpha_g_prior=dist.Uniform(0.25, 0.75)): super(FlowGP, self).__init__(D_in, [], kernel, M) self.T = T self.M = M self.dt = torch.tensor(timestep).float() self.D_in = D_in self.D_out = D_out self.f = SGPR(D_in=D_in, M=M, kernel=kernel(D_in, sigma_prior=sigma_f_prior, alpha_prior=alpha_f_prior), D_out=D_in, mean=mean_f) self.g = SGPR(D_in=D_in, M=M, kernel=kernel(D_out, sigma_prior=sigma_g_prior, alpha_prior=alpha_g_prior), D_out=D_out, mean=mean_g) self.checkpoints = []
def mh_update(samples, sample_labels, width_proposal, prior, l_k, inputs, labels, width_inc=1.02, width_dec=0.5, update_steps=250, sess=tf.Session()): #Take reference to the original code from the paper, which provides a efficient method for MH process import torch import torch.distributions as dist sample_size = samples.shape[0] x = torch.tensor(samples).view(-1, 28*28) acc_ratio = torch.zeros(sample_size) for i in range(update_steps): g_bottom = dist.Uniform(low=torch.max(x - width_proposal.unsqueeze(-1), prior.low), high=torch.min(x + width_proposal.unsqueeze(-1), prior.high)) x_new = g_bottom.sample() loss_run_time, logits_run_time = sess.run([loss, logits], feed_dict={inputs: x_new.view(-1, 28, 28).numpy(), labels: sample_labels}) s_xn = compute_property_function(logits_run_time, sample_labels) g_top = dist.Uniform(low=torch.max(x_new - width_proposal.unsqueeze(-1), prior.low), high=torch.min(x_new + width_proposal.unsqueeze(-1), prior.high)) lg_alpha = (prior.log_prob(x_new) - prior.log_prob(x)+ g_top.log_prob(x) - g_bottom.log_prob(x_new)).sum(dim=1) acceptance = torch.min(lg_alpha, torch.zeros_like(lg_alpha)) log_u = torch.log(torch.rand_like(acceptance)) acc_idx = torch.tensor((log_u <= acceptance).numpy() & (s_xn >= l_k)) acc_ratio += acc_idx.float() x = torch.where(acc_idx.unsqueeze(-1), x_new, x) width_proposal = torch.where(acc_ratio > 0.124, width_proposal*width_inc, width_proposal) width_proposal = torch.where(acc_ratio < 0.124, width_proposal*width_dec, width_proposal) return x.view(-1, 28, 28).numpy(), width_proposal
def __init__(self, D_out, sigma_prior=dist.Uniform(1., 2.), alpha_prior=dist.Uniform(1., 2.)): super(SquaredExp, self).__init__() self.sigma = nn.Parameter(sigma_prior.sample(sample_shape=(1, 1))) self.alpha = nn.Parameter(alpha_prior.sample(sample_shape=(1, 1)))
def get_random_rectangles_params(params_shape, images_height, images_width, erase_scale_range, aspect_ratio_range): images_area = images_height * images_width target_areas = tdist.Uniform( erase_scale_range[0], erase_scale_range[1]).sample(params_shape) * images_area if aspect_ratio_range[0] < 1. and aspect_ratio_range[1] > 1.: aspect_ratios1 = tdist.Uniform(aspect_ratio_range[0], 1).sample(params_shape) aspect_ratios2 = tdist.Uniform( 1, aspect_ratio_range[1]).sample(params_shape) rand_idxs = torch.round(torch.rand(params_shape)).bool() aspect_ratios = torch.where(rand_idxs, aspect_ratios1, aspect_ratios2) else: aspect_ratios = tdist.Uniform( aspect_ratio_range[0], aspect_ratio_range[1]).sample(params_shape) # based on target areas and aspect ratios, rectangle params are computed heights = torch.min( torch.max(torch.round((target_areas * aspect_ratios)**(1 / 2)), torch.tensor(1.)), torch.tensor(float(images_height))).int() widths = torch.min( torch.max(torch.round((target_areas / aspect_ratios)**(1 / 2)), torch.tensor(1.)), torch.tensor(float(images_width))).int() xs = (torch.rand(params_shape) * (images_width - widths + 1).float()).int() ys = (torch.rand(params_shape) * (images_height - heights + 1).float()).int() return widths, heights, xs, ys
def sample(self, sample_shape=torch.Size()): X = self.p.sample(sample_shape) # short-circuit into C++ if (type(p), type(q)) in lookup.keys(): # print("short-circuiting into C++") func = lookup[(type(p), type(q))] return (X, func(self.p.loc, self.p.scale, self.q.loc, self.q.scale, X)) Y = X.clone() W = dist.Uniform(0, self.p.log_prob(X).exp()).sample() threshold = self.q.log_prob(Y).exp() msk = (W <= threshold) # which samples have succeeded # basically, ignore the values where (msk == 1), because they're finalized # this is inefficient because: # 1 - this creates many extra samples - because threads don't exit # 2 - this computes the log prob many more times # we can speed this up using a custom implementation of a coupled kernel while not msk.all(): Yp = self.q.sample() threshold = self.p.log_prob(Yp).exp() W = dist.Uniform(0, self.q.log_prob(Yp).exp()).sample() add = ((1 - msk) & (W > threshold)) Y = Y * (1 - add).type(Y.type()) + Yp * add.type(Y.type()) msk += add return (X, Y)
def __init__(self, D, K, low = 0.1, high = 0.2, device="cuda:0", train_b=False): super().__init__() self.prior = td.Uniform(torch.zeros(K, device=device), torch.ones(K, device=device)) pW = td.Uniform(low, high) W = pW.sample([K,D]).to(device) self._W = nn.Parameter(torch.log(W.exp() - 1.0)) self.b = nn.Parameter(torch.zeros(D), requires_grad=train_b)
def __init__(self, noise_strength: float): """ """ # Setup the distributions self.amp_dist = distributions.Uniform(low=0.2, high=1) self.freq_dist = distributions.Uniform(low=0.1, high=0.25) self.phase_dist = distributions.Uniform(low=0, high=math.pi) self.noise_dist = distributions.Uniform(low=-noise_strength, high=noise_strength)
def get_simulator_and_prior(task): if task == "nonlinear-gaussian": simulator = simulators.NonlinearGaussianSimulator() prior = distributions.Uniform( low=-3 * torch.ones(simulator.parameter_dim), high=3 * torch.ones(simulator.parameter_dim), ) elif task == "nonlinear-gaussian-gaussian": simulator = simulators.NonlinearGaussianSimulator() prior = distributions.MultivariateNormal( loc=torch.zeros(5), covariance_matrix=torch.eye(5) ) elif task == "two-moons": simulator = simulators.TwoMoonsSimulator() a = 2 prior = distributions.Uniform( low=-a * torch.ones(simulator.parameter_dim), high=a * torch.ones(simulator.parameter_dim), ) elif task == "linear-gaussian": dim, std = 20, 0.5 simulator = simulators.LinearGaussianSimulator(dim=dim, std=std) prior = distributions.MultivariateNormal( loc=torch.zeros(dim), covariance_matrix=torch.eye(dim) ) elif task == "lotka-volterra": simulator = simulators.LotkaVolterraSimulator( summarize_observations=True, gaussian_prior=False ) prior = distributions.Uniform( low=-5 * torch.ones(simulator.parameter_dim), high=2 * torch.ones(simulator.parameter_dim), ) elif task == "lotka-volterra-gaussian": simulator = simulators.LotkaVolterraSimulator( summarize_observations=True, gaussian_prior=True ) prior = distributions.MultivariateNormal( loc=torch.zeros(4), covariance_matrix=2 * torch.eye(4) ) elif task == "mg1": simulator = simulators.MG1Simulator(summarize_observations=True) prior = distributions_.MG1Uniform( low=torch.zeros(3), high=torch.Tensor([10.0, 10.0, 1.0 / 3.0]) ) else: raise ValueError(f"'{task}' simulator choice not understood.") return simulator, prior
def __init__(self, is_test): super().__init__() #self.flip_var_order = flip_var_order #if is_test: #self.pX = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) #else: mix = D.Categorical(torch.ones(2, )) comp = D.Uniform(torch.tensor([0.0, 0.35]), torch.tensor([0.45, 1.0])) self.pX = D.MixtureSameFamily(mix, comp) self.pY1 = D.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) self.pY2 = lambda X: D.Normal(torch.sin(10 * X), 0.05)
def uniform_perturbation(sample, x_min, x_max, n=1, sigma=1, seed=0, image_size=28): import torch import torch.distributions as dist sample = torch.tensor(sample).view(-1, image_size*image_size) if isinstance(x_min, (int, float, complex)) and isinstance(x_max, (int, float, complex)): prior = dist.Uniform(low=torch.max(sample-sigma, torch.tensor([x_min])), high=torch.min(sample+sigma, torch.tensor([x_max]))) elif isinstance(x_min, torch.Tensor) and isinstance(x_max, torch.Tensor): prior = dist.Uniform(low=torch.max(sample-sigma, x_min), high=torch.min(sample+sigma, x_max)) else: raise ValueError('Type of x_min and x_max {0} is not supported'.format(type(x_min))) x = prior.sample(torch.Size([n])).view(-1, image_size, image_size) return x.numpy(), prior
def __init__(self, D_in, layers_sizes, kernel, M = 10, sigma_prior = dist.Uniform(1., 2), alpha_prior = dist.Uniform(0.25, 0.75)): super(DGP, self).__init__() self.layers_sizes = layers_sizes self.layers = nn.ModuleList([]) for D_out in layers_sizes: self.layers.append(SGPR(D_in, M = M, kernel = kernel(D_out, sigma_prior = sigma_prior, alpha_prior = alpha_prior), D_out = D_out, mean = lambda X: identity_mean)) D_in = D_out
def sample_z(args): # generate samples from the prior z_cat = OneHotCategorical( logits=torch.zeros(args.batch_size, args.cat_dim)).sample() z_noise = dist.Uniform(-1, 1).sample( torch.Size((args.batch_size, args.noise_dim))) z_cont = dist.Uniform(-1, 1).sample( torch.Size((args.batch_size, args.cont_dim))) # concatenate the incompressible noise, discrete latest, and continuous latents z = torch.cat([z_noise, z_cat, z_cont], dim=1) return z.to(args.device), z_cat.to(args.device), z_noise.to( args.device), z_cont.to(args.device)
def test_(): # if torch.cuda.is_available(): # device = torch.device("cuda") # torch.set_default_tensor_type("torch.cuda.FloatTensor") # else: # input("CUDA not available, do you wish to continue?") # device = torch.device("cpu") # torch.set_default_tensor_type("torch.FloatTensor") loc = torch.Tensor([0, 0]) covariance_matrix = torch.Tensor([[1, 0.99], [0.99, 1]]) likelihood = distributions.MultivariateNormal( loc=loc, covariance_matrix=covariance_matrix) bound = 1.5 low, high = -bound * torch.ones(2), bound * torch.ones(2) prior = distributions.Uniform(low=low, high=high) # def potential_function(inputs_dict): # parameters = next(iter(inputs_dict.values())) # return -(likelihood.log_prob(parameters) + prior.log_prob(parameters).sum()) prior = distributions.Uniform(low=-5 * torch.ones(4), high=2 * torch.ones(4)) from nsf import distributions as distributions_ likelihood = distributions_.LotkaVolterraOscillating() potential_function = PotentialFunction(likelihood, prior) # kernel = Slice(potential_function=potential_function) from pyro.infer.mcmc import HMC, NUTS # kernel = HMC(potential_fn=potential_function) kernel = NUTS(potential_fn=potential_function) num_chains = 3 sampler = MCMC( kernel=kernel, num_samples=10000 // num_chains, warmup_steps=200, initial_params={"": torch.zeros(num_chains, 4)}, num_chains=num_chains, ) sampler.run() samples = next(iter(sampler.get_samples().values())) utils.plot_hist_marginals(utils.tensor2numpy(samples), ground_truth=utils.tensor2numpy(loc), lims=[-6, 3]) # plt.show() plt.savefig("/home/conor/Dropbox/phd/projects/lfi/out/mcmc.pdf") plt.close()
def dequantize(x, constraint=0.9, inverse=False): if inverse: x = 2 / (torch.exp(-x) + 1) - 1 x /= constraint x = (x + 1) / 2 return x, 0 else: B, C, H, W = x.shape noise = distributions.Uniform(0., 1.).sample(x.shape) x = (x * 255 + noise) / 256 return x x = x * 2 - 1 x *= constraint x = (x + 1) / 2 logit_x = x.log() - (1 - x).log() pre_logit_scale = torch.Tensor([constraint]).log() \ - torch.Tensor([1 - constraint]).log() log_det_J = F.softplus(logit_x) + F.softplus(-logit_x) \ - F.softplus(-pre_logit_scale) return logit_x, log_det_J.view(B, -1).sum(1, keepdim=True)
def brute_force(model, domain, count_iterations=100): count_above = int(0) count_total = int(0) if CUDA: low_params = domain[:,0].cuda() high_params = domain[:,1].cuda() else: low_params = domain[:,0] high_params = domain[:,1] prior = dist.Uniform(low=low_params, high=high_params) count_particles = int(1000000) start = time.time() max_val = -math.inf for i in range(count_iterations): x = prior.sample(torch.Size([count_particles])) s_x = model(x).squeeze(-1) count_above += int((s_x >= 0).float().sum().item()) count_total += count_particles max_val = max(s_x.max().item(), max_val) if i % 1000 == 0: time_per_iter = (time.time() - start) / (i+1) time_left = (count_iterations - i) * time_per_iter / 60 print(f'{i/count_iterations*100}% done ({round(time_left,3)} mins left)') print(f'{count_above} adversarial examples observed from {count_total} samples') if count_above > 0: return math.log(count_above) - math.log(count_total), max_val, count_above, count_total else: return -math.inf, max_val, count_above, count_total
def simulate( self, input_obj: torch.Tensor, num_sim: int, seed: Optional[int] = None, ) -> np.ndarray: if seed is not None: torch.manual_seed(seed) # Get the predicted location and scale parameters predictions = self.forward(input_obj) locations, scales = predictions[:, 0], predictions[:, 1] # Build a Folded Logistic Di stribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) # Z ~ |Y| ~ FoldedLogistic(a, b) base_distribution = dists.Uniform(0, 1) transforms = [ dists.transforms.SigmoidTransform().inv, dists.transforms.AffineTransform(loc=locations, scale=scales), dists.transforms.AbsTransform(), ] folded_logistic_dists = ( dists.transformed_distribution.TransformedDistribution( base_distribution, transforms)) # Sample from the distributions samples = folded_logistic_dists.sample((num_sim, )) return samples.numpy()
def getGPmodel(self): pyro.set_rng_seed(self.seed) pyro.clear_param_store() args = self.getInitXY() if self.dim1(args) == True: X4Kernel = torch.tensor(self.getFlotList(args[0])) dim = 1 else: X4Kernel = torch.tensor([self.getFlotList(args[0])]) dim = X4Kernel.shape[1] y4Kernel = torch.tensor(self.getFlotList(args[1])) #X4Kernel = torch.tensor([[2.0, 40], [6.0, 40], [6.0, 80], [8.0, 90]]) #y4Kernel = torch.tensor([1.0, 2.0, 3.0 ,4.0]) #kernel = gp.kernels.RBF(input_dim=dim) kernel = gp.kernels.Matern52(input_dim=dim) kernel.set_prior("variance", dist.Uniform(torch.tensor(0.5), torch.tensor(1.5))) kernel.set_prior("lengthscale", dist.Gamma(torch.tensor(7.5), torch.tensor(1.))) gpmodel = gp.models.GPRegression(X4Kernel, y4Kernel, kernel, noise=torch.tensor(self.noise), jitter=1.0e-4) #1.0e-4 return gpmodel
def __init__(self, n_hidden=100, bottom_width=7, in_ch=128): super(Generator, self).__init__() self.n_hidden = n_hidden self.in_ch = in_ch self.bottom_width = bottom_width self.uniform = distributions.Uniform(-1, 1) # register parameters self.l0 = nn.Linear(in_features=self.n_hidden, out_features=self.bottom_width*self.bottom_width*self.in_ch, bias=False) self.dc1 = nn.ConvTranspose2d(in_channels=self.in_ch, out_channels=self.in_ch//2, kernel_size=4, stride=2, padding=1, bias=False) # (N, 3, 14, 14) self.dc2 = nn.ConvTranspose2d(in_channels=self.in_ch//2, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) # (N, 1, 24, 24) self.bn0 = nn.BatchNorm1d( num_features=self.bottom_width*self.bottom_width*self.in_ch) self.bn1 = nn.BatchNorm2d(num_features=self.in_ch//2)
def logistic_distribution(loc: Tensor, scale: Tensor): base_distribution = td.Uniform(loc.new_zeros(1), scale.new_zeros(1)) transforms = [ td.SigmoidTransform().inv, td.AffineTransform(loc=loc, scale=scale) ] return td.TransformedDistribution(base_distribution, transforms)
def run_pgdn(net, data, get_feature, logger, ad_save_path, params): """ Computes PGD on negative loss for given data. """ dist = dists.Uniform(-params.epsilon, params.epsilon) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') net.to(device) net.eval() for i, (x, y) in enumerate(data): print('pgd sample {}/{}'.format(i + 1, len(data))) file = data.dataset.filenames[i] file = file if '.wav' in file else file + '.wav' ad_ex, db, converged = _pgdn(net, params, x.to(device), y.to(device), get_feature, dist) orig_pred = net.predict(get_feature(x.to(device))).item() true_pred = y.item() if converged >= 0: print('\nsave current file: {} with distortion: {}db (r:{}/o:{})'. format(file, db, true_pred, orig_pred)) new_pred = net.predict(get_feature(ad_ex)).item() save_adv_example(ad_ex.cpu(), os.path.join(ad_save_path, file)) logger.append( [file, converged, db, True, true_pred, orig_pred, new_pred]) else: print('\ncould not find robust adv. example for file {}'.format( file)) logger.append( [file, converged, 0, False, true_pred, orig_pred, orig_pred])
def sample_gamma(self, shape, scale): augment = 10 # get Gamma(shape + factor, 1) with torch.no_grad(): sample = distributions.Gamma(shape + augment, 1).sample() eps = torch.sqrt(9. * (shape + augment) - 3.) * (( (sample / (shape + augment - (1. / 3.)))**(1. / 3.)) - 1.) z = (shape + augment - (1. / 3.)) * ((1. + (eps / torch.sqrt(9. * (shape + augment) - 3.)))**3.) # reduce factor with torch.no_grad(): expand_shape = shape.unsqueeze(-1).repeat(1, 1, augment) factor_range = torch.arange( 1, augment + 1, dtype=torch.float, device=self.device).expand_as(expand_shape) u = distributions.Uniform( torch.zeros(factor_range.size(), device=self.device), torch.ones(factor_range.size(), device=self.device)).sample() u_prod = torch.prod( u**(1. / (expand_shape + factor_range - 1. + 1e-12)), -1) z = z * u_prod * scale return z
def random(state, algorithm, body): '''Random action sampling that returns the same data format as default(), but without forward pass. Uses gym.space.sample()''' state = try_preprocess( state, algorithm, body, append=True) # for consistency with init_action_pd inner logic if body.action_type == 'discrete': action_pd = distributions.Categorical(logits=torch.ones( body.action_space.high, device=algorithm.net.device)) elif body.action_type == 'continuous': # Possibly this should this have a 'device' set action_pd = distributions.Uniform( low=torch.tensor(body.action_space.low).float(), high=torch.tensor(body.action_space.high).float()) elif body.action_type == 'multi_discrete': action_pd = distributions.Categorical( logits=torch.ones(body.action_space.high.size, body.action_space.high[0], device=algorithm.net.device)) elif body.action_type == 'multi_continuous': raise NotImplementedError elif body.action_type == 'multi_binary': raise NotImplementedError else: raise NotImplementedError sample = body.action_space.sample() action = torch.tensor(sample, device=algorithm.net.device) return action, action_pd
def MHKernel(X, q, pi): X = X.clone() Xp = q(X).sample() U = dist.Uniform(0, 1).sample(X.size()) msk = (U.log() <= pi.log_prob(Xp) + q(Xp).log_prob(X) - (pi.log_prob(X) + q(X).log_prob(Xp))) X = (1 - msk).float() * X + msk.float() * Xp return X
def random(state, algorithm, body): '''Random action sampling that returns the same data format as default(), but without forward pass. Uses gym.space.sample()''' action_pd = distributions.Uniform( low=torch.from_numpy(np.array(body.action_space.low)).float(), high=torch.from_numpy(np.array(body.action_space.high)).float()) sample = body.action_space.sample() action = torch.tensor(sample) return action, action_pd
def __init__(self, loc, scale, **kwargs): loc, scale = map(torch.as_tensor, (loc, scale)) base_distribution = ptd.Uniform(torch.zeros_like(loc), torch.ones_like(loc), **kwargs) transforms = [ ptd.SigmoidTransform().inv, ptd.AffineTransform(loc=loc, scale=scale), ] super().__init__(base_distribution, transforms)
def unif_sample(sample_shape=[1], low=0, high=1): """Generate a sample from U(low, high). :param sample_shape: The shape S :param low: Lower bound on uniform :param high: Upper bound on uniform :return: A matrix of shape S containg samples from the uniform. """ d = dist.Uniform(low, high) return d.rsample(sample_shape)
def sample_relation_type(self, prev_parts): """ Sample a relation type from the prior for the current stroke, conditioned on the previous strokes Parameters ---------- prev_parts : list of StrokeType previous part types Returns ------- r : RelationType relation type sample """ for p in prev_parts: assert isinstance(p, StrokeType) nprev = len(prev_parts) stroke_ix = nprev # first sample the relation category if nprev == 0: category = 'unihist' else: indx = self.rel_mixdist.sample() category = self.__relation_categories[indx] # now sample the category-specific type-level parameters if category == 'unihist': data_id = torch.tensor([stroke_ix]) gpos = self.Spatial.sample(data_id) # convert (1,2) tensor to (2,) tensor gpos = torch.squeeze(gpos) r = RelationIndependent(category, gpos, self.Spatial.xlim, self.Spatial.ylim) elif category in ['start', 'end', 'mid']: # sample random stroke uniformly from previous strokes. this is the # stroke we will attach to probs = torch.ones(nprev) attach_ix = dist.Categorical(probs=probs).sample() if category == 'mid': # sample random sub-stroke uniformly from the selected stroke nsub = prev_parts[attach_ix].nsub probs = torch.ones(nsub) attach_subix = dist.Categorical(probs=probs).sample() # sample random type-level spline coordinate _, lb, ub = bspline_gen_s(self.ncpt, 1) eval_spot = dist.Uniform(lb, ub).sample() r = RelationAttachAlong(category, attach_ix, attach_subix, eval_spot, self.ncpt) else: r = RelationAttach(category, attach_ix) else: raise TypeError('invalid relation') return r
def __init__(self): mean = torch.log(torch.Tensor([0.01, 0.5, 1, 0.01])) sigma = 0.5 covariance = sigma**2 * torch.eye(4) self._gaussian = distributions.MultivariateNormal( loc=mean, covariance_matrix=covariance) self._uniform = distributions.Uniform(low=-5 * torch.ones(4), high=2 * torch.ones(4)) self._log_normalizer = -torch.log( torch.erf((2 - mean) / sigma) - torch.erf((-5 - mean) / sigma)).sum()
def logistic_distribution(loc, log_scale): scale = torch.exp(log_scale) + 1e-5 base_distribution = distributions.Uniform(torch.zeros_like(loc), torch.ones_like(loc)) transforms = [ LogisticTransform(), distributions.AffineTransform(loc=loc, scale=scale) ] logistic = distributions.TransformedDistribution(base_distribution, transforms) return logistic
def score_relation_type(self, prev_parts, r): """ Compute the log-probability of the relation type of the current stroke under the prior Parameters ---------- prev_parts : list of StrokeType previous stroke types r : RelationType relation type to score Returns ------- ll : tensor scalar; log-probability of the relation type """ assert isinstance(r, RelationType) for p in prev_parts: assert isinstance(p, StrokeType) nprev = len(prev_parts) stroke_ix = nprev # first score the relation category if nprev == 0: ll = 0. else: ix = self.__relation_categories.index(r.category) ix = torch.tensor(ix, dtype=torch.long) ll = self.rel_mixdist.log_prob(ix) # now score the category-specific type-level parameters if r.category == 'unihist': data_id = torch.tensor([stroke_ix]) # convert (2,) tensor to (1,2) tensor gpos = r.gpos.view(1, 2) # score the type-level location ll = ll + torch.squeeze(self.Spatial.score(gpos, data_id)) elif r.category in ['start', 'end', 'mid']: # score the stroke attachment index probs = torch.ones(nprev) ll = ll + dist.Categorical(probs=probs).log_prob(r.attach_ix) if r.category == 'mid': # score the sub-stroke attachment index nsub = prev_parts[r.attach_ix].nsub probs = torch.ones(nsub) ll = ll + dist.Categorical(probs=probs).log_prob( r.attach_subix) # score the type-level spline coordinate _, lb, ub = bspline_gen_s(self.ncpt, 1) ll = ll + dist.Uniform(lb, ub).log_prob(r.eval_spot) else: raise TypeError('invalid relation') return ll