def prune_model(model, amount, prune_mask, method=prune.L1Unstructured):
    model.to('cpu')
    model.mask_to_device('cpu')
    for name, module in model.named_modules():  # re-apply current mask to the model
        if isinstance(module, torch.nn.Linear):
#            if name is not "fc4":
             prune.custom_from_mask(module, "weight", prune_mask[name])

    parameters_to_prune = (
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
        (model.fc4, 'weight'),
    )
    prune.global_unstructured(  # global prune the model
        parameters_to_prune,
        pruning_method=method,
        amount=amount,
    )

    for name, module in model.named_modules():  # make pruning "permanant" by removing the orig/mask values from the state dict
        if isinstance(module, torch.nn.Linear):
#            if name is not "fc4":
            torch.logical_and(module.weight_mask, prune_mask[name],
                              out=prune_mask[name])  # Update progress mask
            prune.remove(module, 'weight')  # remove all those values in the global pruned model

    return model
Example #2
0
def iou(out, labels):
    with torch.no_grad():
        if len(out.shape) == 4 and out.size(1) > 1:
            # layer_wise_label_mask = torch.zeros(
            #     [labels.size(0), torch.max(labels), labels.size(1), labels.size(2)],
            #     dtype=torch.long
            # )
            # layer_wise_label_mask[labels] = 1
            # layer_wise_label_mask = torch.stack([labels == x for x in range(labels.max() + 1)], dim=1).long()
            # prediction = torch.zeros_like(out)
            # prediction[torch.max(torch.softmax(out, dim=1), dim=1)] = 1
            ious = []
            prediction = torch.argmax(out, dim=1)
            for cat in torch.unique(labels):
                cat = int(cat)
                if cat == 0:
                    continue
                intersection = torch.logical_and(prediction == cat,
                                                 labels == cat).sum(-1).sum(-1)
                union = torch.logical_or(prediction == cat,
                                         labels == cat).sum(-1).sum(-1)
                ious.append(
                    torch.mean((intersection + 1e-8) / (union + 1e-8)).item())
            return np.mean(ious)

        else:
            layer_wise_label_mask = labels
            prediction = torch.round(torch.sigmoid(out))

            intersection = torch.logical_and(
                prediction, layer_wise_label_mask).sum(-1).sum(-1)
            union = torch.logical_or(prediction,
                                     layer_wise_label_mask).sum(-1).sum(-1)

            return torch.mean((intersection + 1e-8) / (union + 1e-8))
Example #3
0
def _get_triplet_mask(labels):
    """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]
    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    """
    # Check that i, j and k are distinct
    indices_equal = torch.eye(labels.shape[0], device=DEVICE).bool()
    indices_not_equal = indices_equal.logical_not()
    i_not_equal_j = indices_not_equal.unsqueeze(2)
    i_not_equal_k = indices_not_equal.unsqueeze(1)
    j_not_equal_k = indices_not_equal.unsqueeze(0)

    distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)


    # Check if labels[i] == labels[j] and labels[i] != labels[k]
    label_equal = torch.eq(labels.unsqueeze(0), labels.unsqueeze(1))
    i_equal_j = label_equal.unsqueeze(2)
    i_equal_k = label_equal.unsqueeze(1)

    valid_labels = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))

    # Combine the two masks
    mask = torch.logical_and(distinct_indices, valid_labels)

    return mask
Example #4
0
def _get_triplet_mask(labels):
    """ Returns a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]

    Args:
        labels: Long `Tensor` with shape [batch_size]
    Returns:
        Mask: bool `Tensor` with shape [batch_size, batch_size].
    """

    # Check that i, j and k are distinct
    indices_equal = torch.eye(labels.size()[0]).bool().to(device)
    indices_not_equal = torch.logical_not(indices_equal)
    i_not_equal_j = torch.unsqueeze(indices_not_equal, 2)
    i_not_equal_k = torch.unsqueeze(indices_not_equal, 1)
    j_not_equal_k = torch.unsqueeze(indices_not_equal, 0)

    distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)

    # Check if labels[i] == labels[j] and labels[i] != labels[k]
    label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
    i_equal_j = torch.unsqueeze(label_equal, 2)
    i_equal_k = torch.unsqueeze(label_equal, 1)

    valid_labels = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))

    # Combine the two masks
    mask = torch.logical_and(distinct_indices, valid_labels)

    return mask
def scoring(targets, predictions, verbose=True):
    acc = accuracy_score(targets, predictions)
    if verbose:
        print("Accuracy: {}".format(acc))

    targets = torch.tensor(targets)
    predictions = torch.tensor(predictions)
    tp = float(torch.logical_and(predictions == 1, targets == 1).sum())
    fp = float(torch.logical_and(predictions == 1, targets == 0).sum())
    tn = float(torch.logical_and(predictions == 0, targets == 0).sum())
    fn = float(torch.logical_and(predictions == 0, targets == 1).sum())

    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    mcc = (tp * tn - fp * fn) / sqrt(
        (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    f1 = 2 * recall * precision / (recall + precision)

    if verbose:
        print("F1-Score: {}".format(f1))
        print("Precision: {}".format(precision))
        print("Recall: {}".format(recall))
        print("MCC: {}".format(mcc))
    metrics = {
        "acc": acc,
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "mcc": mcc
    }
    return metrics
    def apply(self, gt_boxes, pred_boxes, gt_labels, pred_labels, pred_scores,
              input_data):

        gt_centers = gt_boxes.gravity_center
        gt_ranges = torch.norm(gt_centers, dim=1)

        gt_boxes_mask = torch.logical_and(gt_ranges >= self._min_radius,
                                          gt_ranges <= self._max_radius)

        pred_centers = pred_boxes.gravity_center
        pred_ranges = torch.norm(pred_centers, dim=1)

        pred_boxes_mask = torch.logical_and(pred_ranges >= self._min_radius,
                                            pred_ranges < self._max_radius)

        (
            gt_boxes,
            pred_boxes,
            gt_labels,
            pred_labels,
            pred_scores,
        ) = self._apply_box_mask(
            gt_boxes_mask,
            pred_boxes_mask,
            gt_boxes,
            pred_boxes,
            gt_labels,
            pred_labels,
            pred_scores,
        )

        return gt_boxes, pred_boxes, gt_labels, pred_labels, pred_scores, input_data
Example #7
0
    def test(self, model, nClass):
        model.to(self.device)
        model.eval()
        self.confMatrix = torch.zeros((nClass, nClass),
                                      dtype=torch.int64).cpu()

        predictions = torch.Tensor().cpu().type(torch.int64)
        labels = torch.Tensor().cpu().type(torch.int64)

        with torch.no_grad():
            for i, (x, y) in enumerate(self.testloader):
                x = x.to(self.device)
                y = y.cpu().type(torch.int64)
                pred = model(x).cpu().type(torch.int64)
                _, pred = torch.max(pred, dim=1)
                predictions = torch.cat([predictions, pred])
                labels = torch.cat([labels, y])

        for i in range(self.confMatrix.size(0)):
            for j in range(i, self.confMatrix.size(1)):
                self.confMatrix[i, j] += torch.logical_and(
                    predictions.eq(i), labels.eq(j)).sum().cpu()
                if i != j:
                    self.confMatrix[j, i] += torch.logical_and(
                        predictions.eq(j), labels.eq(i)).sum().cpu()
def train():
	alpha = 0.7
	model.train()
	#negative sampling
	neg_row, neg_col = negative_sampling(data.train_pos_edge_index, 
																	num_nodes = data.num_nodes,  
																	num_neg_samples= data.train_pos_edge_index.size(1))
	to_keep = ~ torch.logical_and(neg_row >= data.x_paper.size(0) , neg_col >= data.x_paper.size(0)) #keep exclude mesh-mesh edges
	neg_row, neg_col = neg_row[to_keep], neg_col[to_keep]
	train_neg_edge_index = torch.stack([neg_row, neg_col], dim=0)
	train_neg_edge_type = torch.logical_or(torch.logical_and(neg_row < data.x_paper.size(0) , neg_col >= data.x_paper.size(0)), torch.logical_and(neg_row >= data.x_paper.size(0) , neg_col < data.x_paper.size(0))).to(torch.float32)
	sort_indices = torch.argsort(train_neg_edge_type)
	train_neg_edge_index = train_neg_edge_index[:, sort_indices]
	train_neg_edge_type = train_neg_edge_type[sort_indices]
	optimizer.zero_grad()
	z = model.encode()
	link_logits = model.decode(z, data.train_pos_edge_index, data.train_pos_edge_type, train_neg_edge_index, train_neg_edge_type)
	link_labels = get_link_labels(data.train_pos_edge_index, train_neg_edge_index)
	link_logits_paper_paper = model.decode(z, data.train_pos_edge_index[:, data.train_pos_edge_type == 0], data.train_pos_edge_type[data.train_pos_edge_type == 0], train_neg_edge_index[:, train_neg_edge_type ==0], train_neg_edge_type[train_neg_edge_type ==0])
	link_logits_paper_mesh = model.decode(z,  data.train_pos_edge_index[:, data.train_pos_edge_type == 1], data.train_pos_edge_type[data.train_pos_edge_type == 1], train_neg_edge_index[:, train_neg_edge_type ==1], train_neg_edge_type[train_neg_edge_type ==1])
	link_labels_paper_paper = get_link_labels(data.train_pos_edge_index[:, data.train_pos_edge_type == 0], train_neg_edge_index[:, train_neg_edge_type ==0])
	link_labels_paper_mesh = get_link_labels(data.train_pos_edge_index[:, data.train_pos_edge_type == 1], train_neg_edge_index[:, train_neg_edge_type ==1])
	loss_paper_paper = F.binary_cross_entropy_with_logits(link_logits_paper_paper, link_labels_paper_paper)
	loss_paper_mesh = F.binary_cross_entropy_with_logits(link_logits_paper_mesh, link_labels_paper_mesh)
	loss = (1/2) * ((1 - alpha) * loss_paper_paper + alpha * loss_paper_mesh)
	# loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
	loss.backward()
	optimizer.step()
	link_probs = link_logits.sigmoid()
	link_probs_paper_paper = link_logits_paper_paper.sigmoid()
	link_probs_paper_mesh = link_logits_paper_mesh.sigmoid()
	rocauc=roc_auc_score(link_labels.detach().cpu().numpy(), link_probs.detach().cpu().numpy())
	roc_auc_pp=roc_auc_score(link_labels_paper_paper.detach().cpu().numpy(), link_probs_paper_paper.detach().cpu().numpy())
	roc_auc_pm=roc_auc_score(link_labels_paper_mesh.detach().cpu().numpy(), link_probs_paper_mesh.detach().cpu().numpy())
	return loss, rocauc, roc_auc_pp, roc_auc_pm
Example #9
0
    def forward(self, batch_dict):
        points = batch_dict['points']
        cp_points = batch_dict['cp_points']
        color_fea = torch.zeros(points.shape[0], 6).to(points.device)  #(N, 6)
        images = ['Placeholder']
        batch_size = batch_dict['batch_size']
        for i in range(5):
            images.extend([batch_dict["image_{}".format(i)]])

        for batch_id in range(batch_size):
            batch_flag = cp_points[:, 0] == batch_id
            for im_id in range(1, 6):
                #For a specific image, handle coordinates corresponding to it.
                h, w, _ = images[im_id][batch_id].size()
                im_flag = torch.logical_and(batch_flag, cp_points[:,
                                                                  1] == im_id)
                x = cp_points[im_flag, 2].long()
                y = cp_points[im_flag, 3].long()
                x = torch.clamp(x, 0, w - 1)
                y = torch.clamp(y, 0, h - 1)
                color_fea[im_flag.nonzero(),
                          torch.arange(0, 3)] = images[im_id][batch_id][y, x]

                im_flag = torch.logical_and(batch_flag, cp_points[:,
                                                                  4] == im_id)
                x = cp_points[im_flag, 5].long()
                y = cp_points[im_flag, 6].long()
                x = torch.clamp(x, 0, w - 1)
                y = torch.clamp(y, 0, h - 1)
                color_fea[im_flag.nonzero(),
                          torch.arange(3, 6)] = images[im_id][batch_id][y, x]
        color_fea = torch.maximum(color_fea[:, 0:3], color_fea[:, 3:6])
        batch_dict['points'] = torch.cat([batch_dict['points'], color_fea],
                                         axis=1)
        return batch_dict
Example #10
0
    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        outputs = self.forward(input_ids, attention_mask)
        loss = self.loss_function(outputs, labels)

        predictions = torch.where(
            outputs > 0.5, torch.ones(outputs.shape, device=self.device),
            torch.zeros(outputs.shape, device=self.device))

        if int(self.hparams.amount_labels) == 1:
            batch_size = labels.shape[0]
            labels = labels[:, 0].view(-1)
            predictions = predictions[:, 0].view(-1)

        accuracy = Accuracy()
        acc = accuracy(predictions, labels)

        tp = float(torch.logical_and(predictions == 1, labels == 1).sum())
        fp = float(torch.logical_and(predictions == 1, labels == 0).sum())
        tn = float(torch.logical_and(predictions == 0, labels == 0).sum())
        fn = float(torch.logical_and(predictions == 0, labels == 1).sum())

        self.test_conf_matrix["tp"] += tp
        self.test_conf_matrix["fp"] += fp
        self.test_conf_matrix["tn"] += tn
        self.test_conf_matrix["fn"] += fn

        result = EvalResult()
        result.log("test_loss", loss)
        result.log("acc", acc)
        return result
Example #11
0
    def compute_metric_counts(self, attribute_preds, attribute_labels):
        # We use the BCEWithLogits loss function, so the sigmoid needs to be applied before computing our metrics
        attribute_preds = torch.sigmoid(attribute_preds)
        attribute_preds = torch.round(attribute_preds)

        # Remove the predictions from GPU and move to CPU
        # attribute_preds_cpu = attribute_preds.detach().to(torch.device("cpu"))
        # attribute_labels_cpu = attribute_labels.detach().to(torch.device("cpu"))
        attribute_preds_cpu = attribute_preds.cpu()
        attribute_labels_cpu = attribute_labels.cpu()

        attribute_positive_preds = torch.ge(attribute_preds_cpu, 1)
        attribute_negative_preds = torch.lt(attribute_positive_preds, 1)
        attribute_positive_labels = torch.ge(attribute_labels_cpu, 1)
        attribute_negative_labels = torch.lt(attribute_positive_labels, 1)

        true_positive = torch.sum(torch.logical_and(
            attribute_positive_preds, attribute_positive_labels).int(),
                                  dim=0)
        false_positive = torch.sum(torch.logical_and(
            attribute_positive_preds, attribute_negative_labels).int(),
                                   dim=0)
        true_negative = torch.sum(torch.logical_and(
            attribute_negative_preds, attribute_negative_labels).int(),
                                  dim=0)
        false_negative = torch.sum(torch.logical_and(
            attribute_negative_preds, attribute_positive_labels).int(),
                                   dim=0)

        self.true_pos_count += true_positive
        self.true_neg_count += true_negative
        self.false_pos_count += false_positive
        self.false_neg_count += false_negative
Example #12
0
    def forward(self, sample):
        img, vertebrae = sample['image'], sample['vertebrae']

        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)

        width, height = F._get_image_size(img)
        # pad the width if needed
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
        # pad the height if needed
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)

        top, left, h, w = self.get_params(img, self.size)
        cropped_img = F.crop(img, top, left, h, w)

        vertebrae[:, 1] -= left
        vertebrae[:, 2] -= top

        left_check = torch.logical_and(vertebrae[:, 1] < w,
                                       vertebrae[:, 1] >= 0)
        top_check = torch.logical_and(vertebrae[:, 2] < h,
                                      vertebrae[:, 2] >= 0)
        correct_vertebrae = torch.logical_and(top_check, left_check)
        vertebrae = vertebrae[correct_vertebrae]

        return {
            'image': cropped_img,
            'vertebrae': vertebrae,
            'info': sample['info']
        }
Example #13
0
    def prune_CambriconS(self, q=0.75):
        chunk_size = self.chunk_size
        last_chunk = self.out_channels % chunk_size
        n_chunks = self.out_channels // chunk_size + (last_chunk != 0)

        conv_mat = self.conv.weight.data
        mask = torch.full(conv_mat.shape, True, dtype=bool).cuda()
        cutoff = torch.std(conv_mat) * q

        for chunk_idx in range(n_chunks):
            if chunk_idx == n_chunks - 1 and last_chunk != 0:
                current_chunk = conv_mat[chunk_idx * chunk_size:, :]
                l1_norm = torch.sum(torch.abs(current_chunk),
                                    dim=0) / last_chunk
                next_mask = (l1_norm > cutoff).repeat(last_chunk, 1, 1, 1)
                mask[chunk_idx * chunk_size:, :, :, :] = torch.logical_and(
                    mask[chunk_idx * chunk_size:, :, :, :], next_mask)
            else:
                current_chunk = conv_mat[chunk_idx *
                                         chunk_size:(chunk_idx + 1) *
                                         chunk_size, :]
                l1_norm = torch.sum(torch.abs(current_chunk),
                                    dim=0) / chunk_size
                next_mask = (l1_norm > cutoff).repeat(chunk_size, 1, 1, 1)
                mask[chunk_idx * chunk_size:(chunk_idx + 1) *
                     chunk_size, :, :, :] = torch.logical_and(
                         mask[chunk_idx * chunk_size:(chunk_idx + 1) *
                              chunk_size, :, :, :], next_mask)

        self.mask = mask
        # prune the weights
        self.conv.weight.data = self.conv.weight.float() * self.mask.float()
        # calculate sparsity
        self.sparsity = self.conv.weight.data.numel(
        ) - self.conv.weight.data.nonzero().size(0)
Example #14
0
def bounded_mse_loss(
    predictions: torch.tensor,
    targets: torch.tensor,
    less_than_target: torch.tensor,
    greater_than_target: torch.tensor,
) -> torch.tensor:
    """
    Loss function for use with regression when some targets are presented as inequalities.

    :param predictions: Model predictions with shape(batch_size, tasks).
    :param targets: Target values with shape(batch_size, tasks).
    :param less_than_target: A tensor with boolean values indicating whether the target is a less-than inequality.
    :param greater_than_target: A tensor with boolean values indicating whether the target is a greater-than inequality.
    :return: A tensor containing loss values of shape(batch_size, tasks).
    """
    predictions = torch.where(
        torch.logical_and(predictions < targets, less_than_target), targets,
        predictions)

    predictions = torch.where(
        torch.logical_and(predictions > targets, greater_than_target),
        targets,
        predictions,
    )

    return nn.functional.mse_loss(predictions, targets, reduction="none")
Example #15
0
    def compute_iou(self, pred: Tensor, occ_mask: Tensor, loss_dict: dict,
                    invalid_mask: Tensor):
        """
        compute IOU on occlusion

        :param pred: occlusion prediction [N,H,W]
        :param occ_mask: ground truth occlusion mask [N,H,W]
        :param loss_dict: dictionary of losses
        :param invalid_mask: invalid disparities (including occ and places without data), [N,H,W]
        """
        # threshold
        pred_mask = pred > 0.5

        # iou for occluded region
        inter_occ = torch.logical_and(pred_mask, occ_mask).sum()
        union_occ = torch.logical_or(
            torch.logical_and(pred_mask, ~invalid_mask), occ_mask).sum()

        # iou for non-occluded region
        inter_noc = torch.logical_and(~pred_mask, ~invalid_mask).sum()
        union_noc = torch.logical_or(torch.logical_and(~pred_mask, occ_mask),
                                     ~invalid_mask).sum()

        # aggregate
        loss_dict['iou'] = (inter_occ + inter_noc).float() / (union_occ +
                                                              union_noc)

        return
Example #16
0
def ctc_shrink(logits, pad, blk):
    """only count the first one for the repeat freams
    """
    device = logits.device
    B, T, V = logits.size()
    tokens = torch.argmax(logits, -1)
    # intermediate vars along time
    list_fires = []
    token_prev = torch.ones(B).to(device) * -1
    blk_batch = torch.ones(B).to(device) * blk
    pad_batch = torch.ones(B).to(device) * pad

    for t in range(T):
        token = tokens[:, t]
        fire_place = torch.logical_and(token != blk_batch, token != token_prev)
        fire_place = torch.logical_and(fire_place, token != pad_batch)
        list_fires.append(fire_place)
        token_prev = token

    fires = torch.stack(list_fires, 1)
    len_decode = fires.sum(-1)
    max_decode_len = len_decode.max()
    list_ls = []

    for b in range(B):
        l = logits[b, :, :].index_select(0, torch.where(fires[b])[0])
        pad_l = torch.zeros([max_decode_len - l.size(0), V]).to(device)
        list_ls.append(torch.cat([l, pad_l], 0))

    logits_shrunk = torch.stack(list_ls, 0)

    return logits_shrunk, len_decode
Example #17
0
def eval_part_full(gt, pred, per_instance=False, yaxis_only=False):
    pdiff = eval_part_model(gt, pred, yaxis_only=yaxis_only)
    pdiff.update({
        f'5deg5cm':
        torch.logical_and(pdiff['rdiff'] <= 5.0,
                          pdiff['tdiff'] <= 0.05).float()
    })
    pdiff.update({
        f'10deg10cm':
        torch.logical_and(pdiff['rdiff'] <= 10.0,
                          pdiff['tdiff'] <= 0.10).float()
    })
    pdiff = {
        f'{key}_{i}': pdiff[key][..., i]
        for key in pdiff for i in range(pdiff[key].shape[-1])
    }

    if per_instance:
        per_diff = deepcopy(pdiff)
    else:
        per_diff = {}

    pdiff = {key: torch.mean(value, dim=0) for key, value in pdiff.items()}

    return pdiff, per_diff
Example #18
0
def edgeTrainer(targeted, model, attacked_node, y_target, device):
    edge_weight0 = model.edge_weight.clone().detach()
    optimizer_params = setRequiresGrad(model)
    optimizer = torch.optim.SGD(optimizer_params, lr=0.01)

    train(model=model,
          targeted=targeted,
          attacked_nodes=attacked_node,
          y_targets=y_target,
          optimizer=optimizer)

    with torch.no_grad():
        diff = model.edge_weight - edge_weight0
        mask1 = torch.logical_and(edge_weight0 == 1, diff > 0).to(device)
        mask2 = torch.logical_and(edge_weight0 == 0, diff < 0).to(device)
        mask = torch.logical_or(mask1, mask2).to(device)
        diff[mask] = 0

        malicious_edge = torch.argmax(torch.abs(diff)).to(device)

        # return to old self and flip edge
        model.edge_weight.data = edge_weight0
        model.edge_weight.data[
            malicious_edge] = not model.edge_weight.data[malicious_edge]
    return model
Example #19
0
    def forward(self, b, s, y):

        r = torch.arange(b.shape[0], device=b.device)

        m_b = b.view(-1, 1) == b.view(1, -1)  # same batch id
        m_y = torch.logical_xor(y.view(-1, 1), y.view(1,
                                                      -1))  # different labels
        m_r = r.view(-1, 1) < r.view(1, -1)  # prevent duplicates
        m = torch.logical_and(torch.logical_and(m_b, m_y), m_r)

        if m.sum().item() == 0:
            raise EmptyBatchException

        mat_d = s.view(-1, 1) - s.view(1, -1)
        mat_y = y.view(-1, 1).repeat(1, y.shape[0])
        d = mat_d[m]
        z = mat_y[m].float()

        loss = nn.BCEWithLogitsLoss()(d, z)

        output = {}
        output['loss'] = loss
        output['logits'] = d.detach()
        output['labels'] = z.detach()

        return output
Example #20
0
    def create_attention_mask(self, bs, seq_len, windows, block_length,
                              attention_mask):
        ticker = torch.arange(seq_len)[None, :]
        b_t = ticker.reshape(1, windows, block_length)

        bq_t = b_t
        bq_k = self.look_around(b_t, block_length, self.window_size)

        # compute attn mask
        # this matches the original implem in mess-tensorflow
        # https://github.com/tensorflow/mesh/blob/8bd599a21bad01cef1300a8735c17306ce35db6e/mesh_tensorflow/transformer/attention.py#L805
        relative_position = bq_k.unsqueeze(-2) - bq_t.unsqueeze(-1)
        relative_position = relative_position.transpose(-1, -2)

        sequence_id = torch.ones(bs, seq_len)
        q_seq = sequence_id.reshape(-1, windows, block_length)
        m_seq = sequence_id.reshape(-1, windows, block_length)
        m_seq = self.look_around(m_seq, block_length, self.window_size)

        if attention_mask is not None:
            attention_mask = attention_mask.to(m_seq.device)
            attention_mask = attention_mask.reshape(-1, windows, block_length)
            attention_mask = self.look_around(attention_mask, block_length,
                                              self.window_size)
            m_seq *= attention_mask

        visible = torch.eq(q_seq.unsqueeze(-1),
                           m_seq.unsqueeze(-2)).transpose(-1, -2)
        visible = torch.logical_and(
            visible, torch.gt(relative_position, -self.window_size))
        mask = torch.logical_and(visible, torch.less_equal(
            relative_position, 0)).transpose(-1, -2).unsqueeze(2)
        return mask
Example #21
0
    def __call__(self, batch):
        input_ids, permuted = zip(*batch)
        permuted = torch.tensor(permuted, dtype=torch.float)
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self._pad_token_id)

        mask = torch.logical_and(
            torch.less(torch.rand(input_ids.shape), self._mask_prob),
            torch.not_equal(input_ids, self._pad_token_id))
        truly_mask = torch.less(torch.rand(input_ids.shape),
                                1 - self._random_prob)
        random_mask = torch.less(torch.rand(input_ids.shape), 0.5)

        labels = torch.where(mask, input_ids, TARGET_IDX)

        # masking some of the tokens
        input_ids = torch.where(torch.logical_and(mask, truly_mask),
                                self._mask_token_id, input_ids)

        # randomly changing other tokens
        input_ids = torch.where(
            torch.logical_and(
                mask,
                torch.logical_and(torch.logical_not(truly_mask), random_mask)),
            torch.randint_like(input_ids, low=5, high=self._vocab_size),
            input_ids)

        return input_ids, labels, permuted
def calc_IOU(pred, y, zero_division_safe=False):
    if zero_division_safe:
        return (torch.sum(torch.logical_and(pred, y)) /
                (torch.sum(torch.logical_or(pred, y)) + 1e-8)).item()
    else:
        return (torch.sum(torch.logical_and(pred, y)) /
                torch.sum(torch.logical_or(pred, y))).item()
Example #23
0
def thread_get_parts(f_lines, f_types, p_scan, gpu = '', percentage = 1, scan_idx = 0):
    global annotations
    try:
        thread_id = threading.get_ident()
        device = torch.device(get_gpu(thread_id, gpu))
        print (scan_idx, thread_id, device, len(annotations))

        # annotation_id,asset_id
        lines = coco_to_seg(f_lines, 'annotation_id,category_id')
        lines = torch.tensor(lines).to(device)

        # type_id
        types = coco_to_seg(f_types, 'category_id')
        types = torch.tensor(types).to(device)

        # stack lines and types and free memory
        lines = torch.cat((lines, types), axis = 3)

        face_id = -1
        for face in os.listdir(p_scan):
            filename = p_scan + face
            face_id = face_id + 1
            dface = load_gz(filename, (-1, 9)) # fijxyzdlp
            dface = torch.tensor(dface).to(device)
            if not dface.shape[0]: # what do we do?
                continue
            
            dface = dface[torch.randint(dface.shape[0], (int(dface.shape[0] * percentage / 100), ))]    # down sample it

            t1 = lines[dface[:,0].type(torch.LongTensor),dface[:,2].type(torch.LongTensor),dface[:,1].type(torch.LongTensor)]   # son of a bitch, ji ij who comes first?
            asset_ids = t1[torch.logical_and(~torch.eq(t1[:,1], -1.0), ~torch.eq(t1[:,2], -1.0))]  # both type id and asset id are valid
            if not asset_ids.shape[0]:
                continue
            
            unique_asset_ids = torch.unique(asset_ids[:, :3], dim=0)
            for annotation_id, asset_id, type_id in np.array(unique_asset_ids.cpu()):
                if type_id not in asset_types:
                    print('!!! Warning !!! Unknown type ', type_id)
                    continue

                mask = torch.logical_and(torch.eq(t1[:,0], annotation_id), torch.logical_and(torch.eq(t1[:,1], asset_id), torch.eq(t1[:,2], type_id)))
                curr = dface[mask]
                ones = torch.ones((curr.shape[0], 1), device = device)
                
                # exclusive access to queue
                mutex.acquire()

                center = np.mean(np.frombuffer(np.array(curr[:,[3,4,5]].cpu()), dtype=np.float32).reshape((-1,3)), axis=0)
                annotations[annotation_id] = center

                #curr = torch.cat((curr, ones * type_id, ones * scan_idx * 1000000 + curr[:, -1:]), dim = 1)
                curr = torch.cat((curr, ones * type_id, curr[:, -1:]), dim = 1) # fijxyzdlptp
                curr = np.array(curr.cpu(), dtype = np.int32)
                consumer_queue.put((curr, scan_idx, face_id, asset_id, type_id))
                mutex.release() 
    except:
        print("**************** Unexpected error on gpu: ", get_gpu(threading.get_ident(), gpu))
        print(sys.exc_info())
        traceback.print_exc()
Example #24
0
 def backward(ctx, grad_output):
     X, qmin, qmax = ctx.saved_tensors
     grad_input = grad_output.detach().clone()
     m0 = torch.logical_and(X<qmin, grad_input>0)
     m1 = torch.logical_and(X>qmax, grad_input<0)
     m = torch.logical_or(m0, m1)
     grad_input[m] = 0
     return grad_input, None, None, None, None, None
Example #25
0
    def disagreement_proc(self, labels_confi, la, lb, l_compare):
        """
        Function to implement disagreement procedure during the exchange of unlabeled data.

        """
        labels_confi = torch.logical_and(labels_confi, la == lb)
        labels_confi = torch.logical_and(labels_confi, la != l_compare)
        return torch.logical_and(labels_confi, lb != l_compare)
Example #26
0
def cal_Hamming(feat1, feat2, mask1=None, mask2=None):
    if mask1 is None or mask2 is None:
        mask1 = torch.ones_like(feat1).to(torch.bool)
        mask2 = torch.ones_like(feat2).to(torch.bool)
    mask = torch.logical_and(mask1, mask2)
    dist = torch.logical_and(torch.logical_xor(feat1, feat2), mask).to(
        torch.float).sum() / mask.to(torch.float).sum()
    return dist
Example #27
0
    def _get_error_edges_new_labels(self):
        """Gets error edges set and edge label for editing the model.

    Returns:
      (error edges set,
        edge prediction of old model,
        edge label for model editing,
        edge weights for cross-entropy loss in model editing).
    """
        self.model.eval()
        old_adj_pred, _ = self.model.pred_adj()
        old_adj_pred = old_adj_pred.detach().cpu()
        adj_gt = self.gt_adj_torch if self.gt_adj_torch.dim(
        ) == 3 else self.gt_adj_torch.unsqueeze(2)

        error_edge_idxs = []
        error_edge_gts = []
        no_con_adj_gt = 1.0 - adj_gt.sum(dim=2)
        no_con_adj_pred = 1.0 - old_adj_pred.sum(dim=2)
        error_edge_idxs.append(
            torch.nonzero(
                torch.logical_and(no_con_adj_gt != no_con_adj_pred,
                                  no_con_adj_gt == 1.0)))
        error_edge_gts.append(torch.tensor([0]))
        for i in range(self.args.num_relation_types - 1):
            error_edge_idxs.append(
                torch.nonzero(
                    torch.logical_and(adj_gt[:, :, i] != old_adj_pred[:, :, i],
                                      adj_gt[:, :, i] == 1.0)))
            error_edge_gts.append(torch.tensor([i + 1]))
        num_nodes = self.args.num_nodes
        adj_label = torch.cat(
            (no_con_adj_pred.unsqueeze(dim=2), old_adj_pred),
            dim=-1).argmax(dim=2).reshape(num_nodes * num_nodes)

        false_pos_idxs, false_neg_idxs = error_edge_idxs[0], error_edge_idxs[1]
        edit_edge_idxs = torch.cat((false_pos_idxs, false_neg_idxs), dim=0)
        edit_edge_idxs_one_dim1 = false_pos_idxs[:,
                                                 0] * num_nodes + false_pos_idxs[:,
                                                                                 1]
        adj_label[edit_edge_idxs_one_dim1] = 0
        edit_edge_idxs_one_dim2 = false_neg_idxs[:,
                                                 0] * num_nodes + false_neg_idxs[:,
                                                                                 1]
        adj_label[edit_edge_idxs_one_dim2] = 1

        weights = torch.ones(size=(len(adj_label), ), device=self.device) / (
            num_nodes * num_nodes - num_nodes - len(edit_edge_idxs))
        weights[edit_edge_idxs_one_dim1] = 1.0 / (
            len(edit_edge_idxs_one_dim1) + len(edit_edge_idxs_one_dim2))
        weights[edit_edge_idxs_one_dim2] = 1.0 / (
            len(edit_edge_idxs_one_dim1) + len(edit_edge_idxs_one_dim2))
        ignore_idxs = torch.arange(num_nodes) * num_nodes + torch.arange(
            num_nodes)
        weights[ignore_idxs] = 0.0

        return edit_edge_idxs, old_adj_pred, adj_label, weights
def calc_Specificity(pred, y, zero_division_safe=False):
    reject = (pred - 1)**2
    negatives = (y - 1)**2
    if zero_division_safe:
        return (torch.sum(torch.logical_and(reject, negatives)) /
                (torch.sum(negatives) + 1e-8)).item()
    else:
        return (torch.sum(torch.logical_and(reject, negatives)) /
                torch.sum(negatives)).item()
def relative_attention(q, rpe, v):
    dev = q.device
    clipping_dist = rpe.shape[0] - 1
    v_sum = v.sum(dim=-2)
    max_dist_enc = rpe[-1]
    max_dist = rpe.shape[0] - 1
    L = q.shape[-2]
    img_dim = int(math.sqrt(L))
    dim = q.shape[-1]
    batch_size = q.shape[0]
    heads = q.shape[1]

    x_diffs = torch.arange(-max_dist + 1, max_dist,
                           device=dev).repeat(max_dist * 2 - 1).expand(L, -1)
    y_diffs = torch.arange(-max_dist + 1, max_dist,
                           device=dev).repeat_interleave(max_dist * 2 -
                                                         1).expand(L, -1)
    x_pos = x_diffs + torch.cat([
        torch.arange(img_dim, device=dev).repeat(img_dim).unsqueeze(0).T,
        torch.zeros((1, 1), dtype=torch.long, device=dev)
    ])
    y_pos = y_diffs + torch.cat([
        torch.arange(img_dim,
                     device=dev).repeat_interleave(img_dim).unsqueeze(0).T,
        torch.zeros((1, 1), dtype=torch.long, device=dev) + img_dim
    ])
    diffs = torch.abs(x_diffs) + torch.abs(y_diffs)
    valid = torch.logical_and(
        torch.logical_and(torch.logical_and(x_pos >= 0, x_pos < img_dim),
                          torch.logical_and(y_pos >= 0, y_pos < img_dim)),
        diffs < max_dist)
    diffs[valid != True] = clipping_dist

    q_dot_rpe = q @ (rpe - max_dist_enc).T
    q_idx = torch.arange(0, diffs.shape[0],
                         device=dev).unsqueeze(1).expand_as(diffs)
    q_rel = q_dot_rpe[:, :, q_idx, diffs]
    q_rel[:, :, valid != True] = 0

    rel_window = (clipping_dist - 1) * img_dim + clipping_dist - 1
    v_padded = torch.zeros(
        (batch_size, heads, rel_window * 2 + L, v.shape[-1]), device=dev)
    v_padded[:, :, rel_window:rel_window + L, :] = v
    v_unfolded = v_padded.unfold(2, rel_window * 2 + 1,
                                 1).permute(0, 1, 2, 4, 3)

    q_max_dist = torch.einsum('...ij,j->...i', q, max_dist_enc)
    out = torch.einsum('...i,...j->...ij', q_max_dist, v_sum)
    local_dim = (clipping_dist * 2 - 1)
    for i in range(local_dim):
        out += torch.einsum('...ij,...ijk->...ik',
                            q_rel.narrow(-1, i * local_dim, local_dim),
                            v_unfolded.narrow(-2, img_dim * i, local_dim))
    D_inv = 1. / (q @ (max_dist_enc * L) + q_rel.sum(dim=-1))
    out = out * D_inv.unsqueeze(-1)
    return out
Example #30
0
 def forward(self,x,threshold):
   y = x.unsqueeze(1)
   y = self.avgpool(y)
   y = torch.squeeze(y)
   surrounded_by_high = torch.gt(y,self.high_threshold)
   surrounded_by_low = torch.logical_not(torch.gt(y,self.low_threshold))
   decrease = torch.logical_and(torch.gt(x,threshold), surrounded_by_low)
   increase = torch.logical_and(torch.logical_not(torch.gt(x,threshold)), surrounded_by_high)
   x = torch.clip(x+increase.long()-decrease.long(), min=0, max=1)
   return x