Example #1
0
def get_loss(kernel, eps):
    if kernel.kernel_type == 'quaternion':
        return SamplesLoss("sinkhorn",
                           blur=eps,
                           diameter=3.15,
                           cost=utils.quaternion_geodesic_distance,
                           backend='tensorized')
    elif kernel.kernel_type == 'squared_euclidean':
        return SamplesLoss("sinkhorn",
                           p=2,
                           blur=eps,
                           diameter=4.,
                           backend='tensorized')
    elif kernel.kernel_type == 'power_quaternion':
        dist = partial(utils.power_quaternion_geodesic_distance, kernel.power)
        #dist = partial(utils.sum_power_quaternion_geodesic_distance,kernel.power)
        #return sinkhorn_wasserstein_fisher_rao,dist
        return SamplesLoss("sinkhorn",
                           blur=eps,
                           diameter=10.,
                           cost=dist,
                           backend='tensorized')
    elif kernel.kernel_type == 'sum_power_quaternion':
        dist = partial(utils.sum_power_quaternion_geodesic_distance,
                       kernel.power)
        return SamplesLoss("sinkhorn",
                           blur=eps,
                           diameter=10.,
                           cost=dist,
                           backend='tensorized')
Example #2
0
def make_default_loss_fn(ot_schedule=None, bce_schedule=None, ot_loss=None):
    ot_loss = ot_loss or SamplesLoss(
        "sinkhorn", p=2, blur=.05, scaling=.6, reach=6.)
    bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    ot_schedule = ot_schedule or (lambda state: 1.0)
    bce_schedule = bce_schedule or (lambda state: 1.0)

    def compute_loss(state):
        vector_masses = state['vector_masses']
        vector_coords = state['vector_coords']
        raster_masses = state['raster_masses']
        raster_coords = state['raster_coords']

        bce_coef = bce_schedule(state)
        ot_coef = ot_schedule(state)

        bce_per_sample = bce_loss(state['render'], state['raster']).mean(
            dim=(1, 2)) if bce_coef != 0. else 0.
        ot_loss_per_sample = ot_loss(vector_masses, vector_coords,
                                     raster_masses,
                                     raster_coords) if ot_coef != 0. else 0.

        total_loss_per_sample = bce_coef * bce_per_sample + ot_coef * ot_loss_per_sample

        state['loss_per_sample'] = total_loss_per_sample

        return total_loss_per_sample.sum()

    return compute_loss
Example #3
0
    def __init__(self,
                 models,
                 eps=0.01,
                 lr=1e-2,
                 opt=torch.optim.Adam,
                 max_iter=10,
                 niter=15,
                 batchsize=128,
                 n_pairs=10,
                 tol=1e-3,
                 weight_decay=1e-5,
                 order='random',
                 unsymmetrize=True,
                 scaling=.9):

        self.models = models
        self.sk = SamplesLoss("sinkhorn",
                              p=2,
                              blur=eps,
                              scaling=scaling,
                              backend="auto")
        self.lr = lr
        self.opt = opt
        self.max_iter = max_iter
        self.niter = niter
        self.batchsize = batchsize
        self.n_pairs = n_pairs
        self.tol = tol
        self.weight_decay = weight_decay
        self.order = order
        self.unsymmetrize = unsymmetrize

        self.is_fitted = False
Example #4
0
 def __init__(self, blur=.01, scaling=.9, diameter=None, p: int = 2):
     super(WasLoss, self).__init__()
     self.loss = SamplesLoss("sinkhorn",
                             blur=blur,
                             scaling=scaling,
                             debias=False,
                             diameter=diameter,
                             p=p)
Example #5
0
    def __init__(self, config, device):

        self.ot_solver = SamplesLoss("sinkhorn",
                                     p=2,
                                     blur=config.sinkhorn_blur,
                                     scaling=config.sinkhorn_scaling,
                                     debias=True)
        self.device = device
Example #6
0
 def __init__(self, config):
     super(NeuralNetRougeRegModel, self).__init__()
     self.config = config
     self.sinkhorn = SamplesLoss(loss='sinkhorn',
                                 p=self.config['p'],
                                 blur=self.config['blur'],
                                 scaling=self.config['scaling'])
     self.layer = nn.Linear(self.config['D_in'], self.config['D_out'])
 def __init__(self, config):
     super(LinSinkhornRegModel, self).__init__()
     self.config = config
     self.layer = nn.Linear(self.config['D'], self.config['D'], bias=False)
     self.sinkhorn = SamplesLoss(loss='sinkhorn',
                                 p=self.config['p'],
                                 blur=self.config['blur'],
                                 scaling=self.config['scaling'])
Example #8
0
 def __init__(self, config):
     super(TransformSinkhornRegModel, self).__init__()
     self.config = config
     self.M = nn.Parameter(
         torch.randn(self.config['D_in'], self.config['D_out']))
     self.sinkhorn = SamplesLoss(loss='sinkhorn',
                                 p=self.config['p'],
                                 blur=self.config['blur'],
                                 scaling=self.config['scaling'])
 def __init__(self, config):
     super(NNSinkhornPRModel, self).__init__()
     self.config = config
     self.layer = nn.Linear(self.config['D'], self.config['D'])
     self.sinkhorn = SamplesLoss(loss='sinkhorn',
                                 p=self.config['p'],
                                 blur=self.config['blur'],
                                 scaling=self.config['scaling'])
     self.sigm = nn.Sigmoid()
 def __init__(self, config):
     super(CondLinSinkhornPRModel, self).__init__()
     self.config = config
     self.model = AvgModel(self.config['D'])
     self.sinkhorn = SamplesLoss(loss='sinkhorn',
                                 p=self.config['p'],
                                 blur=self.config['blur'],
                                 scaling=self.config['scaling'])
     self.sigm = nn.Sigmoid()
Example #11
0
def batch_EMD_loss(x, y):
    bs, num_points_x, points_dim = x.size()
    _, num_points_y, _ = y.size()
    batch_EMD = 0
    L = SamplesLoss()
    for i in range(bs):
        loss = L(x[i], y[i])
        batch_EMD += loss
    emd = batch_EMD / bs
    return emd
Example #12
0
 def triplet_distance(self, anchor, positive, negative):
     #       return torch.nn.functional.relu((anchor-positive).pow(2).sum()-(anchor-negative).pow(2).sum()+self.margin)
     wloss = SamplesLoss("sinkhorn",
                         p=2,
                         blur=0.05,
                         scaling=.99,
                         backend="online")
     loss1 = wloss(torch.Tensor(anchor), torch.Tensor(positive))
     loss2 = wloss(torch.Tensor(anchor), torch.Tensor(negative))
     return torch.nn.functional.relu(loss1 - loss2 + self.margin)
Example #13
0
 def triplet_distance(self, anchor, positive, negative):
     wloss = SamplesLoss(loss="sinkhorn",
                         p=2,
                         blur=0.05,
                         scaling=.99,
                         backend="online")
     d1 = wloss(anchor, positive)
     d2 = wloss(anchor, negative)
     return torch.nn.functional.relu(
         d1.pow(2).sum() - d2.pow(2).sum() + self.margin)
Example #14
0
def train_approx(args, fmodel, gmodel, device, approx_loader, f_optimizer, g_optimizer, output_samples, epoch):
    gmodel.train()
    fmodel.train()
    LF = SamplesLoss(loss='sinkhorn', p=2, blur=0.05, backend='tensorized', cost=cost_func)

    for batch_idx, (data, target) in enumerate(approx_loader):
        data, target = data.to(device), target.to(device)
        f_optimizer.zero_grad()

        with torch.no_grad():
            # To be consistant with KL, the exp() function is changed to softplus,
            # i.e., alpha0 = softplus(g).
            # Note that, for mmd, the exp() function can be used directly for faster convergence,
            # without tuning hyper-parameters.
            g_out = F.softplus(gmodel(data))
            output = F.softmax((output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to(
                device)), dim=2).clamp(0.0001, 0.9999)

        f_out = F.softmax(fmodel(data), dim=1)

        pi = f_out.mul(g_out)

        s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1,0,2)

        loss = LF(output, s1).mean()

        loss.backward()
        f_optimizer.step()

        if batch_idx == 0:
            print('Train Epoch: {}, Loss: {:.6f}'.format(
                epoch, loss.item()))

        g_optimizer.zero_grad()

        g_out = F.softplus(gmodel(data))

        with torch.no_grad():
            output = F.softmax((output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to(
                    device)), dim=2).clamp(0.0001, 0.9999)

        with torch.no_grad():
            f_out = F.softmax(fmodel(data), dim=1)

        pi = f_out.mul(g_out)
        s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1,0,2)

        loss = LF(output, s1).mean()

        loss.backward()
        g_optimizer.step()

        if batch_idx == 0:
            print('Train Epoch: {}, Loss: {:.6f}'.format(
                epoch, loss.item()))
Example #15
0
 def __init__(self, sigma, a):
     super(MyLoss, self).__init__()
     self.sigma = sigma
     self.a = a
     self.loss = SamplesLoss("sinkhorn",
                             p=1,
                             blur=.1,
                             scaling=0.8,
                             verbose=True)  #, backend="tensorized")
     self.l1 = nn.L1Loss()
     self.mse = nn.MSELoss()
Example #16
0
 def __init__(self, particles, rm_map, eps, max_iter, particles_type):
     super(SinkhornEval, self).__init__()
     self.eps = eps
     self.particles_type = particles_type
     self.particles = particles
     self.RM_map = rm_map
     self.loss = SamplesLoss("sinkhorn",
                             blur=eps,
                             diameter=3.15,
                             cost=utils.quaternion_geodesic_distance,
                             backend='tensorized')
Example #17
0
def train_approx(args, fmodel, gmodel, device, approx_loader, f_optimizer, g_optimizer, output_samples, epoch):
    gmodel.train()
    fmodel.train()
    LF = SamplesLoss(loss='sinkhorn', p=2, blur=0.05, backend='tensorized', cost=cost_func)

    for batch_idx, (data, target) in enumerate(approx_loader):
        data, target = data.to(device), target.to(device)
        f_optimizer.zero_grad()

        with torch.no_grad():
            g_out = torch.exp(gmodel(data))
            output = output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to(
                device).clamp(0.0001, 0.9999)

        f_out = F.softmax(fmodel(data), dim=1)

        pi = f_out.mul(g_out)

        s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1, 0, 2).contiguous()

        loss = LF(output, s1).mean()

        loss.backward()
        f_optimizer.step()

        if batch_idx == 0:
            print('Train Epoch: {}, Loss: {:.6f}'.format(
                epoch, loss.item()))

        g_optimizer.zero_grad()

        g_out = torch.exp(gmodel(data))

        # data = data.view(data.shape[0], -1)
        with torch.no_grad():
            output = output_samples[batch_idx * approx_loader.batch_size:(batch_idx + 1) * approx_loader.batch_size].to(
                device).clamp(0.0001, 0.9999)

        with torch.no_grad():
            f_out = F.softmax(fmodel(data), dim=1)

        pi = f_out.mul(g_out)
        s1 = torch.distributions.Dirichlet(pi).rsample((num_samples,)).permute(1, 0, 2).contiguous()

        loss = LF(output, s1).mean()

        # print(loss.item())
        loss.backward()
        g_optimizer.step()

        if batch_idx == 0:
            print('Train Epoch: {}, Loss: {:.6f}'.format(
                epoch, loss.item()))
Example #18
0
    def forward(self, batch, labels, gt_labels=None):
        """
        Args:
            batch:   torch.Tensor: Input of embeddings with size (BS x DIM)
            labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1)
        """
        if isinstance(labels, torch.Tensor):
            labels = labels.detach().cpu().numpy()

        if callable(self.sampler):
            sampled_triplets = self.sampler(batch, labels, gt_labels)
        else:
            sampled_triplets = self.sampler.give(batch, labels, gt_labels)

        d_ap, d_an = [], []
        for triplet in sampled_triplets:
            train_triplet = {
                'Anchor': batch[triplet[0], :],
                'Positive': batch[triplet[1], :],
                'Negative': batch[triplet[2]]
            }
            wloss = SamplesLoss(loss="sinkhorn",
                                p=2,
                                blur=0.05,
                                scaling=.99,
                                backend="online")
            d1 = wloss(train_triplet['Anchor'], train_triplet['Positive'])
            d2 = wloss(train_triplet['Anchor'], train_triplet['Negative'])
            pos_dist = (d1.pow(2).sum() + 1e-8).pow(1 / 2)
            neg_dist = (d2.pow(2).sum() + 1e-8).pow(1 / 2)
            d_ap.append(pos_dist)
            d_an.append(neg_dist)
        d_ap, d_an = torch.stack(d_ap), torch.stack(d_an)

        if self.beta_constant:
            beta = self.beta
        else:
            beta = torch.stack([
                self.beta[labels[triplet[0]]] for triplet in sampled_triplets
            ]).type(torch.cuda.FloatTensor)

        pos_loss = torch.nn.functional.relu(d_ap - beta + self.margin)
        neg_loss = torch.nn.functional.relu(beta - d_an + self.margin)

        pair_count = torch.sum((pos_loss > 0.) + (neg_loss > 0.)).type(
            torch.cuda.FloatTensor)

        if pair_count == 0.:
            loss = torch.sum(pos_loss + neg_loss)
        else:
            loss = torch.sum(pos_loss + neg_loss) / pair_count
        # if self.nu: loss = loss + beta_regularisation_loss.type(torch.cuda.FloatTensor)
        return loss
Example #19
0
    def explain(graph_num):
        cur_embs = torch.Tensor(graph_embs[graph_num])

        distance = SamplesLoss("sinkhorn", p=1, blur=.01)

        positive_ids, negative_ids = closest(graph_num, graph_distance, size=similar_size)

        positive_embs = [torch.Tensor(graph_embs[i]) for i in positive_ids]
        negative_embs = [torch.Tensor(graph_embs[i]) for i in negative_ids]

        mask = torch.nn.Parameter(torch.zeros(len(cur_embs)))

        learning_rate = 1e-1
        optimizer = torch.optim.Adam([mask], lr=learning_rate)

        if distance_str == 'ot':
            def mydist(mask, embs):
                return distance(mask.softmax(0), cur_embs,
                                distance.generate_weights(embs), embs)
        else:
            def mydist(mask, embs):
                return torch.dist((cur_embs * mask.softmax(0).reshape(-1, 1)).sum(axis=0), embs.mean(axis=0))
        # tq = tqdm(range(50))
        history = []
        for t in range(50):
            loss_pos = torch.mean(torch.stack([mydist(mask, x) for x in positive_embs]))
            loss_neg = torch.mean(torch.stack([mydist(mask, x) for x in negative_embs]))
            loss_self = mydist(mask, cur_embs)

            loss = 0
            if '-' in loss_str:
                loss = loss + loss_neg
            if '+' in loss_str:
                loss = loss - loss_pos
            if 's' in loss_str:
                loss = loss + loss_self

            hist_item = dict(loss_neg=loss_neg.item(), loss_self=loss_self.item(), loss_pos=loss_pos.item(),
                             loss=loss.item())
            history.append(hist_item)
            # tq.set_postfix(**hist_item)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        node_importance = list(1 - mask.softmax(0).detach().numpy().ravel())
        N = nx_graphs[graph_num].number_of_nodes()
        masked_adj = np.zeros((N, N))
        for u, v in nx_graphs[graph_num].edges():
            u = int(u)
            v = int(v)
            masked_adj[u, v] = masked_adj[v, u] = node_importance[u] + node_importance[v]
        return masked_adj
Example #20
0
def wasserstein(a, b,
                blur: float = 0.05, # smaller seem better, and possibly slower?
                scaling: float = 0.8, # the closer to 1 the slower but more accurate.
                splits: int = 1 # set as small as possible so that data fits on GPU.
                ):
    """ Compute the Wasserstein distance between two sets of features, a and b, of arbitrary size. """
    # separate data into portions that can be handled by the GPU
    a, b = [np.split(x, splits) for x in [a, b]]
    # initiate loss function
    Loss = SamplesLoss("sinkhorn", p=2, blur=blur, scaling=scaling, backend="tensorized")
    # compute distance
    distance = np.mean([Loss(x, y).item() for x in loop_cuda(a) for y in loop_cuda(b)])
    return distance
Example #21
0
def make_parameterized_aligner(device, config, crossing_model=None):
    if crossing_model is None:
        crossing_model = load_crossing_model(config['crossing_model_weights'],
                                             device)

    loss = LossComposition()
    ot_loss = SamplesLoss("sinkhorn", **config['ot_loss'])
    loss.add(
        make_default_loss_fn(bce_schedule=(lambda state: 1.0) if config.get(
            'bce_loss_enabled', False) else (lambda state: 0.0),
                             ot_schedule=(lambda state: 1.0) if config.get(
                                 'ot_loss_enabled', True) else
                             (lambda state: 0.0),
                             ot_loss=ot_loss))
    if 'perceptual_bce' in config:
        for layer in config['perceptual_bce']:
            loss_component = perceptual_bce(crossing_model, layer)
            if 'perceptual_bce_from_step' not in config:
                loss.add(loss_component)
            else:

                def perceptual_loss(state):
                    step = state['current_step']
                    if step >= config['perceptual_bce_from_step']:
                        return loss_component(state)
                    else:
                        return torch.scalar_tensor(0.0)

                loss.add(perceptual_loss)

    if not config.get('ot_loss_enabled', True):
        loss.add(not_too_thin)

    grad_transformer = compose(strip_confidence_grads,
                               coords_only_grads(config['coord_only_grads']))

    aligner = StatefulBatchAligner(device=device)
    init_ot_aligner(aligner,
                    loss_fn=loss,
                    device=device,
                    optimize_fn=make_default_optimize_fn(
                        aligner,
                        lr=config['lr'],
                        transform_grads=grad_transformer,
                        base_optimizer=optim.Adam,
                    ))

    if config.get('use_best_batch', False):
        aligner.add_callback(save_best_batch)

    return aligner
    def set_data(self, X, Y):
        ''' X and Y data, each in R^1 for now. (we simplify the problem to point clouds)'''
        assert len(X) == len(Y)
        self._X, self._Y = X.clone(), Y.clone()
        self._X = self._X if self._X.ndim == 2 else self._X.view(-1, 1)
        self._Y = self._Y if self._Y.ndim == 2 else self._Y.view(-1, 1)
        self._XY = torch.cat([self._X, self._Y], 1)
        self._X, self._Y, self._XY = self._X.type(dtype), self._Y.type(
            dtype), self._XY.type(dtype)

        self._fcm_net_causal = torch.nn.Sequential(
            torch.nn.Linear(2, self.num_hiddens), torch.nn.ReLU(),
            torch.nn.Linear(self.num_hiddens, self.num_hiddens),
            torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, 1))
        self._fcm_net_anticausal = torch.nn.Sequential(
            torch.nn.Linear(2, self.num_hiddens), torch.nn.ReLU(),
            torch.nn.Linear(self.num_hiddens, self.num_hiddens),
            torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, 1))
        self._fcm_net_causal.register_buffer('noise',
                                             torch.Tensor(len(self._X), 1))
        self._fcm_net_anticausal.register_buffer('noise',
                                                 torch.Tensor(len(self._Y), 1))
        self._fcm_net_causal = self._fcm_net_causal.type(dtype)
        self._fcm_net_anticausal = self._fcm_net_anticausal.type(dtype)
        self.optim_causal = torch.optim.Adam(self._fcm_net_causal.parameters(),
                                             lr=self.lr)
        self.optim_anticausal = torch.optim.Adam(
            self._fcm_net_anticausal.parameters(), lr=self.lr)
        self.L_causal = SamplesLoss(loss=self.loss,
                                    p=self.p,
                                    blur=self.blur,
                                    scaling=self.scaling)
        self.L_anticausal = SamplesLoss(loss=self.loss,
                                        p=self.p,
                                        blur=self.blur,
                                        scaling=self.scaling)

        self.data_is_set = True
Example #23
0
def cal_diff(x, y, norm="org", criterion="qt"):
    # 对原来的向量进行放缩
    if norm == "softmax":
        x = F.softmax(x)
        y = F.softmax(y)
    elif norm == "logsoftmax":
        x = F.log_softmax(x)
        y = F.log_softmax(y)
    elif norm == "line":
        # logger.info("使用线性归一化")
        x = linear_normalization(x)
        y = linear_normalization(y)
    elif norm == "Gaussian":
        z = 1
        # 实现高斯分布
        # transform_BZ = transforms.Normalize(
        #     mean=[0.5, 0.5, 0.5],  # 取决于数据集
        #     std=[0.5, 0.5, 0.5]
        # )
        # logger.info("使用高斯分布归一化")

    # 每个batch一起算
    # KLloss = criterion(x, y)
    # return KLloss.item()

    # 每个batch 内单独算,最后算一个和
    dim0 = x.shape[0]
    result = 0.0
    blur = .05
    OT_solver = SamplesLoss("sinkhorn",
                            p=2,
                            blur=blur,
                            scaling=.9,
                            debias=False,
                            potentials=True)
    for i in range(dim0):
        if criterion == "kl":
            criterion_kl = nn.KLDivLoss()
            # notice 考虑了KL的不对称性
            KLloss = (criterion_kl(x[i], y[i]) + criterion_kl(y[i], x[i])) / 2
            result += KLloss.item()
        elif criterion == "qt":
            result += qt_cal(x[i], y[i])
        else:
            # change wgan
            F_i, G_j = OT_solver(x[i], y[i])
            # # print("F_i ",torch.sum(F_i).item())
            result += (torch.sum(F_i).item())
            # print("hi")
    return result / dim0
 def __init__(self,
              weights: torch.Tensor,
              blur=.01,
              scaling=.9,
              diameter=None,
              p: int = 2):
     super(WeightedSamplesLoss, self).__init__()
     self.weights = weights
     self.loss = SamplesLoss("sinkhorn",
                             blur=blur,
                             scaling=scaling,
                             debias=False,
                             diameter=diameter,
                             p=p)
 def reset_net(self):
     self._net = torch.nn.Sequential(
         torch.nn.Linear(2, self.num_hiddens), torch.nn.ReLU(),
         torch.nn.Linear(self.num_hiddens, self.num_hiddens),
         torch.nn.ReLU(), torch.nn.Linear(self.num_hiddens, 1))
     self._net.register_buffer('noise', torch.Tensor(len(self._X), 1))
     self._net = self._net.type(dtype)
     self._L = SamplesLoss(loss=self.loss,
                           p=self.p,
                           blur=self.blur,
                           scaling=self.scaling)
     self._opt = torch.optim.Adam(self._net.parameters(),
                                  lr=self.lr,
                                  weight_decay=self.weight_decay)
Example #26
0
def display_scaling(scaling=.5, Nits=9, debias=True):

    plt.figure(figsize=((12, ((Nits - 1) // 3 + 1) * 4)))

    for i in range(Nits):
        blur = scaling**i
        Loss = SamplesLoss("sinkhorn",
                           p=2,
                           blur=blur,
                           diameter=1.,
                           scaling=scaling,
                           debias=debias)

        # Create a copy of the data...
        a_i, x_i = A_i.clone(), X_i.clone()
        b_j, y_j = B_j.clone(), Y_j.clone()

        # And require grad:
        a_i.requires_grad = True
        x_i.requires_grad = True
        b_j.requires_grad = True

        # Compute the loss + gradients:
        Loss_xy = Loss(a_i, x_i, b_j, y_j)
        [F_i, G_j, dx_i] = grad(Loss_xy, [a_i, b_j, x_i])

        # The generalized "Brenier map" is (minus) the gradient of the Sinkhorn loss
        # with respect to the Wasserstein metric:
        BrenierMap = -dx_i / (a_i.view(-1, 1) + 1e-7)

        # Fancy display: -----------------------------------------------------------
        ax = plt.subplot(((Nits - 1) // 3 + 1), 3, i + 1)
        ax.scatter(
            [10], [10])  # shameless hack to prevent a slight change of axis...

        display_potential(ax, G_j, "#E2C5C5")
        display_potential(ax, F_i, "#C8DFF9")

        display_samples(ax, y_j, b_j, [(.55, .55, .95)])
        display_samples(ax, x_i, a_i, [(.95, .55, .55)], v=BrenierMap)

        ax.set_title("iteration {}, blur = {:.3f}".format(i + 1, blur))

        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        ax.axis([0, 1, 0, 1])
        ax.set_aspect('equal', adjustable='box')

    plt.tight_layout()
Example #27
0
def WL(X, Y, L):

  blur = 0.01
  scaling = 0.99
  loss = SamplesLoss("sinkhorn", blur=blur, scaling=scaling, debias=True, p=L, backend="tensorized")

  n = X.shape[0]

  # HX, _ = np.histogramdd(X, bins=100)
  # HY, _ = np.histogramdd(Y, bins=100)
  #
  # nonzero = np.where(HX > -1)
  # nonzero = np.concatenate([a[::, np.newaxis] for a in nonzero], axis=-1)
  # nonzero = 2 * 3.1622 * (nonzero + 0.5) / 100

  # M = ot.dist(X, Y)
  # M = M ** (L/2)
  p = np.ones(n) / n
  # D = ot.emd2(p, p, M, processes=40, numItermax=100000)
  D = loss.forward(torch.from_numpy(p).type(torch.float32).cuda(),
                   torch.from_numpy(X).type(torch.float32).cuda(),
                   torch.from_numpy(p).type(torch.float32).cuda(),
                   torch.from_numpy(Y).type(torch.float32).cuda()).item()
  return D ** (1/L)
 def __init__(self,
              model_type='wav2vec',
              PRETRAINED_MODEL_PATH='/path/to/wav2vec_large.pt'):
     super().__init__()
     self.model_type = model_type
     self.wass_dist = SamplesLoss()
     if model_type == 'wav2vec':
         ckpt = torch.load(PRETRAINED_MODEL_PATH)
         self.model = Wav2VecModel.build_model(ckpt['args'], task=None)
         self.model.load_state_dict(ckpt['model'])
         self.model = self.model.feature_extractor
         self.model.eval()
     else:
         print('Please assign a loss model')
         sys.exit()
def compute_maps(objFileName, bodyFileName):
    body_pts = trimesh.load(bodyFileName, process=False)
    mesh = trimesh.load(objFileName, process=False)
    p = pkl.load(open('dat.pkl', 'rb'))
    size = 256
    body_pts = trimesh.triangles.barycentric_to_points(
        body_pts.vertices[p['vertices']], p['weights'])
    X_i = torch.tensor(body_pts).type(dtype)
    Y_j = torch.tensor(mesh.sample(size * size)).type(dtype)
    x_i, a = gradient_flow(SamplesLoss("sinkhorn", p=2, blur=.01),
                           X_i,
                           Y_j,
                           lr=0.5,
                           do_a=True)
    mask = a > 0.3
    x_i_new = x_i[mask].detach()
    x_i_final, _ = gradient_flow(SamplesLoss("sinkhorn", p=2, blur=.01),
                                 x_i_new,
                                 Y_j,
                                 lr=0.5,
                                 do_a=False)
    x_i_back = torch.zeros_like(x_i).type(dtype)
    x_i_back[mask] = x_i_final - x_i_new
    legal_map = mask.reshape([size, size])
    dp_map = X_i.reshape([size, size, 3])
    disp_map = x_i_back.reshape([size, size, 3])
    '''
    plt.figure(1)
    plt.imshow(disp_map.detach()[:,:,0].cpu().numpy())
    plt.figure(2)
    plt.imshow(dp_map[:,:,0].cpu().numpy())
    plt.figure(3)
    plt.imshow(legal_map.cpu().numpy())
    plt.show()
    '''
    return disp_map, dp_map, legal_map
Example #30
0
 def emd_loss(self, P, activations):
     n = activations.size(0)
     eps = .01
     sum_act = torch.sum(torch.pow(activations, 2), 1)
     Q = sum_act + sum_act.view([
         -1, 1
     ]) - 2 * torch.matmul(activations, torch.transpose(activations, 0, 1))
     Q = Q / self.dof
     Q = torch.pow(1 + Q, -(self.dof + 1) / 2)
     Q = Q * torch.from_numpy(np.ones((n, n)) - np.eye(n)).to(
         self.device)  # Zero out diagonal
     Q = Q / torch.sum(Q)
     loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
     C = loss(P, Q)
     #C = self.sinkhorn_loss(P,Q,eps,len(P),1)
     return C