def forward(self, x, label=None):

        # Part of netgative pairs
        if self.sample_type == 'PoN':
            x1 = x[:, 0, :]
            x2 = x[:, 1, :]
            nloss1 = torch.pdist(x1,
                                 p=2).pow(2).mul(-self.t).exp().mean().log()
            nloss2 = torch.pdist(x2,
                                 p=2).pow(2).mul(-self.t).exp().mean().log()
            nloss = (nloss1 + nloss2) / 2

        # All positive and negtive pairs
        elif self.sample_type == 'APN':
            x = torch.reshape(x, (-1, x.size()[-1]))
            nloss = torch.pdist(x, p=2).pow(2).mul(-self.t).exp().mean().log()

        # All negative pairs
        elif self.sample_type == 'AN':
            x1 = x[:, 0, :]
            x2 = x[:, 1, :]
            K = x.size()[0]

            nloss1 = torch.pdist(x1, p=2).pow(2).mul(-self.t).exp().sum()
            nloss2 = torch.pdist(x2, p=2).pow(2).mul(-self.t).exp().sum()
            nloss3 = F.pairwise_distance(x1.unsqueeze(-1),
                                         x2.unsqueeze(-1).transpose(
                                             0, 2)).pow(2).mul(-self.t).exp()
            nloss3 = nloss3 * (1 - torch.eye(K).cuda())
            nloss3 = nloss3.sum()
            nloss = torch.log(
                torch.div(nloss1 + nloss2 + nloss3,
                          K * (K - 1) / 2 + K * (K - 1) / 2 + K * (K - 1)))

        return nloss, 0
Exemplo n.º 2
0
def relative_teacher_distances(features_a, features_b, normalize=False, distance="l2", **kwargs):
    """Distillation loss between the teacher and the student comparing distances
    instead of embeddings.

    Reference:
        * Lu Yu et al.
          Learning Metrics from Teachers: Compact Networks for Image Embedding.
          CVPR 2019.

    :param features_a: ConvNet features of a model.
    :param features_b: ConvNet features of a model.
    :return: A float scalar loss.
    """
    if normalize:
        features_a = F.normalize(features_a, dim=-1, p=2)
        features_b = F.normalize(features_b, dim=-1, p=2)

    if distance == "l2":
        p = 2
    elif distance == "l1":
        p = 1
    else:
        raise ValueError("Invalid distance for relative teacher {}.".format(distance))

    pairwise_distances_a = torch.pdist(features_a, p=p)
    pairwise_distances_b = torch.pdist(features_b, p=p)

    return torch.mean(torch.abs(pairwise_distances_a - pairwise_distances_b))
Exemplo n.º 3
0
def mmd(
    X: torch.Tensor,
    Y: torch.Tensor,
    implementation: str = "tp_sutherland",
    z_score: bool = False,
    bandwidth: str = "X",
) -> torch.Tensor:
    """Estimate MMD^2 statistic with Gaussian kernel

    Currently different implementations are available, in order to validate accuracy and compare speeds. The widely used median heuristic for bandwidth-selection of the Gaussian kernel is used.
    """
    if torch.isnan(X).any() or torch.isnan(Y).any():
        return torch.tensor(float("nan"))

    tic = time.time()  # noqa

    if z_score:
        X_mean = torch.mean(X, axis=0)
        X_std = torch.std(X, axis=0)
        X = (X - X_mean) / X_std
        Y = (Y - X_mean) / X_std

    n_1 = X.shape[0]
    n_2 = Y.shape[0]

    # Bandwidth
    if bandwidth == "X":
        sigma_tensor = torch.median(torch.pdist(X))
    elif bandwidth == "XY":
        sigma_tensor = torch.median(torch.pdist(torch.cat([X, Y])))
    else:
        raise NotImplementedError

    # Compute MMD
    if implementation == "tp_sutherland":
        K = tp_ExpQuadKernel(X, Y, sigma=sigma_tensor)
        statistic = tp_mmd2_unbiased(K)

    elif implementation == "tp_djolonga":
        alpha = 1 / (2 * sigma_tensor**2)
        test = tp_MMDStatistic(n_1, n_2)
        statistic = test(X, Y, [alpha])

    else:
        raise NotImplementedError

    toc = time.time()  # noqa
    # log.info(f"Took {toc-tic:.3f}sec")

    return statistic
Exemplo n.º 4
0
def disp_vector_dist():
    #
    #  This plots the distance between each pair of digits.
    #      Interestingly, these distances are all nearly
    #    the same in model VAE20190716_1551, consistent with the mean
    #    encodings being distributed on a 20-sphere.
    #
    for batch_idx, (data, which_digit) in enumerate(train_loader):
        break
    plt.clf()
    fc21current, fc22current = model.encode(data.cuda().view(-1, 784))

    fc21disp = torch.zeros((10, model.z_dimension))

    for i in range(10):
        fc21disp[i,:] = \
        torch.mean(fc21current[which_digit==i,:],0)

    mu_dist = torch.pdist(fc21disp)

    dist_disp = np.zeros((10, 10))

    counter = 0
    for i in range(10):
        for j in range(i + 1, 10):
            dist_disp[i, j] = mu_dist[counter]
            counter = counter + 1
    plt.title('Vector Distances')
    plt.imshow(dist_disp)
    plt.pause(0.5)
Exemplo n.º 5
0
def stochastic_instance_closeness_pdist_loss(mX, easy_ratio, hard_ratio):
    dists = torch.pdist(mX)**2

    num_dists = dists.shape[0]

    dists, _ = torch.sort(dists)

    loss = 0.0

    if easy_ratio > 0:
        easy_dists = dists[:int(num_dists * easy_ratio)]

        easy_loss = torch.mean(easy_dists)

        loss += easy_loss

    if hard_ratio > 0:
        hard_dists = dists[int(-num_dists * hard_ratio):]

        print(dists.shape, hard_dists.shape)
        hard_loss = (hard_ratio /
                     (hard_ratio + easy_ratio)) * torch.mean(hard_dists)

        loss += hard_loss

    return loss
Exemplo n.º 6
0
def _get_initial_lengthscale(f_X_samples):
    if torch.cuda.is_available():
        f_X_samples = f_X_samples.cuda()

    initial_lengthscale = torch.pdist(f_X_samples).mean()

    return initial_lengthscale.cpu()
Exemplo n.º 7
0
 def get_triplets(self, embs, labels, selection_fn=None):
     if self.cpu:
         embs = embs.cpu()
     dm = torch.pdist(embs).cpu()
     labels = labels.cpu().data.numpy()
     triplets = []
     for label in set(labels):
         mask = labels == label
         if mask.sum() < 2:
             continue  # no labels with only 1 positive
         label_idx = np.where(mask)[0]
         neg_idx = torch.LongTensor(np.where(np.logical_not(mask))[0])
         pos_pairs = torch.LongTensor(
             list(itertools.combinations(label_idx, 2)))
         pos_dists = dm[condensed_index(pos_pairs[:, 0], pos_pairs[:, 1],
                                        embs.shape[0])]
         for (i, j), dist in zip(pos_pairs, pos_dists):
             loss = dist - dm[condensed_index(i, neg_idx,
                                              embs.shape[0])] + self.margin
             loss = loss.data.cpu().numpy()
             if selection_fn is None:
                 hard_idx = self.selection_fn(loss)
             else:
                 hard_idx = selection_fn(loss)
             if hard_idx is not None:
                 triplets.append([i, j, neg_idx[hard_idx]])
     if not triplets:
         print('No triplets found... Sampling random hard ones.')
         triplets = self.get_triplets(embs, torch.LongTensor(labels),
                                      random_hard_negative)
     return torch.LongTensor(triplets)
 def _rbf_kernel(x):
     n = x.shape[0]
     pw_dists = torch.pdist(x)**2  # Upper triangle
     med = torch.median(pw_dists)
     return SteinVariationalGradientDescent._upper_tria_to_full(
         n,
         torch.exp(-(torch.log(torch.tensor(n, dtype=float)) /
                     (med**2 + 1e-8)) * pw_dists))
Exemplo n.º 9
0
 def semi_loss_function(self, image_z, text_z, label):
     a = []
     for i in range(0, len(label) - 1):
         for j in range(i + 1, len(label)):
             if label[i] == -1 or label[j] == -1:
                 a.append(0.)
             elif label[i] == label[j]:
                 a.append(1.)
             else:
                 a.append(-1.)
     a = torch.from_numpy(np.array(a, dtype=np.float32))
     image_dist = torch.pdist(image_z)
     image_loss = self.trade_off * torch.sum(a * image_dist) / len(label)
     text_dist = torch.pdist(text_z)
     text_loss = self.trade_off * torch.sum(a * text_dist) / len(label)
     semi_loss = image_loss + text_loss
     return semi_loss
Exemplo n.º 10
0
    def forward(self, tree_data: StrDict) -> torch.Tensor:
        """
        Calculate the pairwise distances between the input trees

        :param tree_data: collated trees from the `route_distances.utils.collate_trees` function.
        :return: the distances in condensed form
        """
        lstm_enc = self._tree_lstm(tree_data)
        return torch.pdist(lstm_enc)
Exemplo n.º 11
0
def NearestNeighborMatch(pts, desc):
        # input:
        # pts: BxNx3 or BxNx2 for SIFT
        # desc: Bx256xN or BX128XN for SIFT
        # output: 
        # corr: BxNx4 correspondence
        # score: BxN matching score

        # compute distance matrix
        dist = torch.pdist(desc.transpose(1,2).contiguous())
        distmat = torch_squareform(dist)
Exemplo n.º 12
0
def compute_pdist_matrix(batch, p=2.0):
    """
    Computes the matrix of pairwise distances w.r.t. p-norm
    :param batch: torch.Tensor, input vectors
    :param p:     float, norm parameter, such that ||x||p = (sum_i |x_i|^p)^(1/p)
    :return: torch.Tensor, matrix A such that A_ij = ||batch[i] - batch[j]||_p (for flattened batch[i] and batch[j])
    """
    mat = torch.zeros(batch.shape[0], batch.shape[0], device=batch.device)
    ind = torch.triu_indices(batch.shape[0], batch.shape[0], offset=1, device=batch.device)
    mat[ind[0], ind[1]] = torch.pdist(batch.view(batch.shape[0], -1), p=p)

    return mat + mat.transpose(0, 1)
Exemplo n.º 13
0
    def __pdist_dist(self, x1):
        # Compute squared distance matrix using torch.pdist
        dists = torch.pdist(x1)
        inds_l, inds_r = triu_indices(x1.shape[-2], 1)
        res = torch.zeros(*x1.shape[:-2],
                          x1.shape[-2],
                          x1.shape[-2],
                          dtype=x1.dtype,
                          device=x1.device)
        res[..., inds_l, inds_r] = dists
        res[..., inds_r, inds_l] = dists

        return res
Exemplo n.º 14
0
def row_diff(x: Union[torch.Tensor, QTensor]) -> float:
    r"""
    Implementation of the row_diff value from the paper "PAIRNORM: TACKLING OVERSMOOTHING IN GNNS" (ICLR 2020) by
    Lingxiao Zhao and Leman Akoglu

    The row-diff measure is the average of all pairwise distances between the node features (i.e., rows of
    the representation matrix) and quantifies node-wise oversmoothing.
    """
    if x.__class__.__name__ == "QTensor":
        x = x.stack(dim=0).permute(1, 0, 2)

    x = x.reshape(x.size(0), -1)
    pdist = torch.pdist(x, p=2)
    row_diff_value = (pdist.sum() / pdist.numel()).item()
    return row_diff_value
Exemplo n.º 15
0
    def step(self, batch, prefix: str, model=None) -> Dict:
        batch = AtomsBatch(**batch)
        y_true = batch.R_orig

        if model is None:
            y_pred = self.forward(batch)[1]
        else:
            y_pred = model(batch)[1]

        # predict residual
        y_pred = y_pred + batch.R

        assert y_pred.shape == y_true.shape, f'{y_pred.shape}, {y_true.shape}'

        # TODO try to group by molecule.
        mae = mae_loss(torch.pdist(y_pred), torch.pdist(y_true))
        loss = torch.log(mae)
        size = len(y_true)

        return {
            f'{prefix}_loss': loss,
            f'{prefix}_mae': mae,
            f'{prefix}_size': size,
        }
Exemplo n.º 16
0
def test_isomorphism(loader, model, device, p=2, eps=1e-2):

    model.eval()
    y = None
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            y = torch.cat(
                (y, model(data)), 0) if y is not None else model(data)

    mm = torch.pdist(y, p=p)
    num_not_distinguished = (mm < eps).sum().item()

    # inds = np.where((squareform(mm.cpu().numpy()) + np.diag(np.ones(y.shape[0]))) < eps)
    # print('Non-isomorphic pairs that are not distinguised: {}'.format(inds))

    return mm, num_not_distinguished
Exemplo n.º 17
0
    def get_pairs(self, embeddings, labels):
        if self.cpu:
            embeddings = embeddings.cpu()
        dm = torch.pdist(embeddings)
        labels = labels.cpu().data.numpy()
        all_pairs = torch.LongTensor(utils.comb_index(len(labels), 2))
        pos_pairs = all_pairs[(
            labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
        neg_pairs = all_pairs[(labels[all_pairs[:, 0]] !=
                               labels[all_pairs[:, 1]]).nonzero()]

        dists = dm[condensed_index(neg_pairs[:, 0], neg_pairs[:, 1],
                                   embeddings.shape[0])].cpu()
        # find most similar negatives
        _, top = torch.topk(dists, len(pos_pairs), largest=False)
        neg_pairs = neg_pairs[torch.LongTensor(top)]

        return pos_pairs, neg_pairs
Exemplo n.º 18
0
def find_neg_anchors(e_actv, e_ap, discriminator):
    """Find negative anchors within a batch. 
    Embeddings with same discriminator are removed
    """
    # Computing distance matrix
    n = len(e_actv)
    dm = torch.pdist(e_actv)
    # Converting tu full nxn matrix
    tri = torch.zeros((n, n))
    tri[np.triu_indices(n, 1)] = dm
    fmatrix = torch.tril(tri.T, 1) + tri
    # Removing diagonal
    fmatrix += sys.maxsize * (torch.eye(n, n))
    # Getting the minimum
    idxs = fast_filter(fmatrix, discriminator)
    dn = e_actv[idxs]

    return dn
Exemplo n.º 19
0
def col_diff(x: Union[torch.Tensor, QTensor]) -> float:
    r"""
    Implementation of the col_diff value from the paper "PAIRNORM: TACKLING OVERSMOOTHING IN GNNS" (ICLR 2020) by
    Lingxiao Zhao and Leman Akoglu

    The col-diff is the average of pairwise distances between (L1-normalized1) columns of the representation matrix
    and quantifies feature-wise oversmoothing.
    """
    if x.__class__.__name__ == "QTensor":
        x = x.stack(dim=0).permute(1, 0, 2)

    x = x.reshape(x.size(0), -1)
    col_norms = x.norm(p=1, dim=0, keepdim=True)
    xnormed = x / col_norms
    xnormed = xnormed.t()
    pdist = torch.pdist(xnormed, p=2)
    col_diff_value = (pdist.sum() / pdist.numel()).item()

    return col_diff_value
Exemplo n.º 20
0
def test_squareform():
        import time
        a = torch.randn(32,500,256)
        time1 = time.time()
        distmat1 = torch.norm(a.unsqueeze(1) - a.unsqueeze(2), dim=3)
        print('implmentation1 time '+str(time.time()-time1))
        time2 = time.time()
        dist = torch.pdist(a)
        distmat2 = torch_squareform(dist)
        print('implmentation2 time '+str(time.time()-time2))
        time3 = time.time()
        b = (a**2).sum(2)
        c = (a**2).sum(2)
        d = torch.matmul(a, a.transpose(1,2))
        #distmat3 = torch.pow(b.unsqueeze(1)+c.unsqueeze(2)-2*d, 0.5)
        distmat3 = torch.sqrt(b.unsqueeze(1)+c.unsqueeze(2)-2*d)
        distmat3[torch.isnan(distmat3)] = 0
        print('implmentation3 time '+str(time.time()-time3))
        assert (distmat1-distmat2).abs().max() < 1e-5
        #import pdb;pdb.set_trace()
        #import IPython;IPython.embed() 
        print((distmat1-distmat3).abs().max())
        assert (distmat1-distmat3).abs().max() < 5e-2
Exemplo n.º 21
0
    def _cluster_batch(self, p, features):
        """
        Cluster a batch of frames
        
        Args:
            p (torch.Tensor): detections
            features (torch.Tensor): features

        Returns:

        """

        clusterer = self._clusterer
        p_aggregation = self.p_aggregation

        if p.size(1) > 1:
            raise ValueError("Not Supported shape for propbabilty.")
        p_out = torch.zeros_like(p).view(p.size(0), p.size(1), -1)
        feat_out = features.clone().view(features.size(0), features.size(1),
                                         -1)
        """Frame wise clustering."""
        for i in range(features.size(0)):
            ix = p[i, 0] > 0
            if (ix == 0.).all().item():
                continue
            alg_ix = (p[i].view(-1) > 0).nonzero().squeeze(1)

            p_frame = p[i, 0, ix].view(-1)
            f_frame = features[i, :, ix]
            # flatten samples and put them in the first dim
            f_frame = f_frame.reshape(f_frame.size(0), -1).permute(1, 0)

            filter_mask = self._filter(f_frame[:, 1:4], f_frame[:, 1:4])
            if self.match_dims == 2:
                dist_mat = torch.pdist(f_frame[:, 1:3])
            elif self.match_dims == 3:
                dist_mat = torch.pdist(f_frame[:, 1:4])
            else:
                raise ValueError

            dist_mat = torch.from_numpy(
                scipy.spatial.distance.squareform(dist_mat))
            dist_mat[
                ~filter_mask] = 999999999999.  # those who shall not match shall be separated, only finite vals ...

            if dist_mat.shape[0] == 1:
                warnings.warn(
                    "I don't know how this can happen but there seems to be a"
                    " single an isolated difficult case ...",
                    stacklevel=3)
                n_clusters = 1
                labels = torch.tensor([0])
            else:
                clusterer.fit(dist_mat)
                n_clusters = clusterer.n_clusters_
                labels = torch.from_numpy(clusterer.labels_)

            for c in range(n_clusters):
                in_cluster = labels == c
                feat_ix = alg_ix[in_cluster]
                if p_aggregation == 'sum':
                    p_agg = p_frame[in_cluster].sum()
                elif p_aggregation == 'max':
                    p_agg = p_frame[in_cluster].max()
                elif p_aggregation == 'pbinom_cdf':
                    z = binom_pdiverse(p_frame[in_cluster].view(-1))
                    p_agg = z[1:].sum()
                elif p_aggregation == 'pbinom_pdf':
                    z = binom_pdiverse(p_frame[in_cluster].view(-1))
                    p_agg = z[1]
                else:
                    raise ValueError

                p_out[i, 0, feat_ix[
                    0]] = p_agg  # only set first element to some probability
                """Average the features."""
                feat_av = (feat_out[i, :, feat_ix] * p_frame[in_cluster]
                           ).sum(1) / p_frame[in_cluster].sum()
                feat_out[i, :, feat_ix] = feat_av.unsqueeze(1).repeat(
                    1, in_cluster.sum())

        return p_out.reshape(p.size()), feat_out.reshape(features.size())
Exemplo n.º 22
0
def uniform_loss(x, t=2):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
Exemplo n.º 23
0
 def lunif(x):
     sq_pdist = torch.pdist(x, p=2).pow(2)
     return sq_pdist.mul(-self.t).exp().mean().log()
Exemplo n.º 24
0
    def forward(self, im_q, im_k):
        r"""
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            MoCoLosses object containing the loss terms (and logits if contrastive loss is used)
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = F.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = F.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        moco_loss_ctor_dict = {}

        # lazyily computed & cached!
        def get_q_bdot_k():
            if not hasattr(get_q_bdot_k, 'result'):
                get_q_bdot_k.result = (q * k).sum(dim=1)
            assert get_q_bdot_k.result._version == 0
            return get_q_bdot_k.result

        # lazyily computed & cached!
        def get_q_dot_queue():
            if not hasattr(get_q_dot_queue, 'result'):
                get_q_dot_queue.result = q @ self.queue.clone().detach()
            assert get_q_dot_queue.result._version == 0
            return get_q_dot_queue.result

        # l_contrastive
        if self.contr_tau is not None:
            # compute logits
            # Einstein sum is more intuitive
            # positive logits: Nx1
            l_pos = get_q_bdot_k().unsqueeze(-1)
            # negative logits: NxK
            l_neg = get_q_dot_queue()

            # logits: Nx(1+K)
            logits = torch.cat([l_pos, l_neg], dim=1)

            # apply temperature
            logits /= self.contr_tau

            moco_loss_ctor_dict['logits_contr'] = logits
            moco_loss_ctor_dict['loss_contr'] = F.cross_entropy(
                logits, self.scalar_label.expand(logits.shape[0]))

        # l_align
        if self.align_alpha is not None:
            if self.align_alpha == 2:
                moco_loss_ctor_dict[
                    'loss_align'] = 2 - 2 * get_q_bdot_k().mean()
            elif self.align_alpha == 1:
                moco_loss_ctor_dict['loss_align'] = (q - k).norm(dim=1,
                                                                 p=2).mean()
            else:
                moco_loss_ctor_dict['loss_align'] = (2 -
                                                     2 * get_q_bdot_k()).pow(
                                                         self.align_alpha /
                                                         2).mean()

        # l_uniform
        if self.unif_t is not None:
            sq_dists = (2 - 2 * get_q_dot_queue()).flatten()
            if self.unif_intra_batch:
                sq_dists = torch.cat([sq_dists, torch.pdist(q, p=2).pow(2)])
            moco_loss_ctor_dict['loss_unif'] = sq_dists.mul(
                -self.unif_t).exp().mean().log()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return MoCoLosses(**moco_loss_ctor_dict)
Exemplo n.º 25
0
def train_epoch(cfg,
                epoch,
                model,
                device,
                train_loader,
                optimizer,
                criterion,
                limit=None):
    model.train()

    if cfg.taskweights:
        weights = np.nansum(train_loader.dataset.labels, axis=0)
        weights = weights.max() - weights + weights.mean()
        weights = weights / weights.max()
        weights = torch.from_numpy(weights).to(device).float()
        print("task weights", weights)

    avg_loss = []
    t = tqdm(train_loader)
    for batch_idx, samples in enumerate(t):

        if limit and (batch_idx > limit):
            print("breaking out")
            break

        optimizer.zero_grad()

        images = samples["img"].float().to(device)
        targets = samples["lab"].to(device)

        outputs = model(images)

        loss = torch.zeros(1).to(device).float()
        for task in range(targets.shape[1]):
            task_output = outputs[:, task]
            task_target = targets[:, task]
            mask = ~torch.isnan(task_target)
            task_output = task_output[mask]
            task_target = task_target[mask]
            if len(task_target) > 0:
                task_loss = criterion(task_output.float(), task_target.float())
                if cfg.taskweights:
                    loss += weights[task] * task_loss
                else:
                    loss += task_loss

        # here regularize the weight matrix when label_concat is used
        if cfg.label_concat_reg:
            if not cfg.label_concat:
                raise Exception("cfg.label_concat must be true")
            weight = model.classifier.weight
            num_labels = len(xrv.datasets.default_pathologies)
            num_datasets = weight.shape[0] // num_labels
            weight_stacked = weight.reshape(num_datasets, num_labels, -1)
            label_concat_reg_lambda = torch.tensor(0.1).to(device).float()
            for task in range(num_labels):
                dists = torch.pdist(weight_stacked[:, task], p=2).mean()
                loss += label_concat_reg_lambda * dists

        loss = loss.sum()

        if cfg.featurereg:
            feat = model.features(images)
            loss += feat.abs().sum()

        if cfg.weightreg:
            loss += model.classifier.weight.abs().sum()

        loss.backward()

        avg_loss.append(loss.detach().cpu().numpy())
        t.set_description(
            f'Epoch {epoch + 1} - Train - Loss = {np.mean(avg_loss):4.4f}')

        optimizer.step()

    return np.mean(avg_loss)
Exemplo n.º 26
0
def uniform_distribution(x, t=2):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().log()
Exemplo n.º 27
0
def uniformity_loss(x1: Tensor, x2: Tensor, t=2) -> Tensor:
    sq_pdist_x1 = torch.pdist(x1, p=2).pow(2)
    uniformity_x1 = sq_pdist_x1.mul(-t).exp().mean().log()
    sq_pdist_x2 = torch.pdist(x2, p=2).pow(2)
    uniformity_x2 = sq_pdist_x2.mul(-t).exp().mean().log()
    return (uniformity_x1 + uniformity_x2) / 2
Exemplo n.º 28
0
def gp_torch_train(train_x: Tensor,
                   train_y: Tensor,
                   n_inducing_points: int,
                   tkwargs: Dict[str, Any],
                   init,
                   scale: bool,
                   covar_name: str,
                   gp_file: Optional[str],
                   save_file: str,
                   input_wp: bool,
                   outcome_transform: Optional[OutcomeTransform] = None,
                   options: Dict[str, Any] = None) -> SingleTaskGP:
    assert train_y.ndim > 1, train_y.shape
    assert gp_file or init, (gp_file, init)
    likelihood = gpytorch.likelihoods.GaussianLikelihood()

    if init:
        # build hyp
        print("Initialize GP hparams...")
        print("Doing Kmeans init...")
        assert n_inducing_points > 0, n_inducing_points
        kmeans = MiniBatchKMeans(n_clusters=n_inducing_points,
                                 batch_size=min(10000, train_x.shape[0]),
                                 n_init=25)
        start_time = time.time()
        kmeans.fit(train_x.cpu().numpy())
        end_time = time.time()
        print(f"K means took {end_time - start_time:.1f}s to finish...")
        inducing_points = torch.from_numpy(kmeans.cluster_centers_.copy())

        output_scale = None
        if scale:
            output_scale = train_y.var().item()
        lscales = torch.empty(1, train_x.shape[1])
        for i in range(train_x.shape[1]):
            lscales[0, i] = torch.pdist(train_x[:, i].view(
                -1, 1)).median().clamp(min=0.01)
        base_covar_module = query_covar(covar_name=covar_name,
                                        scale=scale,
                                        outputscale=output_scale,
                                        lscales=lscales)

        covar_module = InducingPointKernel(base_covar_module,
                                           inducing_points=inducing_points,
                                           likelihood=likelihood)

        input_warp_tf = None
        if input_wp:
            # Apply input warping
            # initialize input_warping transformation
            input_warp_tf = CustomWarp(
                indices=list(range(train_x.shape[-1])),
                # use a prior with median at 1.
                # when a=1 and b=1, the Kumaraswamy CDF is the identity function
                concentration1_prior=LogNormalPrior(0.0, 0.75**0.5),
                concentration0_prior=LogNormalPrior(0.0, 0.75**0.5),
            )

        model = SingleTaskGP(train_x,
                             train_y,
                             covar_module=covar_module,
                             likelihood=likelihood,
                             input_transform=input_warp_tf,
                             outcome_transform=outcome_transform)
    else:
        # load model
        output_scale = 1  # will be overwritten when loading model
        lscales = torch.ones(
            train_x.shape[1])  # will be overwritten when loading model
        base_covar_module = query_covar(covar_name=covar_name,
                                        scale=scale,
                                        outputscale=output_scale,
                                        lscales=lscales)
        covar_module = InducingPointKernel(base_covar_module,
                                           inducing_points=torch.empty(
                                               n_inducing_points,
                                               train_x.shape[1]),
                                           likelihood=likelihood)

        input_warp_tf = None
        if input_wp:
            # Apply input warping
            # initialize input_warping transformation
            input_warp_tf = Warp(
                indices=list(range(train_x.shape[-1])),
                # use a prior with median at 1.
                # when a=1 and b=1, the Kumaraswamy CDF is the identity function
                concentration1_prior=LogNormalPrior(0.0, 0.75**0.5),
                concentration0_prior=LogNormalPrior(0.0, 0.75**0.5),
            )
        model = SingleTaskGP(train_x,
                             train_y,
                             covar_module=covar_module,
                             likelihood=likelihood,
                             input_transform=input_warp_tf,
                             outcome_transform=outcome_transform)
        print("Loading GP from file")
        state_dict = torch.load(gp_file)
        model.load_state_dict(state_dict)

    print("GP regression")
    start_time = time.time()
    model.to(**tkwargs)
    model.train()

    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    # set approx_mll to False since we are using an exact marginal log likelihood
    # fit_gpytorch_model(mll, optimizer=fit_gpytorch_torch, approx_mll=False, options=options)
    fit_gpytorch_torch(mll,
                       options=options,
                       approx_mll=False,
                       clip_by_value=True if input_wp else False,
                       clip_value=10.0)
    end_time = time.time()
    print(f"Regression took {end_time - start_time:.1f}s to finish...")

    print("Save GP model...")
    torch.save(model.state_dict(), save_file)
    print("Done training of GP.")

    model.eval()
    return model
Exemplo n.º 29
0
def exp_dec_error_pytorch_2(x, t=2):
    sq_pdist = torch.pdist(x, p=2).pow(2)
    return sq_pdist.mul(-t).exp().mean()
Exemplo n.º 30
0
 def lunif(self, x, t=2):
     sq_dist = torch.pdist(x, p=2).pow(2)
     return sq_dist.mul(-t).exp().mean().log()