def get_loss(kernel, eps): if kernel.kernel_type == 'quaternion': return SamplesLoss("sinkhorn", blur=eps, diameter=3.15, cost=utils.quaternion_geodesic_distance, backend='tensorized') elif kernel.kernel_type == 'squared_euclidean': return SamplesLoss("sinkhorn", p=2, blur=eps, diameter=4., backend='tensorized') elif kernel.kernel_type == 'power_quaternion': dist = partial(utils.power_quaternion_geodesic_distance, kernel.power) #dist = partial(utils.sum_power_quaternion_geodesic_distance,kernel.power) #return sinkhorn_wasserstein_fisher_rao,dist return SamplesLoss("sinkhorn", blur=eps, diameter=10., cost=dist, backend='tensorized') elif kernel.kernel_type == 'sum_power_quaternion': dist = partial(utils.sum_power_quaternion_geodesic_distance, kernel.power) return SamplesLoss("sinkhorn", blur=eps, diameter=10., cost=dist, backend='tensorized')
def make_default_loss_fn(ot_schedule=None, bce_schedule=None, ot_loss=None): ot_loss = ot_loss or SamplesLoss( "sinkhorn", p=2, blur=.05, scaling=.6, reach=6.) bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none') ot_schedule = ot_schedule or (lambda state: 1.0) bce_schedule = bce_schedule or (lambda state: 1.0) def compute_loss(state): vector_masses = state['vector_masses'] vector_coords = state['vector_coords'] raster_masses = state['raster_masses'] raster_coords = state['raster_coords'] bce_coef = bce_schedule(state) ot_coef = ot_schedule(state) bce_per_sample = bce_loss(state['render'], state['raster']).mean( dim=(1, 2)) if bce_coef != 0. else 0. ot_loss_per_sample = ot_loss(vector_masses, vector_coords, raster_masses, raster_coords) if ot_coef != 0. else 0. total_loss_per_sample = bce_coef * bce_per_sample + ot_coef * ot_loss_per_sample state['loss_per_sample'] = total_loss_per_sample return total_loss_per_sample.sum() return compute_loss
def __init__(self, models, eps=0.01, lr=1e-2, opt=torch.optim.Adam, max_iter=10, niter=15, batchsize=128, n_pairs=10, tol=1e-3, weight_decay=1e-5, order='random', unsymmetrize=True, scaling=.9): self.models = models self.sk = SamplesLoss("sinkhorn", p=2, blur=eps, scaling=scaling, backend="auto") self.lr = lr self.opt = opt self.max_iter = max_iter self.niter = niter self.batchsize = batchsize self.n_pairs = n_pairs self.tol = tol self.weight_decay = weight_decay self.order = order self.unsymmetrize = unsymmetrize self.is_fitted = False
def __init__(self, blur=.01, scaling=.9, diameter=None, p: int = 2): super(WasLoss, self).__init__() self.loss = SamplesLoss("sinkhorn", blur=blur, scaling=scaling, debias=False, diameter=diameter, p=p)
def __init__(self, config, device): self.ot_solver = SamplesLoss("sinkhorn", p=2, blur=config.sinkhorn_blur, scaling=config.sinkhorn_scaling, debias=True) self.device = device
def __init__(self, config): super(NeuralNetRougeRegModel, self).__init__() self.config = config self.sinkhorn = SamplesLoss(loss='sinkhorn', p=self.config['p'], blur=self.config['blur'], scaling=self.config['scaling']) self.layer = nn.Linear(self.config['D_in'], self.config['D_out'])
def __init__(self, config): super(LinSinkhornRegModel, self).__init__() self.config = config self.layer = nn.Linear(self.config['D'], self.config['D'], bias=False) self.sinkhorn = SamplesLoss(loss='sinkhorn', p=self.config['p'], blur=self.config['blur'], scaling=self.config['scaling'])
def __init__(self, config): super(TransformSinkhornRegModel, self).__init__() self.config = config self.M = nn.Parameter( torch.randn(self.config['D_in'], self.config['D_out'])) self.sinkhorn = SamplesLoss(loss='sinkhorn', p=self.config['p'], blur=self.config['blur'], scaling=self.config['scaling'])
def __init__(self, config): super(NNSinkhornPRModel, self).__init__() self.config = config self.layer = nn.Linear(self.config['D'], self.config['D']) self.sinkhorn = SamplesLoss(loss='sinkhorn', p=self.config['p'], blur=self.config['blur'], scaling=self.config['scaling']) self.sigm = nn.Sigmoid()
def __init__(self, config): super(CondLinSinkhornPRModel, self).__init__() self.config = config self.model = AvgModel(self.config['D']) self.sinkhorn = SamplesLoss(loss='sinkhorn', p=self.config['p'], blur=self.config['blur'], scaling=self.config['scaling']) self.sigm = nn.Sigmoid()
def batch_EMD_loss(x, y): bs, num_points_x, points_dim = x.size() _, num_points_y, _ = y.size() batch_EMD = 0 L = SamplesLoss() for i in range(bs): loss = L(x[i], y[i]) batch_EMD += loss emd = batch_EMD / bs return emd
def triplet_distance(self, anchor, positive, negative): # return torch.nn.functional.relu((anchor-positive).pow(2).sum()-(anchor-negative).pow(2).sum()+self.margin) wloss = SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=.99, backend="online") loss1 = wloss(torch.Tensor(anchor), torch.Tensor(positive)) loss2 = wloss(torch.Tensor(anchor), torch.Tensor(negative)) return torch.nn.functional.relu(loss1 - loss2 + self.margin)
def triplet_distance(self, anchor, positive, negative): wloss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05, scaling=.99, backend="online") d1 = wloss(anchor, positive) d2 = wloss(anchor, negative) return torch.nn.functional.relu( d1.pow(2).sum() - d2.pow(2).sum() + self.margin)
def train_approx(args, fmodel, gmodel, device, approx_loader, f_optimizer, g_optimizer, output_samples, epoch): gmodel.train() fmodel.train() LF = SamplesLoss(loss='sinkhorn', p=2, blur=0.05, backend='tensorized', cost=cost_func) for batch_idx, (data, target) in enumerate(approx_loader): data, target = data.to(device), target.to(device) f_optimizer.zero_grad() with torch.no_grad(): # To be consistant with KL, the exp() function is changed to softplus, # i.e., alpha0 = softplus(g). # Note that, for mmd, the exp() function can be used directly for faster convergence, # without tuning hyper-parameters. g_out = F.softplus(gmodel(data)) output = F.softmax((output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to( device)), dim=2).clamp(0.0001, 0.9999) f_out = F.softmax(fmodel(data), dim=1) pi = f_out.mul(g_out) s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1,0,2) loss = LF(output, s1).mean() loss.backward() f_optimizer.step() if batch_idx == 0: print('Train Epoch: {}, Loss: {:.6f}'.format( epoch, loss.item())) g_optimizer.zero_grad() g_out = F.softplus(gmodel(data)) with torch.no_grad(): output = F.softmax((output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to( device)), dim=2).clamp(0.0001, 0.9999) with torch.no_grad(): f_out = F.softmax(fmodel(data), dim=1) pi = f_out.mul(g_out) s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1,0,2) loss = LF(output, s1).mean() loss.backward() g_optimizer.step() if batch_idx == 0: print('Train Epoch: {}, Loss: {:.6f}'.format( epoch, loss.item()))
def __init__(self, sigma, a): super(MyLoss, self).__init__() self.sigma = sigma self.a = a self.loss = SamplesLoss("sinkhorn", p=1, blur=.1, scaling=0.8, verbose=True) #, backend="tensorized") self.l1 = nn.L1Loss() self.mse = nn.MSELoss()
def __init__(self, particles, rm_map, eps, max_iter, particles_type): super(SinkhornEval, self).__init__() self.eps = eps self.particles_type = particles_type self.particles = particles self.RM_map = rm_map self.loss = SamplesLoss("sinkhorn", blur=eps, diameter=3.15, cost=utils.quaternion_geodesic_distance, backend='tensorized')
def train_approx(args, fmodel, gmodel, device, approx_loader, f_optimizer, g_optimizer, output_samples, epoch): gmodel.train() fmodel.train() LF = SamplesLoss(loss='sinkhorn', p=2, blur=0.05, backend='tensorized', cost=cost_func) for batch_idx, (data, target) in enumerate(approx_loader): data, target = data.to(device), target.to(device) f_optimizer.zero_grad() with torch.no_grad(): g_out = torch.exp(gmodel(data)) output = output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to( device).clamp(0.0001, 0.9999) f_out = F.softmax(fmodel(data), dim=1) pi = f_out.mul(g_out) s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1, 0, 2).contiguous() loss = LF(output, s1).mean() loss.backward() f_optimizer.step() if batch_idx == 0: print('Train Epoch: {}, Loss: {:.6f}'.format( epoch, loss.item())) g_optimizer.zero_grad() g_out = torch.exp(gmodel(data)) # data = data.view(data.shape[0], -1) with torch.no_grad(): output = output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to( device).clamp(0.0001, 0.9999) with torch.no_grad(): f_out = F.softmax(fmodel(data), dim=1) pi = f_out.mul(g_out) s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1, 0, 2).contiguous() loss = LF(output, s1).mean() # print(loss.item()) loss.backward() g_optimizer.step() if batch_idx == 0: print('Train Epoch: {}, Loss: {:.6f}'.format( epoch, loss.item()))
def forward(self, batch, labels, gt_labels=None): """ Args: batch: torch.Tensor: Input of embeddings with size (BS x DIM) labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) """ if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() if callable(self.sampler): sampled_triplets = self.sampler(batch, labels, gt_labels) else: sampled_triplets = self.sampler.give(batch, labels, gt_labels) d_ap, d_an = [], [] for triplet in sampled_triplets: train_triplet = { 'Anchor': batch[triplet[0], :], 'Positive': batch[triplet[1], :], 'Negative': batch[triplet[2]] } wloss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05, scaling=.99, backend="online") d1 = wloss(train_triplet['Anchor'], train_triplet['Positive']) d2 = wloss(train_triplet['Anchor'], train_triplet['Negative']) pos_dist = (d1.pow(2).sum() + 1e-8).pow(1 / 2) neg_dist = (d2.pow(2).sum() + 1e-8).pow(1 / 2) d_ap.append(pos_dist) d_an.append(neg_dist) d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) if self.beta_constant: beta = self.beta else: beta = torch.stack([ self.beta[labels[triplet[0]]] for triplet in sampled_triplets ]).type(torch.cuda.FloatTensor) pos_loss = torch.nn.functional.relu(d_ap - beta + self.margin) neg_loss = torch.nn.functional.relu(beta - d_an + self.margin) pair_count = torch.sum((pos_loss > 0.) + (neg_loss > 0.)).type( torch.cuda.FloatTensor) if pair_count == 0.: loss = torch.sum(pos_loss + neg_loss) else: loss = torch.sum(pos_loss + neg_loss) / pair_count # if self.nu: loss = loss + beta_regularisation_loss.type(torch.cuda.FloatTensor) return loss
def explain(graph_num): cur_embs = torch.Tensor(graph_embs[graph_num]) distance = SamplesLoss("sinkhorn", p=1, blur=.01) positive_ids, negative_ids = closest(graph_num, graph_distance, size=similar_size) positive_embs = [torch.Tensor(graph_embs[i]) for i in positive_ids] negative_embs = [torch.Tensor(graph_embs[i]) for i in negative_ids] mask = torch.nn.Parameter(torch.zeros(len(cur_embs))) learning_rate = 1e-1 optimizer = torch.optim.Adam([mask], lr=learning_rate) if distance_str == 'ot': def mydist(mask, embs): return distance(mask.softmax(0), cur_embs, distance.generate_weights(embs), embs) else: def mydist(mask, embs): return torch.dist((cur_embs * mask.softmax(0).reshape(-1, 1)).sum(axis=0), embs.mean(axis=0)) # tq = tqdm(range(50)) history = [] for t in range(50): loss_pos = torch.mean(torch.stack([mydist(mask, x) for x in positive_embs])) loss_neg = torch.mean(torch.stack([mydist(mask, x) for x in negative_embs])) loss_self = mydist(mask, cur_embs) loss = 0 if '-' in loss_str: loss = loss + loss_neg if '+' in loss_str: loss = loss - loss_pos if 's' in loss_str: loss = loss + loss_self hist_item = dict(loss_neg=loss_neg.item(), loss_self=loss_self.item(), loss_pos=loss_pos.item(), loss=loss.item()) history.append(hist_item) # tq.set_postfix(**hist_item) optimizer.zero_grad() loss.backward() optimizer.step() node_importance = list(1 - mask.softmax(0).detach().numpy().ravel()) N = nx_graphs[graph_num].number_of_nodes() masked_adj = np.zeros((N, N)) for u, v in nx_graphs[graph_num].edges(): u = int(u) v = int(v) masked_adj[u, v] = masked_adj[v, u] = node_importance[u] + node_importance[v] return masked_adj
def wasserstein(a, b, blur: float = 0.05, # smaller seem better, and possibly slower? scaling: float = 0.8, # the closer to 1 the slower but more accurate. splits: int = 1 # set as small as possible so that data fits on GPU. ): """ Compute the Wasserstein distance between two sets of features, a and b, of arbitrary size. """ # separate data into portions that can be handled by the GPU a, b = [np.split(x, splits) for x in [a, b]] # initiate loss function Loss = SamplesLoss("sinkhorn", p=2, blur=blur, scaling=scaling, backend="tensorized") # compute distance distance = np.mean([Loss(x, y).item() for x in loop_cuda(a) for y in loop_cuda(b)]) return distance
def make_parameterized_aligner(device, config, crossing_model=None): if crossing_model is None: crossing_model = load_crossing_model(config['crossing_model_weights'], device) loss = LossComposition() ot_loss = SamplesLoss("sinkhorn", **config['ot_loss']) loss.add( make_default_loss_fn(bce_schedule=(lambda state: 1.0) if config.get( 'bce_loss_enabled', False) else (lambda state: 0.0), ot_schedule=(lambda state: 1.0) if config.get( 'ot_loss_enabled', True) else (lambda state: 0.0), ot_loss=ot_loss)) if 'perceptual_bce' in config: for layer in config['perceptual_bce']: loss_component = perceptual_bce(crossing_model, layer) if 'perceptual_bce_from_step' not in config: loss.add(loss_component) else: def perceptual_loss(state): step = state['current_step'] if step >= config['perceptual_bce_from_step']: return loss_component(state) else: return torch.scalar_tensor(0.0) loss.add(perceptual_loss) if not config.get('ot_loss_enabled', True): loss.add(not_too_thin) grad_transformer = compose(strip_confidence_grads, coords_only_grads(config['coord_only_grads'])) aligner = StatefulBatchAligner(device=device) init_ot_aligner(aligner, loss_fn=loss, device=device, optimize_fn=make_default_optimize_fn( aligner, lr=config['lr'], transform_grads=grad_transformer, base_optimizer=optim.Adam, )) if config.get('use_best_batch', False): aligner.add_callback(save_best_batch) return aligner
def set_data(self, X, Y): ''' X and Y data, each in R^1 for now. (we simplify the problem to point clouds)''' assert len(X) == len(Y) self._X, self._Y = X.clone(), Y.clone() self._X = self._X if self._X.ndim == 2 else self._X.view(-1, 1) self._Y = self._Y if self._Y.ndim == 2 else self._Y.view(-1, 1) self._XY = torch.cat([self._X, self._Y], 1) self._X, self._Y, self._XY = self._X.type(dtype), self._Y.type( dtype), self._XY.type(dtype) self._fcm_net_causal = torch.nn.Sequential( torch.nn.Linear(2, self.num_hiddens), torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, self.num_hiddens), torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, 1)) self._fcm_net_anticausal = torch.nn.Sequential( torch.nn.Linear(2, self.num_hiddens), torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, self.num_hiddens), torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, 1)) self._fcm_net_causal.register_buffer('noise', torch.Tensor(len(self._X), 1)) self._fcm_net_anticausal.register_buffer('noise', torch.Tensor(len(self._Y), 1)) self._fcm_net_causal = self._fcm_net_causal.type(dtype) self._fcm_net_anticausal = self._fcm_net_anticausal.type(dtype) self.optim_causal = torch.optim.Adam(self._fcm_net_causal.parameters(), lr=self.lr) self.optim_anticausal = torch.optim.Adam( self._fcm_net_anticausal.parameters(), lr=self.lr) self.L_causal = SamplesLoss(loss=self.loss, p=self.p, blur=self.blur, scaling=self.scaling) self.L_anticausal = SamplesLoss(loss=self.loss, p=self.p, blur=self.blur, scaling=self.scaling) self.data_is_set = True
def cal_diff(x, y, norm="org", criterion="qt"): # 对原来的向量进行放缩 if norm == "softmax": x = F.softmax(x) y = F.softmax(y) elif norm == "logsoftmax": x = F.log_softmax(x) y = F.log_softmax(y) elif norm == "line": # logger.info("使用线性归一化") x = linear_normalization(x) y = linear_normalization(y) elif norm == "Gaussian": z = 1 # 实现高斯分布 # transform_BZ = transforms.Normalize( # mean=[0.5, 0.5, 0.5], # 取决于数据集 # std=[0.5, 0.5, 0.5] # ) # logger.info("使用高斯分布归一化") # 每个batch一起算 # KLloss = criterion(x, y) # return KLloss.item() # 每个batch 内单独算,最后算一个和 dim0 = x.shape[0] result = 0.0 blur = .05 OT_solver = SamplesLoss("sinkhorn", p=2, blur=blur, scaling=.9, debias=False, potentials=True) for i in range(dim0): if criterion == "kl": criterion_kl = nn.KLDivLoss() # notice 考虑了KL的不对称性 KLloss = (criterion_kl(x[i], y[i]) + criterion_kl(y[i], x[i])) / 2 result += KLloss.item() elif criterion == "qt": result += qt_cal(x[i], y[i]) else: # change wgan F_i, G_j = OT_solver(x[i], y[i]) # # print("F_i ",torch.sum(F_i).item()) result += (torch.sum(F_i).item()) # print("hi") return result / dim0
def __init__(self, weights: torch.Tensor, blur=.01, scaling=.9, diameter=None, p: int = 2): super(WeightedSamplesLoss, self).__init__() self.weights = weights self.loss = SamplesLoss("sinkhorn", blur=blur, scaling=scaling, debias=False, diameter=diameter, p=p)
def reset_net(self): self._net = torch.nn.Sequential( torch.nn.Linear(2, self.num_hiddens), torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, self.num_hiddens), torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, 1)) self._net.register_buffer('noise', torch.Tensor(len(self._X), 1)) self._net = self._net.type(dtype) self._L = SamplesLoss(loss=self.loss, p=self.p, blur=self.blur, scaling=self.scaling) self._opt = torch.optim.Adam(self._net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
def display_scaling(scaling=.5, Nits=9, debias=True): plt.figure(figsize=((12, ((Nits - 1) // 3 + 1) * 4))) for i in range(Nits): blur = scaling**i Loss = SamplesLoss("sinkhorn", p=2, blur=blur, diameter=1., scaling=scaling, debias=debias) # Create a copy of the data... a_i, x_i = A_i.clone(), X_i.clone() b_j, y_j = B_j.clone(), Y_j.clone() # And require grad: a_i.requires_grad = True x_i.requires_grad = True b_j.requires_grad = True # Compute the loss + gradients: Loss_xy = Loss(a_i, x_i, b_j, y_j) [F_i, G_j, dx_i] = grad(Loss_xy, [a_i, b_j, x_i]) # The generalized "Brenier map" is (minus) the gradient of the Sinkhorn loss # with respect to the Wasserstein metric: BrenierMap = -dx_i / (a_i.view(-1, 1) + 1e-7) # Fancy display: ----------------------------------------------------------- ax = plt.subplot(((Nits - 1) // 3 + 1), 3, i + 1) ax.scatter( [10], [10]) # shameless hack to prevent a slight change of axis... display_potential(ax, G_j, "#E2C5C5") display_potential(ax, F_i, "#C8DFF9") display_samples(ax, y_j, b_j, [(.55, .55, .95)]) display_samples(ax, x_i, a_i, [(.95, .55, .55)], v=BrenierMap) ax.set_title("iteration {}, blur = {:.3f}".format(i + 1, blur)) ax.set_xticks([0, 1]) ax.set_yticks([0, 1]) ax.axis([0, 1, 0, 1]) ax.set_aspect('equal', adjustable='box') plt.tight_layout()
def WL(X, Y, L): blur = 0.01 scaling = 0.99 loss = SamplesLoss("sinkhorn", blur=blur, scaling=scaling, debias=True, p=L, backend="tensorized") n = X.shape[0] # HX, _ = np.histogramdd(X, bins=100) # HY, _ = np.histogramdd(Y, bins=100) # # nonzero = np.where(HX > -1) # nonzero = np.concatenate([a[::, np.newaxis] for a in nonzero], axis=-1) # nonzero = 2 * 3.1622 * (nonzero + 0.5) / 100 # M = ot.dist(X, Y) # M = M ** (L/2) p = np.ones(n) / n # D = ot.emd2(p, p, M, processes=40, numItermax=100000) D = loss.forward(torch.from_numpy(p).type(torch.float32).cuda(), torch.from_numpy(X).type(torch.float32).cuda(), torch.from_numpy(p).type(torch.float32).cuda(), torch.from_numpy(Y).type(torch.float32).cuda()).item() return D ** (1/L)
def __init__(self, model_type='wav2vec', PRETRAINED_MODEL_PATH='/path/to/wav2vec_large.pt'): super().__init__() self.model_type = model_type self.wass_dist = SamplesLoss() if model_type == 'wav2vec': ckpt = torch.load(PRETRAINED_MODEL_PATH) self.model = Wav2VecModel.build_model(ckpt['args'], task=None) self.model.load_state_dict(ckpt['model']) self.model = self.model.feature_extractor self.model.eval() else: print('Please assign a loss model') sys.exit()
def compute_maps(objFileName, bodyFileName): body_pts = trimesh.load(bodyFileName, process=False) mesh = trimesh.load(objFileName, process=False) p = pkl.load(open('dat.pkl', 'rb')) size = 256 body_pts = trimesh.triangles.barycentric_to_points( body_pts.vertices[p['vertices']], p['weights']) X_i = torch.tensor(body_pts).type(dtype) Y_j = torch.tensor(mesh.sample(size * size)).type(dtype) x_i, a = gradient_flow(SamplesLoss("sinkhorn", p=2, blur=.01), X_i, Y_j, lr=0.5, do_a=True) mask = a > 0.3 x_i_new = x_i[mask].detach() x_i_final, _ = gradient_flow(SamplesLoss("sinkhorn", p=2, blur=.01), x_i_new, Y_j, lr=0.5, do_a=False) x_i_back = torch.zeros_like(x_i).type(dtype) x_i_back[mask] = x_i_final - x_i_new legal_map = mask.reshape([size, size]) dp_map = X_i.reshape([size, size, 3]) disp_map = x_i_back.reshape([size, size, 3]) ''' plt.figure(1) plt.imshow(disp_map.detach()[:,:,0].cpu().numpy()) plt.figure(2) plt.imshow(dp_map[:,:,0].cpu().numpy()) plt.figure(3) plt.imshow(legal_map.cpu().numpy()) plt.show() ''' return disp_map, dp_map, legal_map
def emd_loss(self, P, activations): n = activations.size(0) eps = .01 sum_act = torch.sum(torch.pow(activations, 2), 1) Q = sum_act + sum_act.view([ -1, 1 ]) - 2 * torch.matmul(activations, torch.transpose(activations, 0, 1)) Q = Q / self.dof Q = torch.pow(1 + Q, -(self.dof + 1) / 2) Q = Q * torch.from_numpy(np.ones((n, n)) - np.eye(n)).to( self.device) # Zero out diagonal Q = Q / torch.sum(Q) loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05) C = loss(P, Q) #C = self.sinkhorn_loss(P,Q,eps,len(P),1) return C