コード例 #1
0
    def __call__(self, batch):
        input_ids, permuted = zip(*batch)
        permuted = torch.tensor(permuted, dtype=torch.float)
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self._pad_token_id)

        mask = torch.logical_and(
            torch.less(torch.rand(input_ids.shape), self._mask_prob),
            torch.not_equal(input_ids, self._pad_token_id))
        truly_mask = torch.less(torch.rand(input_ids.shape),
                                1 - self._random_prob)
        random_mask = torch.less(torch.rand(input_ids.shape), 0.5)

        labels = torch.where(mask, input_ids, TARGET_IDX)

        # masking some of the tokens
        input_ids = torch.where(torch.logical_and(mask, truly_mask),
                                self._mask_token_id, input_ids)

        # randomly changing other tokens
        input_ids = torch.where(
            torch.logical_and(
                mask,
                torch.logical_and(torch.logical_not(truly_mask), random_mask)),
            torch.randint_like(input_ids, low=5, high=self._vocab_size),
            input_ids)

        return input_ids, labels, permuted
コード例 #2
0
    def forward(ctx, X, rank: int = 100):

        U, S, V = torch.svd(X, compute_uv=True, some=False)

        S = torch.diag(S[0:(rank - 1)])
        U = torch.matmul(U[:, 0:(rank - 1)], S)
        V = torch.transpose(V, 0, 1)[0:(rank - 1), :]

        x, y = X.shape

        Unew = U[:, 0]
        Vnew = V[0, :]

        __U = torch.where(torch.less(torch.min(V[0, :]), torch.min(-V[0, :])),
                          -(Unew.view(x, 1)), Unew.view(x, 1))
        __V = torch.where(torch.less(torch.min(V[0, :]), torch.min(-V[0, :])),
                          -(Vnew.view(1, y)), Vnew.view(1, y))
        if rank > 2:
            for i in range(1, rank - 1):
                Unew = Unew.view(x, 1)
                Vnew = Vnew.view(1, y)
                __U = torch.where(
                    torch.less(torch.min(V[0, :]), torch.min(-V[0, :])),
                    torch.cat((__U, -Unew), dim=1),
                    torch.cat((__U, Unew), dim=1))
                __V = torch.where(
                    torch.less(torch.min(V[0, :]), torch.min(-V[0, :])),
                    torch.cat((__V, -Vnew), dim=0),
                    torch.cat((__V, Vnew), dim=0))

        if rank == 2:
            A = torch.cat((U, -U), dim=1)
        else:
            Un = torch.transpose(-(torch.sum(U, dim=1)), 0, -1).view(x, 1)
            A = torch.cat((U, Un), dim=1)

        B = torch.cat((V, torch.zeros((1, y))), dim=0)

        if rank >= 3:
            b, _ = torch.min(V, dim=0)
            B = torch.subtract(B, torch.minimum(torch.tensor(0.), b))
        else:
            B = torch.subtract(B, torch.minimum(torch.tensor(0.), V))
        x = torch.tensor(x)
        y = torch.tensor(y)
        normalize = torch.sqrt(torch.multiply(x, y).type(torch.FloatTensor))
        norm = torch.norm(A)

        return torch.multiply(torch.div(A, norm),
                              normalize), torch.div(torch.multiply(B, norm),
                                                    normalize)
コード例 #3
0
    def forward(self, x):
        # check dims
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'.format(
                x.dim()))
        if self.training:
            # batch stats
            x_min = torch.amin(x, dim=(0, 1))
            x_max = torch.amax(x, dim=(0, 1))

            if self.first:
                self.max = x_max
                self.min = x_min
                self.first = False

            else:
                # update min max with masking correect entries
                max_mask = torch.greater(x_max, self.max)
                self.max = (max_mask * x_max) + \
                    (torch.logical_not(max_mask) * self.max)

                min_mask = torch.less(x_min, self.min)
                self.min = (min_mask * x_min) + \
                    (torch.logical_not(min_mask) * self.min)

            self.max_min = self.max - self.min + 1e-13

        # scale batch
        x = (x - self.min) / self.max_min

        return x
コード例 #4
0
ファイル: math_ops.py プロジェクト: malfet/pytorch
 def comparison_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     return (
         torch.allclose(a, b),
         torch.argsort(a),
         torch.eq(a, b),
         torch.equal(a, b),
         torch.ge(a, b),
         torch.greater_equal(a, b),
         torch.gt(a, b),
         torch.greater(a, b),
         torch.isclose(a, b),
         torch.isfinite(a),
         torch.isin(a, b),
         torch.isinf(a),
         torch.isposinf(a),
         torch.isneginf(a),
         torch.isnan(a),
         torch.isreal(a),
         torch.kthvalue(a, 1),
         torch.le(a, b),
         torch.less_equal(a, b),
         torch.lt(a, b),
         torch.less(a, b),
         torch.maximum(a, b),
         torch.minimum(a, b),
         torch.fmax(a, b),
         torch.fmin(a, b),
         torch.ne(a, b),
         torch.not_equal(a, b),
         torch.sort(a),
         torch.topk(a, 1),
         torch.msort(a),
     )
コード例 #5
0
def merge_clusters(v, size = 1.0, precision = 0.02):
    device = 'cuda:0'
    device1 = 'cuda:0'
    if torch.cuda.device_count() > 1:
        device1 = 'cuda:1'

    v = torch.tensor(v).to(device)
    polygons = torch.unique(v[:, -1:])

    #indexes = np.in1d(np.array(list(annotations.keys())), polygons.cpu().numpy()).nonzero()
    #annotation_centers = np.array(list(annotations.values())).reshape((-1,3))[indexes]
    #centers = torch.tensor(np.array(list(annotation_centers))).to(device)
    #dist = torch.cdist(centers, centers)
    #argwhere = (dist < size).nonzero()
    #argwhere = argwhere[argwhere[:,0] != argwhere[:,1]]
    #argwhere, _ = torch.sort(argwhere, dim = 1)
    #pairs = torch.unique(argwhere, dim = 0).reshape((-1, 2))
    if polygons.shape[0] < 2:
        return np.array(v.cpu(), dtype=np.uint32)

    print('num of polygons ', polygons.shape)
    pairs = torch.tensor(np.array([]).reshape((-1, 2)), dtype = torch.int32)
    for i in range(polygons.shape[0] - 1):
        for j in range(i, polygons.shape[0]):
            pairs = torch.cat((pairs, torch.tensor(np.array([i,j]).reshape((-1,2)))), dim=0)
    #print(pairs, polygons.shape)

    if pairs.shape[0] > 0 and pairs.shape[1] == 2:
        for ii in range(polygons.shape[0]):
            for i in range(pairs.shape[0]):
                poly_id1, poly_id2 = polygons[pairs[i]]
                if poly_id1 == poly_id2: #merged ones are the same
                    continue

                idx1 = torch.eq(v[:, -1], poly_id1)
                idx2 = torch.eq(v[:, -1], poly_id2)

                # so awkward
                pts1 = torch.tensor(np.frombuffer(np.array(v[idx1, 3:6].cpu()), dtype=np.float32)).reshape((-1, 3)).to(device1)
                pts2 = torch.tensor(np.frombuffer(np.array(v[idx2, 3:6].cpu()), dtype=np.float32)).reshape((-1, 3)).to(device1)
                if (pts1.shape[0] == 0) or (pts2.shape[0] == 0): #nothing to merge
                    continue

                #print(pts1[0], v[0])
                d = torch.cdist(pts1, pts2)
                overlap = torch.less(d, precision).nonzero()
                #overlap = torch.less(overlap[:,0], overlap[:,1])
                #print(pts1, pts2, d, precision)
                if (overlap.shape[0] / (pts1.shape[0] * pts2.shape[0]) > 0.0004):
                    #print(overlap.shape, pts1.shape, pts2.shape, poly_id1, poly_id2)
                    v[idx2, -1] = poly_id1
                    pairs[i][1] = pairs[i][0]
                    pairs[torch.eq(pairs[:,0], pairs[i][1]), 0:1] = pairs[i][0].clone()    # poly_id2 no longer exist
            old_pairs = torch.unique(pairs, dim = 0)
            if (old_pairs.shape[0] == pairs.shape[0]) and torch.all(torch.eq(old_pairs, pairs)):
                break
            pairs = old_pairs

    return np.array(v.cpu(), dtype=np.uint32)
コード例 #6
0
def calculate_complex_aps(raw_data, metrics_threshold, metrics_operator):

    # Conatiner for all aps
    aps = {}

    # Iteration over the metrics
    for data_key in metrics_threshold.keys():

        # Creating subcontainer for metrics values
        aps[data_key] = {}

        # Select metrics information
        thresholds = metrics_threshold[data_key]
        keys = [k for k in raw_data.keys() if k in data_key]

        # Concatinating data sets
        data = {}
        for key in keys:
            for class_id in raw_data[key].keys():
                if class_id in list(data.keys()):
                    data[class_id] = torch.stack(
                        (data[class_id], raw_data[key][class_id]))
                else:
                    data[class_id] = raw_data[key][class_id]

        # Iterating over the classes
        for class_id in data.keys():

            # Selecting the class' data
            class_data = data[class_id]

            # Remove Nans from the calculations
            # class_data = class_data[torch.isnan(class_data) == False]
            n_of_d = class_data.shape[1]

            # Expanding data to make comparison easy
            e_class_data = torch.unsqueeze(class_data, dim=1)
            e_threshold = torch.unsqueeze(thresholds, dim=-1)

            # Applying operator (currently only using less than)
            applied_threshold = torch.less(e_class_data, e_threshold)

            # Collapsing the metrics to a single value per sample
            threshold_mixed = (torch.sum(
                applied_threshold,
                dim=0) == applied_threshold.shape[0]).bool()

            # Futher collapsing now all the samples into a single class value
            class_ap = torch.sum(threshold_mixed, dim=1) / n_of_d

            # Storing class aps
            aps[data_key][class_id] = class_ap

        # Calculating mean for class
        aps[data_key]['mean'] = torch.mean(torch.stack(
            list(aps[data_key].values())).float(),
                                           dim=0)

    return aps
コード例 #7
0
 def not_done(self, i):
     y = self.score * torch.cast(self.flag, torch.floatx())
     y = torch.reduce_min(y, axis=1)
     fs = torch.reduce_any(self.flags, axis=1)
     old = y + (1.0 - torch.cast(fs, torch.floatx())) * utils.big_neg
     n = torch.int_shape(self.tgt)[-1]
     new = self.logp[:, 0] / self.penalty(n)
     done = torch.reduce_all(torch.greater(old, new))
     return torch.logical_and(torch.less(i, n), torch.logical_not(done))
コード例 #8
0
def reduced_sigmoid_focal_loss(pred,
                               target,
                               weight=None,
                               gamma=2.0,
                               alpha=0.25,
                               reduction='mean',
                               avg_factor=None,
                               threshold=0.5):
    """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the
            number of classes
        target (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
    """
    # t = torch.nn.functional.one_hot(target.long()).float()
    l = nn.BCEWithLogitsLoss(reduction='none')
    ce = l(pred, target.float())
    pred_sigmoid = pred.sigmoid()

    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    modulating_factor = torch.greater_equal(
        pt, threshold).float() + torch.less(pt, threshold).float() * (
            pt).pow(gamma) / torch.tensor(threshold).pow(gamma)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * modulating_factor
    loss = ce * focal_weight

    if weight is not None:
        if weight.shape != loss.shape:
            if weight.size(0) == loss.size(0):
                # For most cases, weight is of shape (num_priors, ),
                #  which means it does not have the second axis num_class
                weight = weight.view(-1, 1)
            else:
                # Sometimes, weight per anchor per class is also needed. e.g.
                #  in FSAF. But it may be flattened of shape
                #  (num_priors x num_class, ), while loss is still of shape
                #  (num_priors, num_class).
                assert weight.numel() == loss.numel()
                weight = weight.view(loss.size(0), -1)
        assert weight.ndim == loss.ndim
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss
コード例 #9
0
ファイル: evaluate.py プロジェクト: jialanxin/megnetorch
def count_incorrects(ramans, predict_confidence_round):
    ramans = ramans[0]
    target_confidence = ramans[:, 0]
    target_position = ramans[:, 1]
    target_position = absolute_position(target_position).flatten()
    more = torch.greater(predict_confidence_round,
                         target_confidence).float().sum().detach().item()
    less = torch.less(predict_confidence_round,
                      target_confidence).float().sum().detach().item()
    incorrect = more + less
    total = target_confidence.float().sum().detach().item()
    return int(more), int(less), int(incorrect), int(
        total), target_confidence, target_position
コード例 #10
0
def smooth_l1(deltas, targets, sigma=3.0):

    sigma2 = sigma * sigma
    diffs = torch.subtract(deltas, targets)
    smooth_l1_signs = torch.less(torch.abs(diffs), 1.0 / sigma2).float32

    smooth_l1_option1 = torch.mul(diffs, diffs) * 0.5 * sigma2
    smooth_l1_option2 = torch.abs(Diffs) - 0.5 / sigma2
    smooth_l1_add = torch.mul(smooth_l1_option1, smooth_l1_signs) + torch.mul(
        smooth_l1_option2, 1 - smooth_l1_signs)

    smooth_l1 = smooth_l1_add

    return smooth_l1
コード例 #11
0
def train(epoch):
    if epoch > 0 and epoch % decay_every == 0:
        lr = lr * decay_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    net.train()
    optimizer.zero_grad()

    jaccard_indices = []

    for ligand, protein, site in progressbar.progressbar(
            scpdb_dataloader_train):
        ligand = ligand.to(device)
        protein = protein.to(device)
        site = site.to(device)

        segmentation = net(protein, ligand)
        ground_truth_segmentation = construct_ground_truth_segmentation(site)
        ground_truth_segmentation = torch.from_numpy(ground_truth_segmentation)

        loss = bce_logit_loss(segmentation, ground_truth_segmentation)
        loss.backward()

        # track accuracy
        predictions = (segmentation > 0.5).int()

        true_positives = torch.logical_and(predictions,
                                           ground_truth_segmentation)
        false_positives = torch.greater(predictions, ground_truth_segmentation)
        false_negatives = torch.less(predictions, ground_truth_segmentation)

        num_correct = torch.sum(true_positives)
        iou = num_correct / (num_correct + torch.sum(false_positives) +
                             torch.sum(false_negatives))
        jaccard_indices.append(iou)

        # optimize

        optimizer.step()
        optimizer.zero_grad()

    mean_iou = sum(jaccard_indices / len(jaccard_indices))
    print("Epoch {} - Train: {:06.4f}".format(epoch, mean_iou))
コード例 #12
0
    def forward(self, x, y):
        # VAE component
        qz = self.enc(x)
        logodds, rate = self.dec(qz)

        pois_llik = x * torch.log(rate) - rate + sp.gammaln(x)
        # Important identities:
        # log(x + y) = log(x) + softplus(y - x)
        # log(sigmoid(x)) = -softplus(-x)
        case_zero = -torch.nn.Softplus(-logodds) + torch.nn.Softplus(
            pois_llik(x, rate) + torch.nn.Softplus(-logodds))
        case_non_zero = -torch.nn.Softplus(logodds) + pois_llik(x, mean)
        zip_llik = torch.where(torch.less(x, 1), case_zero, case_non_zero)

        kl_pz_qz = .5 * (1 + T.log(prec) + prior_prec *
                         (torch.square(mean) + 1 / prec))

        vae_loss = torch.mean(torch.mean(zip_llik) - kl_pz_qz)
        return vae_loss
コード例 #13
0
ファイル: math_ops.py プロジェクト: yanboliang/pytorch
 def forward(self):
     a = torch.tensor(0)
     b = torch.tensor(1)
     return len(
         torch.allclose(a, b),
         torch.argsort(a),
         torch.eq(a, b),
         torch.eq(a, 1),
         torch.equal(a, b),
         torch.ge(a, b),
         torch.ge(a, 1),
         torch.greater_equal(a, b),
         torch.greater_equal(a, 1),
         torch.gt(a, b),
         torch.gt(a, 1),
         torch.greater(a, b),
         torch.isclose(a, b),
         torch.isfinite(a),
         torch.isin(a, b),
         torch.isinf(a),
         torch.isposinf(a),
         torch.isneginf(a),
         torch.isnan(a),
         torch.isreal(a),
         torch.kthvalue(a, 1),
         torch.le(a, b),
         torch.le(a, 1),
         torch.less_equal(a, b),
         torch.lt(a, b),
         torch.lt(a, 1),
         torch.less(a, b),
         torch.maximum(a, b),
         torch.minimum(a, b),
         torch.fmax(a, b),
         torch.fmin(a, b),
         torch.ne(a, b),
         torch.ne(a, 1),
         torch.not_equal(a, b),
         torch.sort(a),
         torch.topk(a, 1),
         torch.msort(a),
     )
コード例 #14
0
def focal_loss_for_heat_map(labels,
                            logits,
                            pos_threshold=0.99,
                            alpha=2,
                            beta=4,
                            sum=True):
    '''
    focal loss for heat map, for example CenterNet2's heat map loss
    '''
    logits = logits.to(torch.float32)
    zeros = torch.zeros_like(labels)
    ones = torch.ones_like(labels)
    num_pos = torch.sum(
        torch.where(torch.greater_equal(labels, pos_threshold), ones, zeros))

    probs = F.sigmoid(logits)
    pos_weight = torch.where(torch.greater_equal(labels, pos_threshold),
                             ones - probs, zeros)
    neg_weight = torch.where(torch.less(labels, pos_threshold), probs, zeros)
    '''
    用于保证数值稳定性,log(sigmoid(x)) = log(1/(1+e^-x) = -log(1+e^-x) = x-x-log(1+e^-x) = x-log(e^x +1)
    pos_loss = tf.where(tf.less(logits,0),logits-tf.log(tf.exp(logits)+1),tf.log(probs))
    '''
    pure_pos_loss = -torch.minimum(
        logits, logits.new_tensor(0, dtype=logits.dtype)) + torch.log(
            1 + torch.exp(-torch.abs(logits)))
    pos_loss = pure_pos_loss * torch.pow(pos_weight, alpha)
    if sum:
        pos_loss = torch.sum(pos_loss)
    '''
    用于保证数值稳定性
    '''
    pure_neg_loss = F.relu(logits) + torch.log(1 +
                                               torch.exp(-torch.abs(logits)))
    neg_loss = torch.pow(
        (1 - labels), beta) * torch.pow(neg_weight, alpha) * pure_neg_loss
    if sum:
        neg_loss = torch.sum(neg_loss)
    loss = (pos_loss + neg_loss) / (num_pos + 1e-4)
    return loss
コード例 #15
0
ファイル: ops.py プロジェクト: DwightFoster/Pytorch-TecoGAN
def random_flip(input, decision):
    identity = torch.identity()
    f1 = identity(input)
    f2 = torch.flip(input, dim=3)
    return torch.where(torch.less(decision, 0.5), f2, f1)
コード例 #16
0
def get_mask(lengths, sequence_len):
    batch_size = lengths.shape[0]
    bool_mask = tc.less(
        tc.arange(sequence_len).expand(batch_size, sequence_len),
        lengths.unsqueeze(dim=1).expand(batch_size, sequence_len))
    return bool_mask.float()
コード例 #17
0
ファイル: _tensor.py プロジェクト: zander9648/ort-customops
 def __lt__(self, other):
     x0, x1 = self._to_binary_tensor_args(other)
     y = torch.less(x0._t, x1._t)
     s = _ox.less(*_EagerTensor.ox_args([x0, x1]))
     return self.from_torch(y, s)
コード例 #18
0
def cosine_similarity_dim1(a: torch.Tensor, b: torch.Tensor, eps=1e-5):
    dot_prod = torch.einsum('bn,bn->b', a, b)
    vecs_lens = torch.norm(a, dim=1) * torch.norm(b, dim=1)
    epsv = torch.Tensor([eps] * vecs_lens.shape[0])
    cos = dot_prod / torch.where(torch.less(vecs_lens, epsv), epsv, vecs_lens)
    return cos