Beispiel #1
0
def lstsq(b, y, alpha=0.01):
    """
    Batched linear least-squares for pytorch with optional L1 regularization.

    Parameters
    ----------

    b : shape(L, M, N)
    y : shape(L, M)

    Returns
    -------
    tuple of (coefficients, model, residuals)

    """
    bT = b.transpose(-1, -2)
    AA = torch.bmm(bT, b)
    if alpha != 0:
        diag = torch.diagonal(AA, dim1=1, dim2=2)
        diag += alpha
    RHS = torch.bmm(bT, y[:, :, None])
    X, LU = torch.gesv(RHS, AA)
    fit = torch.bmm(b, X)[..., 0]
    res = y - fit
    return X[..., 0], fit, res
Beispiel #2
0
    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = 1  #* kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, input, reduction='sum')

        kld_loss = torch.sum(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1),
            dim=0)

        # DIP Loss
        centered_mu = mu - mu.mean(dim=1, keepdim=True)  # [B x D]
        cov_mu = centered_mu.t().matmul(centered_mu).squeeze()  # [D X D]

        # Add Variance for DIP Loss II
        cov_z = cov_mu + torch.mean(torch.diagonal(
            (2. * log_var).exp(), dim1=0),
                                    dim=0)  # [D x D]
        # For DIp Loss I
        # cov_z = cov_mu

        cov_diag = torch.diag(cov_z)  # [D]
        cov_offdiag = cov_z - torch.diag(cov_diag)  # [D x D]
        dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \
                   self.lambda_diag * torch.sum((cov_diag - 1) ** 2)

        loss = recons_loss + kld_weight * kld_loss + dip_loss
        return {
            'loss': loss,
            'Reconstruction_Loss': recons_loss,
            'KLD': -kld_loss,
            'DIP_Loss': dip_loss
        }
    def forward(self, z_a, z_b):
        """Forward head.

        Args:
            z_a (Tensor): NxD representation from one randomly augmented image.
            z_b (Tensor): NxD representation from another version of augmentation.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        N, D = z_a.shape

        # # normalize repr. along the batch dimension
        # z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD
        # z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD
        # if self.dimension == 'D':
        #     # cross-correlation matrix
        #     c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD
        #     # loss
        #     c_diff = (c - torch.eye(D).cuda()).pow(2) # DxD
        #     # multiply off-diagonal elems of c_diff by lambda
        #     c_diff[~torch.eye(D, dtype=bool).cuda()] *= self.lambd
        # elif self.dimension == 'N':
        #     # auto-correlation matrix
        #     c = torch.mm(z_a_norm, z_b_norm.T) / N # NxN
        #     # loss
        #     c_diff = (c - torch.eye(N).cuda()).pow(2) # NxN
        #     # multiply off-diagonal elems of c_diff by lambda
        #     c_diff[~torch.eye(N, dtype=bool).cuda()] *= self.lambd
        # loss = c_diff.sum()

        # empirical cross-correlation matrix
        c = self.bn(z_a).T @ self.bn(z_b)
        # sum the cross-correlation matrix between all gpus
        c.div_(N)
        torch.distributed.all_reduce(c)
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.lambd * off_diag

        losses = dict()
        losses["loss"] = loss
        return losses
Beispiel #4
0
 def forward(self, X):
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     # Initialize r and i nodes (X.shape[0] = n_batches, X, shape[1] = n_time_steps)
     R = torch.zeros((X.shape[0], self.visible_size, self.full_size),
                     dtype=torch.float32).to(device)
     Y = torch.zeros((X.shape[0], self.visible_size, self.full_size),
                     dtype=torch.float32).to(device)
     Lambda_batch = self.Lambda.repeat(X.shape[0],
                                       1).view(X.shape[0],
                                               self.visible_size,
                                               self.full_size).to(device)
     # Forward path
     for t in range(X.shape[1]):
         Y[:, :, :self.visible_size] = X[:, t, :].repeat(self.visible_size, 1)\
             .view(self.visible_size, X.shape[0], self.visible_size).transpose(0, 1)
         U = torch.mul(Lambda_batch, self.W(R)) + torch.mul(
             (1 - Lambda_batch), Y)
         R = self.phi(U)
     return torch.diagonal(U[:, :, :self.visible_size], dim1=-1, dim2=-2)
Beispiel #5
0
    def forward(self, x):
        # in: N x d x m x m
        # out: N x (d * basis) x m
        N = x.size(0)
        m = x.size(-1)
        diag_part = torch.diagonal(x, dim1=2, dim2=3)
        max_diag_part = torch.max(diag_part, 2)[0].unsqueeze(-1)
        max_of_rows = torch.max(x, 3)[0]
        max_of_cols = torch.max(x, 2)[0]
        max_all = torch.max(torch.max(x, 2)[0], 2)[0].unsqueeze(-1)

        op1 = diag_part
        op2 = max_diag_part.expand_as(op1)
        op3 = max_of_rows
        op4 = max_of_cols
        op5 = max_all.expand_as(op1)

        return torch.stack([op1, op2, op3, op4,
                            op5]).permute(1, 0, 2, 3).reshape(N, -1, m)
Beispiel #6
0
    def forward(self, y1, y2):
        z1 = self.projector(self.backbone(y1))
        z2 = self.projector(self.backbone(y2))

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(self.per_device_batch_size * self.trainer.num_processes)
        self.all_reduce(c)

        # use --scale-loss to multiply the loss by a constant factor
        # In order to match the code that was used to develop Barlow Twins,
        # the authors included an additional parameter, --scale-loss,
        # that multiplies the loss by a constant factor.
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.scale_loss)
        off_diag = self.off_diagonal(c).pow_(2).sum().mul(self.scale_loss)
        loss = on_diag + self.lambd * off_diag
        return loss
Beispiel #7
0
def train(net, data_loader, train_optimizer):
    net.train()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for data_tuple in train_bar:
        (pos_1, pos_2), _ = data_tuple
        pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(
            non_blocking=True)
        feature_1, out_1 = net(pos_1)
        feature_2, out_2 = net(pos_2)
        # Barlow Twins

        # normalize the representations along the batch dimension
        out_1_norm = (out_1 - out_1.mean(dim=0)) / out_1.std(dim=0)
        out_2_norm = (out_2 - out_2.mean(dim=0)) / out_2.std(dim=0)

        # cross-correlation matrix
        c = torch.matmul(out_1_norm.T, out_2_norm) / batch_size

        # loss
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        if corr_neg_one is False:
            # the loss described in the original Barlow Twin's paper
            # encouraging off_diag to be zero
            off_diag = off_diagonal(c).pow_(2).sum()
        else:
            # inspired by HSIC
            # encouraging off_diag to be negative ones
            off_diag = off_diagonal(c).add_(1).pow_(2).sum()
        loss = on_diag + lmbda * off_diag

        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += batch_size
        total_loss += loss.item() * batch_size
        if corr_neg_one is True:
            off_corr = -1
        else:
            off_corr = 0
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f} off_corr:{} lmbda:{:.4f} bsz:{} f_dim:{} dataset: {}'.format(\
                                epoch, epochs, total_loss / total_num, off_corr, lmbda, batch_size, feature_dim, dataset))
    return total_loss / total_num
Beispiel #8
0
def contrastive_loss(z1, z2, temp=1):
    # get unit vectors
    z1_unit = z1 / torch.norm(z1, p=2, dim=1).view(-1, 1)
    z2_unit = z2 / torch.norm(z2, p=2, dim=1).view(-1, 1)

    # compute z_i * z_j for all i,j in z1
    intra_cos_sims = z1_unit @ z1_unit.T

    # compute z_i * z_k for all i in z1, k in z2
    inter_cos_sims = z1_unit @ z2_unit.T
    cos_sims = torch.cat((intra_cos_sims, inter_cos_sims), dim=1) / temp

    # compute cross-entropy loss term
    # subtract out e to remove the zi * zi term from intra_cos_sims
    exp_numerator = torch.exp(torch.diagonal(inter_cos_sims))
    sum_exp = torch.sum(torch.exp(cos_sims), dim=1) - np.e
    losses = -torch.log(exp_numerator / sum_exp)

    return torch.mean(losses)
Beispiel #9
0
 def forward(self, bert_token_b, bert_segment_b, bert_masks_b,
             bert_clause_b, doc_len, adj, y_mask_b):
     bert_output = self.bert(input_ids=bert_token_b.to(DEVICE),
                             attention_mask=bert_masks_b.to(DEVICE),
                             token_type_ids=bert_segment_b.to(DEVICE))
     doc_sents_h = self.batched_index_select(bert_output, bert_clause_b.to(DEVICE))
     reduced_doc_sents_h = self.reduce(doc_sents_h)
     X = self.seq2mat(reduced_doc_sents_h, reduced_doc_sents_h) # (B, N, N, H)
     X = F.relu(X, inplace=True)
     masks = torch.IntTensor(y_mask_b).to(DEVICE)
     masks = masks.unsqueeze(1) & masks.unsqueeze(2)
     T, _ = self.mdrnn(X, states=None, masks=masks)
     T_diagonal = torch.diagonal(T, offset=0, dim1=1, dim2=2)
     diagonal = T_diagonal.permute(0, 2, 1)
     
     pred_e, pred_c = self.pred(diagonal)
     couples_pred = self.tab_pred(T)
     
     return couples_pred, pred_e, pred_c
    def backward(ctx, dK):
        r"""
        Backward function from the affinity matrix :math:`\mathbf K` to node-wise affinity matrix :math:`\mathbf K_e`
        and edge-wize affinity matrix :math:`\mathbf K_e`.
        """
        device = dK.device
        Ke, Kp = ctx.saved_tensors
        Kro1t, Kro2t = ctx.K
        dKe = dKp = None
        if ctx.needs_input_grad[0]:
            dKe = bilinear_diag_torch(Kro1t, dK.contiguous(), Kro2t)
            dKe = dKe.view(Ke.shape[0], Ke.shape[2],
                           Ke.shape[1]).transpose(1, 2)
        if ctx.needs_input_grad[1]:
            dKp = torch.diagonal(dK, dim1=-2, dim2=-1)
            dKp = dKp.view(Kp.shape[0], Kp.shape[2],
                           Kp.shape[1]).transpose(1, 2)

        return dKe, dKp, None, None, None, None
Beispiel #11
0
    def test(self, validation=False):
        self.netF.eval()
        self.netC.eval()

        test_loss = 0
        correct = 0
        size = 0
        num_class = self.nclasses
        output_all = np.zeros((0, num_class))
        confusion_matrix = torch.zeros(num_class, num_class)

        if validation:
            loader = self.source_val_loader
        else:
            loader = self.target_loader

        with torch.no_grad():
            for batch_idx, data_t in enumerate(loader):
                imgs, labels = data_t
                imgs = imgs.cuda()
                labels = labels.cuda()

                feat = self.netF(imgs)
                logits = self.netC(feat)
                output_all = np.r_[output_all, logits.data.cpu().numpy()]
                size += imgs.size(0)
                pred = logits.data.max(1)[
                    1]  # get the index of the max log-probability
                for t, p in zip(labels.view(-1), pred.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
                correct += pred.eq(labels.data).cpu().sum()
                test_loss += self.criterion(logits, labels) / len(loader)
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} C ({:.0f}%)\n'.
            format(test_loss, correct, size, 100. * (float(correct) / size)))
        mean_class_acc = torch.diagonal(confusion_matrix).float() / torch.sum(
            confusion_matrix, dim=1)
        mean_class_acc = mean_class_acc * 100.0
        print('Classwise accuracy')
        print(mean_class_acc)
        mean_class_acc = torch.mean(mean_class_acc)
        net_class_acc = 100. * float(correct) / size
        return test_loss.data, mean_class_acc, net_class_acc
Beispiel #12
0
def banddiag(orig_x, lu, ld, fill=0):
    s1 = list(orig_x.shape)
    s2 = list(orig_x.shape)
    x = orig_x
    s1[-2] = lu
    s2[-2] = ld
    x = torch.cat(
        [
            torch.zeros(*s1, device=x.device, dtype=x.dtype),
            x,
            torch.zeros(*s2, device=x.device, dtype=x.dtype),
        ],
        dim=-2,
    )
    unf = x.unfold(-2, lu + ld + 1, 1)
    return (
        torch.diagonal(unf, 0, -3, -2).transpose(-2, -1),
        x.narrow(-2, lu, orig_x.shape[-2]),
    )
Beispiel #13
0
    def forward(self, x, labels):
        '''
        input shape (N, in_features)
        '''
        assert len(x) == len(labels)
        assert torch.min(labels) >= 0
        assert torch.max(labels) < self.out_features
        
        for W in self.fc.parameters():
            W = F.normalize(W, dim=1)

        x = F.normalize(x, dim=1)

        wf = self.fc(x)
        numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
        excl = torch.cat([torch.cat((wf[i, :y], wf[i, y+1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
        denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
        L = numerator - torch.log(denominator)
        return -torch.mean(L)
Beispiel #14
0
    def test_to_hermitian(self):
        m = torch.rand(4, 2, 3, 3)

        h = sm.to_hermitian(m)

        h_real, h_imag = sm.real(h), sm.imag(h)

        # 1 - real part is symmetric
        self.assertAllEqual(h_real, h_real.transpose(-1, -2))

        # 2 - Imaginary diagonal must be 0
        imag_diag = torch.diagonal(h_imag, dim1=-2, dim2=-1)
        self.assertAllEqual(imag_diag, torch.zeros_like(imag_diag))

        # 3 - imaginary elements in the upper triangular part of the matrix must be of opposite sign than the
        # elements in the lower triangular part
        imag_triu = torch.triu(h_imag, diagonal=1)
        imag_tril = torch.tril(h_imag, diagonal=-1)
        self.assertAllEqual(imag_triu, imag_tril.transpose(-1, -2) * -1)
Beispiel #15
0
        def keep(state1, state2):
            valid_state1 = torch.isfinite(state1[:, 0])
            valid_state2 = torch.isfinite(state2[:, 0])

            small_distance = ((PedPedPotential.norm_r_ab(
                PedPedPotential.r_ab(state1)) < radius)
                              | (PedPedPotential.norm_r_ab(
                                  PedPedPotential.r_ab(state2)) < radius))
            torch.diagonal(small_distance)[:] = False
            small_distance = torch.any(small_distance, dim=-1)

            acc = (torch.abs(state1[:, 4:6]) > acc_abs) | (torch.abs(
                state2[:, 4:6]) > acc_abs)
            acc = torch.any(acc, dim=-1)
            # keep 10% of samples without acc:
            acc[~acc] = (torch.rand(acc[~acc].shape) < 0.1)
            acc[:] = torch.any(acc)  # symmetrize

            return valid_state1 & valid_state2 & small_distance & acc
Beispiel #16
0
    def decode(self, spin_up, spin_down, num_qubits):
        # spin_up and spin_down (1,mps.size,mps.size)
        mps = {'0': spin_up, '1': spin_down}

        coeffs = []

        states = seq_gen(num_qubits)
        for state in states:
            mat = mps[state[0]]

            for site in state[1:]:
                mat = torch.matmul(mat, mps[site])
            diagonal = torch.diagonal(mat, dim1=-1, dim2=-2)
            coeffs.append(torch.sum(diagonal, dim=-1, keepdim=True))

        c_i = coeffs[0]
        for i in coeffs[1:]:
            c_i = torch.cat((c_i, i), dim=1)
        return c_i.squeeze()
Beispiel #17
0
 def forward(self, x):
     iv = self.di_pool(x)
     iv = torch.diagonal(iv, dim1=2, dim2=3)
     iv = torch.log2(iv / torch.mean(iv))
     top = self.top_pool(iv[:, :, self.deriv_size:])
     bottom = self.bottom_pool(iv[:, :, :-self.deriv_size])
     dv = (top - bottom)
     left = torch.cat([torch.zeros(dv.shape[0], dv.shape[1], 2), dv], dim=2)
     right = torch.cat([dv, torch.zeros(dv.shape[0], dv.shape[1], 2)],
                       dim=2)
     band = ((left < 0) == torch.ones_like(left)) * (
         (right > 0) == torch.ones_like(right))
     band = band[:, :, 2:-2]
     boundaries = []
     for i in range(0, band.shape[0]):
         cur_bound = torch.where(
             band[i, 0])[0] + self.window_radius + self.deriv_size
         boundaries.append(cur_bound)
     return iv, dv, boundaries
 def forward(self, cov: Tensor) -> Tensor:
     # cov, var = functional.spd_insurance(cov)
     var = torch.diagonal(cov, dim1=-1, dim2=-2)
     if self.training:
         u = torch.mean(var, dim=0, keepdim=True)
         self.track_mean(u)
         if self.weight is not None:
             weight = self.weight / torch.sqrt(u + self.eps)
         else:
             weight = 1 / torch.sqrt(u + self.eps)
     else:
         if self.weight is not None:
             weight = self.weight / torch.sqrt(self.running_mean + self.eps)
         else:
             weight = 1 / torch.sqrt(self.running_mean + self.eps)
     cov = cov * torch.matmul(weight.unsqueeze(-1), weight.unsqueeze(-2))
     if self.bias_var is not None:
         cov = cov + mnn_core.nn.functional.var2cov(self.bias_var)
     return cov
Beispiel #19
0
    def update(self, y_pred: torch.Tensor, cat_c: torch.Tensor):
        if self.k is None:
            order = torch.argsort(input=y_pred, descending=True)
        else:
            sequence_length = y_pred.shape[0]
            if sequence_length < self.k:
                k = sequence_length
            else:
                k = self.k
            _, order = torch.topk(input=y_pred, k=k, largest=True)

        cat_c = cat_c[order]
        cat_score = F.normalize(cat_c, dim=-1)
        cat_score = cat_score @ cat_score.T
        cat_score = torch.sum(cat_score) - torch.sum(
            torch.diagonal(cat_score, 0))
        cat_score = cat_score / (cat_c.size(0) * (cat_c.size(0) - 1))
        self.ils_cat += cat_score
        self.count += 1.0
Beispiel #20
0
    def forward(ctx, input):

        # LUP decompose the matrices
        inp_lu, pivots = input.lu()
        perm, inpl, inpu = torch.lu_unpack(inp_lu, pivots)

        # get the number of permuations
        s = (pivots != torch.as_tensor(range(
            1, input.shape[1] + 1)).int()).sum(1).type(
                torch.get_default_dtype())

        # get the prod of the diag of U
        d = torch.diagonal(inpu, dim1=-2, dim2=-1).prod(1)

        # assemble
        det = ((-1)**s * d)
        ctx.save_for_backward(input, det)

        return det
Beispiel #21
0
def _F1F2(mean: Tensor, cov: Tensor, lower: Tensor,
          upper: Tensor) -> Tuple[Tensor, Tensor]:
    is_cens_up = torch.isfinite(upper)
    is_cens_lo = torch.isfinite(lower)

    std = torch.diagonal(cov, dim1=-2, dim2=-1).sqrt()

    # mask out the infs before any gradients are being tracked:
    alpha = torch.zeros_like(mean)
    alpha[is_cens_lo] = (lower[is_cens_lo] -
                         mean[is_cens_lo]) / std[is_cens_lo]
    beta = torch.zeros_like(mean)
    beta[is_cens_up] = (upper[is_cens_up] - mean[is_cens_up]) / std[is_cens_up]

    # _F1F2_no_inf unstable for large z-scores, so use the lim(+/-inf) version for those as well
    is_cens_up = is_cens_up & (beta.data < 4.)
    is_cens_lo = is_cens_lo & (alpha.data > -4.)
    is_cens_both = is_cens_up & is_cens_lo

    #
    sqrt_2 = 2.**.5
    x = alpha / sqrt_2
    y = beta / sqrt_2

    # uncensored
    F1, F2 = torch.zeros_like(mean), torch.zeros_like(mean)

    # censored both:
    F1[is_cens_both], F2[is_cens_both] = _F1F2_no_inf(x[is_cens_both],
                                                      y[is_cens_both])

    # censored lower, uncensored upper:
    F1[is_cens_lo & ~is_cens_up] = 1. / erfcx(x[is_cens_lo & ~is_cens_up])
    F2[is_cens_lo & ~is_cens_up] = x[is_cens_lo & ~is_cens_up] / erfcx(
        x[is_cens_lo & ~is_cens_up])

    # uncensored lower, censored upper:
    F1[~is_cens_lo & is_cens_up] = -1. / erfcx(-y[~is_cens_lo & is_cens_up])
    F2[~is_cens_lo
       & is_cens_up] = -y[~is_cens_lo & is_cens_up] / erfcx(-y[~is_cens_lo
                                                               & is_cens_up])

    return F1, F2
Beispiel #22
0
def _sampson_dist(F, X, Y, if_homo=False):
    if not if_homo:
        X = utils_misc._homo(X)
        Y = utils_misc._homo(Y)
    if len(X.size()) == 2:
        nominator = (torch.diag(Y @ F @ X.t()))**2
        Fx1 = torch.mm(F, X.t())
        Fx2 = torch.mm(F.t(), Y.t())
        denom = Fx1[0]**2 + Fx1[1]**2 + Fx2[0]**2 + Fx2[1]**2
    else:
        nominator = (torch.diagonal(Y @ F @ X.transpose(1, 2), dim1=1,
                                    dim2=2))**2
        Fx1 = torch.matmul(F, X.transpose(1, 2))
        Fx2 = torch.matmul(F.transpose(1, 2), Y.transpose(1, 2))
        denom = Fx1[:, 0]**2 + Fx1[:, 1]**2 + Fx2[:, 0]**2 + Fx2[:, 1]**2
        # print(nominator.size(), denom.size())

    errors = nominator / denom
    return errors
    def re_loss_direct(self, x):
        log_term = (self.D/2) * (np.log(2*np.pi) + 2*self.decoder.log_s)
        s = torch.exp(self.decoder.log_s)

        x_mat = torch.matmul(self.decoder.W.weight, self.encoder.M.weight) - torch.eye(self.D)
        x_vect = torch.matmul(x, x_mat)
        x_vect_sq = torch.mul(x_vect, x_vect)
        x_vect_sq_sum = torch.sum(x_vect_sq, axis=-1)
        norm_term = torch.mean(x_vect_sq_sum)/(2*s*s)

        W = self.decoder.W.weight
        W_T = torch.transpose(W, 0, 1)
        trace_mat = torch.matmul(W_T,W)
        S = torch.exp(self.encoder.log_S)
        trace_diag = torch.diagonal(trace_mat).reshape(S.shape)
        diag_new = trace_diag * S
        trace_term = torch.sum(diag_new)/(2*s*s)

        return log_term + norm_term + trace_term
Beispiel #24
0
    def _get_constraints(self, y_train_labeled, n, augment):
        """
        Get constraints for the equivalence matrix based on the labeled data.

        :param y_train_labeled: Labels for the labeled subset of the batch
        :param n: Batch size (labeled + unlabeled observations)
        :param augment: Number of augmented copies of each input example to use. If 0 then no augmentation is performed
        :return: mask: Binary matrix with value 0 in entry (i,j) if it is known whether i and j belong to the same class
                       and 1 else
        :return: known: Binary matrix with value 1 in entry (i,j) if it is known that i and j belong to the same class
                        and 0 else
        """
        if y_train_labeled is not None:
            nl = len(y_train_labeled)
            mask = (torch.BoolTensor(n, n).zero_() + 1).to(defaults.device)
            known = torch.zeros(n, n).to(defaults.device)
            y_labeled_one_hot = opt_utils.one_hot_embedding(
                y_train_labeled, self.params.nclasses)
            known[0:nl, 0:nl] = y_labeled_one_hot.mm(y_labeled_one_hot.t())
            torch.diagonal(known).fill_(1)
            mask[0:nl, 0:nl] = 0
            torch.diagonal(mask).fill_(0)
        else:
            mask = (torch.BoolTensor(n, n).zero_() + 1).to(defaults.device)
            known = torch.zeros(n, n).to(defaults.device)
            torch.diagonal(known).fill_(1)
            torch.diagonal(mask).fill_(0)
        if augment > 0:
            if y_train_labeled is not None:
                nu = n - nl
            else:
                nu = n
                nl = 0
            for i in range(int(nu // augment)):
                known[nl + i * augment:nl + (i + 1) * augment,
                      nl + i * augment:nl + (i + 1) * augment] = 1
                mask[nl + i * augment:nl + (i + 1) * augment,
                     nl + i * augment:nl + (i + 1) * augment] = 0

        return mask, known
Beispiel #25
0
    def _forward_equivariance(self, src, tgt, equi_src, equi_tgt, T):

        inv_loss, acc, fp, cn = self._forward_invariance(src, tgt)

        # equi feature: nb, nc, na
        # L2 distance function
        dist_func = lambda a, b: (a - b)**2
        bdim = src.size(0)

        # so3 interpolation
        # equi_srcR = self._interpolate(equi_src, T, sigma=self.sigma).view(bdim, -1)
        # equi_tgt = equi_tgt.view(bdim, -1)
        equi_tgt = self._interpolate(equi_tgt, T,
                                     sigma=self.sigma).view(bdim, -1)
        equi_srcR = equi_src.view(bdim, -1)

        # furthest positive
        all_dist = pairwise_distance_matrix(equi_srcR, equi_tgt)
        furthest_positive = torch.diagonal(all_dist)
        closest_negative = batch_hard_negative_mining(all_dist)

        diff = furthest_positive - closest_negative
        if self.loss == 'hard':
            diff = F.relu(diff + self.margin)
        elif self.loss == 'soft':
            diff = F.softplus(diff, beta=self.margin)
        elif self.loss == 'contrastive':
            diff = furthest_positive + F.relu(self.margin - closest_negative)
        # evaluate accuracy
        _, idx = torch.topk(all_dist, k=self.k_precision, dim=1, largest=False)
        accuracy = torch.sum(idx.int() == self.gt_idx).float() / float(bdim)

        inv_info = [inv_loss, acc, fp, cn]
        equi_loss = diff.mean()
        total_loss = inv_loss + self.alpha * equi_loss
        equi_info = [
            equi_loss, accuracy,
            furthest_positive.mean(),
            closest_negative.mean()
        ]

        return total_loss, inv_info, equi_info
Beispiel #26
0
def KRt_from_projection(
        P: torch.Tensor,
        eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""Decompose the Projection matrix into Camera-Matrix, Rotation Matrix and Translation vector.

    Args:
        P: the projection matrix with shape :math:`(B, 3, 4)`.

    Returns:
        - The Camera matrix with shape :math:`(B, 3, 3)`.
        - The Rotation matrix with shape :math:`(B, 3, 3)`.
        - The Translation vector with shape :math:`(B, 3)`.

    """
    if P.shape[-2:] != (3, 4):
        raise AssertionError("P must be of shape [B, 3, 4]")
    if len(P.shape) != 3:
        raise AssertionError

    submat_3x3 = P[:, 0:3, 0:3]
    last_column = P[:, 0:3, 3].unsqueeze(-1)

    # Trick to turn QR-decomposition into RQ-decomposition
    reverse = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]],
                           device=P.device,
                           dtype=P.dtype).unsqueeze(0)
    submat_3x3 = torch.matmul(reverse, submat_3x3).permute(0, 2, 1)
    ortho_mat, upper_mat = linalg_qr(submat_3x3)
    ortho_mat = torch.matmul(reverse, ortho_mat.permute(0, 2, 1))
    upper_mat = torch.matmul(reverse,
                             torch.matmul(upper_mat.permute(0, 2, 1), reverse))

    # Turning the `upper_mat's` diagonal elements to positive.
    diagonals = torch.diagonal(upper_mat, dim1=-2, dim2=-1) + eps
    signs = torch.sign(diagonals)
    signs_mat = torch.diag_embed(signs)

    K = torch.matmul(upper_mat, signs_mat)
    R = torch.matmul(signs_mat, ortho_mat)
    t = torch.matmul(torch.inverse(K), last_column)

    return K, R, t
Beispiel #27
0
    def forward(self,
                AttM,
                input,
                label=None,
                TrainOrTest='Train',
                clsGroup=None,
                showAtt=False):
        batch_num = len(input)
        cls_num = len(AttM)
        x = F.normalize(input)
        att = self.img_guaid(x)
        att = self.ReLU(att)
        att = F.softmax(att / self.tmp).reshape(-1, 1, self.att_dims)
        att = att + torch.ones_like(att)
        att = att.repeat(1, cls_num, 1)

        AttM = AttM.reshape(1, -1, self.att_dims)
        AttM = AttM.repeat(len(input), 1, 1)

        AttM = att * AttM

        W1 = self.L1(AttM)
        W1 = self.ReLU1(W1)
        W2 = self.L2(W1)
        classifier = self.ReLU2(W2)
        classifier = F.normalize(classifier, p=2, dim=-1, eps=1e-12)

        out = self.scale_cls * (torch.matmul(classifier, x.t()) + self.bias)
        out = out.permute(1, 0, 2)
        out = torch.diagonal(out, offset=0, dim1=1, dim2=2)
        out = out.t()

        if TrainOrTest == 'Train':
            if showAtt:
                return out, att
            else:
                return out
        else:
            if showAtt:
                return out, out / self.scale_cls, att
            else:
                return out, out / self.scale_cls
Beispiel #28
0
def randneg_train(query_encode_func,
                  doc_encode_func,
                  input_query_ids,
                  query_attention_mask,
                  input_doc_ids,
                  doc_attention_mask,
                  other_doc_ids=None,
                  other_doc_attention_mask=None,
                  hard_pair_mask=None):

    query_embs = query_encode_func(input_query_ids, query_attention_mask)
    doc_embs = doc_encode_func(input_doc_ids, doc_attention_mask)

    with autocast(enabled=False):
        batch_scores = torch.matmul(query_embs, doc_embs.T)
        single_positive_scores = torch.diagonal(batch_scores, 0)
    # other_doc_ids: batch size, per query doc, length
    other_doc_num = other_doc_ids.shape[0] * other_doc_ids.shape[1]
    other_doc_ids = other_doc_ids.reshape(other_doc_num, -1)
    other_doc_attention_mask = other_doc_attention_mask.reshape(
        other_doc_num, -1)
    other_doc_embs = doc_encode_func(other_doc_ids, other_doc_attention_mask)

    with autocast(enabled=False):
        other_batch_scores = torch.matmul(query_embs, other_doc_embs.T)
        other_batch_scores = other_batch_scores.reshape(-1)
        positive_scores = single_positive_scores.reshape(-1, 1).repeat(
            1, other_doc_num).reshape(-1)
        other_logit_matrix = torch.cat(
            [positive_scores.unsqueeze(1),
             other_batch_scores.unsqueeze(1)],
            dim=1)
        # print(logit_matrix)
        other_lsm = F.log_softmax(other_logit_matrix, dim=1)
        other_loss = -1.0 * other_lsm[:, 0]
        if hard_pair_mask is not None:
            hard_pair_mask = hard_pair_mask.reshape(-1)
            other_loss = other_loss * hard_pair_mask
            second_loss, second_num = other_loss.sum(), hard_pair_mask.sum()
        else:
            second_loss, second_num = other_loss.sum(), len(other_loss)
    return (second_loss / second_num, )
Beispiel #29
0
    def duel_forward(self, feature, package_valid_matrix):
        feat = self.linear1(feature)
        hidden_sum = self.GAT(self.multilinear1, feat, package_valid_matrix)
        hidden_sum = hidden_sum.transpose(0,1)
        mask = torch.diagonal(package_valid_matrix[:, :, :, 0],dim1=1,dim2=2)
        src_key_padding_mask = (1 - mask).bool()
        output = self.transformer_encoder(hidden_sum, src_key_padding_mask=src_key_padding_mask)
        adv_output = self.adv_encoder_layers(output, src_key_padding_mask=src_key_padding_mask).transpose(1,0)
        value_output = self.value_encoder_layers(output, src_key_padding_mask=src_key_padding_mask).transpose(1,0)
        adv = self.advlinear(adv_output).squeeze(-1)

        value = torch.sum(value_output * mask.unsqueeze(-1), dim=1)
        value = self.outputlayer(value) # B * 1

        sum_mask = torch.sum(mask, dim=1)
        value_adv = value - (torch.sum(adv * mask, dim=1)/sum_mask).unsqueeze(1)

        ret = value_adv + adv
        ret -= 9999999999 * (1-mask)
        return ret
    def forward(self, z1, z2, **kwargs) -> Tensor:
        batch_size, _ = z1.size()
        sim_matrix = torch.einsum('ik,jk->ij', z1, z2)

        if self.norm:
            z1_abs = z1.norm(dim=1)
            z2_abs = z2.norm(dim=1)
            sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs)

        sim_matrix = torch.exp(sim_matrix / self.tau)
        pos_sim = torch.diagonal(sim_matrix)
        loss = pos_sim / (sim_matrix.sum(dim=1))
        loss = - torch.log(loss).mean()
        if self.variance_reg > 0:
            loss += self.variance_reg * (std_loss(z1) + std_loss(z2))
        if self.covariance_reg > 0:
            loss += self.covariance_reg * (cov_loss(z1) + cov_loss(z2))
        if self.uniformity_reg > 0:
            loss += self.uniformity_reg * uniformity_loss(z1, z2)
        return loss
def pdist(embeddings, squared=False):

    # The shape of embeddings will be [batch_size, 2]
    max_fn = nn.ReLU()
    embeddings_transpose = torch.transpose(embeddings, 0, 1)
    dot_product = torch.mm(embeddings, embeddings_transpose)
    square_norm = torch.diagonal(dot_product, 0)
    distances = square_norm.view(1, -1) - 2.0 * dot_product + square_norm.view(
        -1, 1)
    # Because of computation errors, some distances might be negative, so we force all of them to be >= 0
    distances = max_fn(distances)

    if not squared:
        mask = torch.eq(distances, 0.0).float()
        distances = distances + mask * 1e-16  # Adding small epsilon for numerical stability as gradient of square root function if value is 0 will be infinite
        distances = torch.sqrt(distances)
        # Correcting the epsilon value added
        distances = distances * (1.0 - mask)

    return distances
Beispiel #32
0
def zero_mean_covariance(covariance, stability=0.0):
    '''Output covariance of ReLU for zero-mean Gaussian input.

    f(x) = max(x, 0).

    Args:
        covariance: Input covariance matrix (Size, Size).
        stability: For accurate results this should be zero
            if used in training, use a value like 1e-4 for stability.

    Returns:
        Output covariance of ReLU for zero-mean Gaussian input (Size, Size).
    '''
    S = outer(torch.sqrt(torch.diagonal(covariance, 0, -2, -1)))
    V = (covariance / S).clamp_(stability - 1.0, 1.0 - stability)
    Q = torch.acos(-V) * V + torch.sqrt(1.0 - (V**2.0)) - 1.0
    cov = S * Q * (1.0 / (2.0 * math.pi))
    # handle degenerate case when we have zero variance
    cov[cov != cov] = 0  # replace nans with zeros
    return cov