Beispiel #1
0
def sum_scan_exclusive(x, dim):
    ret = torch.cumsum(-x, dim=dim)

    end_idx = ret.size(dim) - 1
    ret_sum = ret.narrow(dim, end_idx, 1).clone()
    ret -= ret_sum.expand_as(ret)
    ret += x
    return ret
Beispiel #2
0
    def __init__(self, values, env_to_world = torch.eye(4, 4)):
        # Convert to constant texture if necessary
        if isinstance(values, torch.Tensor):
            values = pyredner.Texture(values)

        assert(values.texels.is_contiguous())
        assert(values.texels.dtype == torch.float32)
        if pyredner.get_use_gpu():
            assert(values.texels.is_cuda)
        else:
            assert(not values.texels.is_cuda)

        assert(env_to_world.dtype == torch.float32)

        # Build sampling table
        luminance = 0.212671 * values.texels[:, :, 0] + \
                    0.715160 * values.texels[:, :, 1] + \
                    0.072169 * values.texels[:, :, 2]
        # For each y, compute CDF over x
        sample_cdf_xs_ = torch.cumsum(luminance, dim = 1)
        y_weight = torch.sin(\
        	math.pi * (torch.arange(luminance.shape[0],
                dtype = torch.float32, device = luminance.device) + 0.5) \
             / float(luminance.shape[0]))
        # Compute CDF for x
        sample_cdf_ys_ = torch.cumsum(sample_cdf_xs_[:, -1] * y_weight, dim = 0)
        pdf_norm = (luminance.shape[0] * luminance.shape[1]) / \
        	(sample_cdf_ys_[-1].item() * (2 * math.pi * math.pi))
        # Normalize to [0, 1)
        sample_cdf_xs = (sample_cdf_xs_ - sample_cdf_xs_[:, 0:1]) / \
            torch.max(sample_cdf_xs_[:, (luminance.shape[1] - 1):luminance.shape[1]],
                1e-8 * torch.ones(sample_cdf_xs_.shape[0], 1, device = sample_cdf_ys_.device))
        sample_cdf_ys = (sample_cdf_ys_ - sample_cdf_ys_[0]) / \
            torch.max(sample_cdf_ys_[-1], torch.tensor([1e-8], device = sample_cdf_ys_.device))

        self.values = values
        self.env_to_world = env_to_world
        self.world_to_env = torch.inverse(env_to_world).contiguous()
        self.sample_cdf_ys = sample_cdf_ys.contiguous()
        self.sample_cdf_xs = sample_cdf_xs.contiguous()
        self.pdf_norm = pdf_norm
Beispiel #3
0
 def invert(self, head_coords, segment_len, mean_angle, eigenworms):
     
     angles = torch.matmul(eigenworms, self.eigen_components)
     angles += mean_angle.view(-1, 1)
     
     ske_x = torch.sin(angles).view(-1, self.n_angles, 1)
     ske_y = torch.cos(angles).view(-1, self.n_angles, 1)
     skels_n = torch.cat([ske_x, ske_y], 2)*segment_len.view(-1, 1, 1)
     
     skels_n = torch.cat([head_coords.view(-1, 1, 2),  skels_n], 1)
     skels_n = torch.cumsum(skels_n, dim=1) 
     
     return skels_n
def _h_eigenworms_inv_T(head_x, head_y, segment_l, 
                        mean_angle, eigenworms):
    '''
    Convert the eigen value transformed data into xy coordinates
    '''
    
    
    n_components = eigenworms.size(0)
    angles = torch.mm(eigenworms.view(1, -1), EIGENWORMS_COMPONENTS_T[:n_components])
    angles += mean_angle
    
    ske_x = torch.sin(angles)*segment_l
    ske_x = torch.cat([head_x.view(1,1),  ske_x], 1)
    ske_x = torch.cumsum(ske_x, dim=1) 
    
    ske_y = torch.cos(angles)*segment_l
    ske_y = torch.cat([head_y.view(1,1),  ske_y], 1)
    ske_y = torch.cumsum(ske_y, dim=1) 
    
    
    skels_n = torch.cat((ske_x.view(-1, 1), ske_y.view(-1, 1)), 1)
    
    return skels_n
Beispiel #5
0
def split(tensor, split_size_or_sections, dim=0):
    """Splits the tensor into chunks.
    If ``split_size_or_sections`` is an integer type, then ``tensor`` will be
    split into equally sized chunks (if possible).
    Last chunk will be smaller if the tensor size along a given dimension
    is not divisible by ``split_size``.
    If ``split_size_or_sections`` is a list, then ``tensor`` will be split
    into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according
    to ``split_size_or_sections``.

    Arguments:
        tensor (Tensor): tensor to split.
        split_size_or_sections (int) or (list(int)): size of a single chunk or
        list of sizes for each chunk
        dim (int): dimension along which to split the tensor.
    """
    if dim < 0:
        dim += tensor.dim()
    dim_size = tensor.size(dim)

    if isinstance(split_size_or_sections, int):
        split_size = split_size_or_sections
        num_splits = (dim_size + split_size - 1) // split_size
        last_split_size = split_size - (split_size * num_splits - dim_size)

        def get_split_size(i):
            return split_size if i < num_splits - 1 else last_split_size
        return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
                     in _range(0, num_splits))

    else:
        if dim_size != sum(split_size_or_sections):
            raise ValueError("Sum of split sizes exceeds tensor dim")
        split_indices = [0] + split_size_or_sections
        split_indices = torch.cumsum(torch.Tensor(split_indices), dim=0)

        return tuple(
            tensor.narrow(int(dim), int(start), int(length))
            for start, length in zip(split_indices, split_size_or_sections))
Beispiel #6
0
def cumsum(input, dim):
    return th.cumsum(input, dim=dim)
    def apply(cls, network_output, num_classes, anchors, conf_thresh):
        num_anchors = len(anchors)
        anchor_step = len(anchors[0])
        anchors = torch.Tensor(anchors)
        if isinstance(network_output, Variable):
            network_output = network_output.data

        # Check dimensions
        if network_output.dim() == 3:
            network_output.unsqueeze_(0)

        # Variables
        cuda = network_output.is_cuda
        batch = network_output.size(0)
        h = network_output.size(2)
        w = network_output.size(3)

        # Compute xc,yc, w,h, box_score on Tensor
        lin_x = torch.linspace(0, w-1, w).repeat(h, 1).view(h*w)
        lin_y = torch.linspace(0, h-1, h).repeat(w, 1).t().contiguous().view(h*w)
        anchor_w = anchors[:, 0].contiguous().view(1, num_anchors, 1)
        anchor_h = anchors[:, 1].contiguous().view(1, num_anchors, 1)
        if cuda:
            lin_x = lin_x.cuda()
            lin_y = lin_y.cuda()
            anchor_w = anchor_w.cuda()
            anchor_h = anchor_h.cuda()

        network_output = network_output.view(batch, num_anchors, -1, h*w)   # -1 == 5+num_classes (we can drop feature maps if 1 class)
        network_output[:, :, 0, :].sigmoid_().add_(lin_x).div_(w)           # X center
        network_output[:, :, 1, :].sigmoid_().add_(lin_y).div_(h)           # Y center
        network_output[:, :, 2, :].exp_().mul_(anchor_w).div_(w)            # Width
        network_output[:, :, 3, :].exp_().mul_(anchor_h).div_(h)            # Height
        network_output[:, :, 4, :].sigmoid_()                               # Box score

        conf_scores = network_output[:, :, 4, :] ## mileistone

        # Compute class_score
        if num_classes > 1:
            if torch.__version__.startswith('0.3'):
                cls_scores = torch.nn.functional.softmax(Variable(network_output[:, :, 5:, :], volatile=True), 2).data
            else:
                with torch.no_grad():
                    cls_scores = torch.nn.functional.softmax(network_output[:, :, 5:, :], 2) 
                    cls_scores = (cls_scores * conf_scores.unsqueeze(2).expand_as(cls_scores)).transpose(2,3)
                    cls_scores = cls_scores.contiguous().view(cls_scores.size(0), cls_scores.size(1), -1)
        else:
            cls_scores = network_output[:, :, 4, :]
            #cls_max = network_output[:, :, 4, :]
            #cls_max_idx = torch.zeros_like(cls_max)

        score_thresh = cls_scores > conf_thresh
        score_thresh_flat = score_thresh.view(-1)

        if score_thresh.sum() == 0:
            boxes = []
            for i in range(batch):
                boxes.append(torch.Tensor([]))
            return boxes

        # Mask select boxes > conf_thresh
        coords = network_output.transpose(2, 3)[..., 0:4]
        coords = coords.unsqueeze(3).expand(coords.size(0),coords.size(1),coords.size(2), 
                num_classes,coords.size(3)).contiguous().view(coords.size(0),coords.size(1),-1,coords.size(3))
        coords = coords[score_thresh[..., None].expand_as(coords)].view(-1, 4)
        scores = cls_scores[score_thresh].view(-1, 1)
        idx = (torch.arange(num_classes)).repeat(batch, num_anchors, w*h).cuda()
        idx = idx[score_thresh].view(-1, 1).float()
        detections = torch.cat([coords, scores, idx], dim=1)

        # Get indexes of splits between images of batch
        max_det_per_batch = num_anchors * h * w * num_classes
        slices = [slice(max_det_per_batch * i, max_det_per_batch * (i+1)) for i in range(batch)]
        det_per_batch = torch.IntTensor([score_thresh_flat[s].int().sum() for s in slices])
        split_idx = torch.cumsum(det_per_batch, dim=0)

        # Group detections per image of batch
        boxes = []
        start = 0
        for end in split_idx:
            boxes.append(detections[start: end])
            start = end

        return boxes
Beispiel #8
0
 def fn(x):
     y = torch.ones(x.shape[0], dtype=torch.long, device=x.device)
     return torch.cumsum(y, 0) - 1
Beispiel #9
0
    def projection_linf(self, points_to_project, w_hyperplane, b_hyperplane):
        t = points_to_project.clone()
        w = w_hyperplane.clone()
        b = b_hyperplane.clone()

        ind2 = ((w * t).sum(1) - b < 0).nonzero().squeeze()
        ind2 = self.check_shape(ind2)
        w[ind2] *= -1
        b[ind2] *= -1

        c5 = (w < 0).float()
        a = torch.ones(t.shape).to(self.device)
        d = (a * c5 - t) * (w != 0).float()
        a -= a * (1 - c5)

        p = torch.ones(t.shape).to(self.device) * c5 - t * (2 * c5 - 1)
        indp = torch.argsort(p, dim=1)

        b = b - (w * t).sum(1)
        b0 = (w * d).sum(1)
        b1 = b0.clone()

        counter = 0
        indp2 = indp.unsqueeze(-1).flip(dims=(1, 2)).squeeze()
        u = torch.arange(0, w.shape[0])
        ws = w[u.unsqueeze(1), indp2]
        bs2 = -ws * d[u.unsqueeze(1), indp2]

        s = torch.cumsum(ws.abs(), dim=1)
        sb = torch.cumsum(bs2, dim=1) + b0.unsqueeze(1)

        c = b - b1 > 0
        b2 = sb[u, -1] - s[u, -1] * p[u, indp[u, 0]]
        c_l = (b - b2 > 0).nonzero().squeeze()
        c2 = ((b - b1 > 0) * (b - b2 <= 0)).nonzero().squeeze()
        c_l = self.check_shape(c_l)
        c2 = self.check_shape(c2)

        lb = torch.zeros(c2.shape[0])
        ub = torch.ones(c2.shape[0]) * (w.shape[1] - 1)
        nitermax = torch.ceil(torch.log2(torch.tensor(w.shape[1]).float()))
        counter2 = torch.zeros(lb.shape).long()

        while counter < nitermax:
            counter4 = torch.floor((lb + ub) / 2)
            counter2 = counter4.long()
            indcurr = indp[c2, -counter2 - 1]
            b2 = sb[c2, counter2] - s[c2, counter2] * p[c2, indcurr]
            c = b[c2] - b2 > 0
            ind3 = c.nonzero().squeeze()
            ind32 = (~c).nonzero().squeeze()
            ind3 = self.check_shape(ind3)
            ind32 = self.check_shape(ind32)
            lb[ind3] = counter4[ind3]
            ub[ind32] = counter4[ind32]
            counter += 1

        lb = lb.long()
        counter2 = 0

        if c_l.nelement != 0:
            lmbd_opt = (torch.max(
                (b[c_l] - sb[c_l, -1]) / (-s[c_l, -1]),
                torch.zeros(sb[c_l, -1].shape).to(self.device))).unsqueeze(-1)
            d[c_l] = (2 * a[c_l] - 1) * lmbd_opt

        lmbd_opt = (torch.max(
            (b[c2] - sb[c2, lb]) / (-s[c2, lb]),
            torch.zeros(sb[c2, lb].shape).to(self.device))).unsqueeze(-1)
        d[c2] = torch.min(lmbd_opt, d[c2]) * c5[c2]\
            + torch.max(-lmbd_opt, d[c2]) * (1 - c5[c2])

        return d * (w != 0).float()
Beispiel #10
0
    def _run_one_fw(self,
                    pixel_model,
                    pixel_inp,
                    cat_var,
                    target,
                    base_eps,
                    avoid_target=True):
        batch_size, channels, height, width = pixel_inp.size()
        pixel_inp_jpeg = self._jpeg_cat(pixel_inp, cat_var, base_eps,
                                        batch_size, height, width)
        s = pixel_model(pixel_inp_jpeg)

        for it in range(self.nb_its):
            loss = self.criterion(s, target)
            loss.backward()

            if avoid_target:
                grad = cat_var.grad.data
            else:
                grad = -cat_var.grad.data

            def where_float(cond, if_true, if_false):
                return cond.float() * if_true + (1 - cond.float()) * if_false

            def where_long(cond, if_true, if_false):
                return cond.long() * if_true + (1 - cond.long()) * if_false

            abs_grad = torch.abs(grad).view(batch_size, -1)
            num_pixels = abs_grad.size()[1]
            sign_grad = torch.sign(grad)

            bound = where_float(sign_grad > 0, self.l1_max - cat_var,
                                cat_var + self.l1_max).view(batch_size, -1)

            k_min = torch.zeros((batch_size, 1),
                                dtype=torch.long,
                                requires_grad=False,
                                device='cuda')
            k_max = torch.ones((batch_size, 1),
                               dtype=torch.long,
                               requires_grad=False,
                               device='cuda') * num_pixels

            # cum_bnd[k] is meant to track the L1 norm we end up with if we take
            # the k indices with the largest gradient magnitude and push them to their boundary values (0 or 255)
            values, indices = torch.sort(abs_grad, descending=True)
            bnd = torch.gather(bound, 1, indices)
            # subtract bnd because we don't want the cumsum to include the final element
            cum_bnd = torch.cumsum(bnd, 1) - bnd

            # this is hard-coded as floor(log_2(256 * 256 * 3))
            for _ in range(17):
                k_mid = (k_min + k_max) // 2
                l1norms = torch.gather(cum_bnd, 1, k_mid)
                k_min = where_long(l1norms > base_eps, k_min, k_mid)
                k_max = where_long(l1norms > base_eps, k_mid, k_max)

            # next want to set the gradient of indices[0:k_min] to their corresponding bound
            magnitudes = torch.zeros((batch_size, num_pixels),
                                     requires_grad=False,
                                     device='cuda')
            for bi in range(batch_size):
                magnitudes[bi, indices[bi, :k_min[bi, 0]]] = bnd[bi, :k_min[bi,
                                                                            0]]
                magnitudes[bi, indices[bi, k_min[
                    bi, 0]]] = base_eps[bi] - cum_bnd[bi, k_min[bi, 0]]

            delta_it = sign_grad * magnitudes.view(cat_var.size())
            # These should always be exactly epsilon
            # l1_check = torch.norm(delta_it.view(batch_size, -1), 1.0, dim=1) / num_pixels
            # print('l1_check: %s' % l1_check)
            cat_var.data = cat_var.data + (delta_it - cat_var.data) / (it +
                                                                       1.0)

            if it != self.nb_its - 1:
                # self.jpeg scales rounding_vars by base_eps, so we divide to rescale
                # its coordinates to [-1, 1]
                cat_var_temp = cat_var / base_eps[:, None]
                pixel_inp_jpeg = self._jpeg_cat(pixel_inp, cat_var_temp,
                                                base_eps, batch_size, height,
                                                width)
                s = pixel_model(pixel_inp_jpeg)
            cat_var.grad.data.zero_()
        return cat_var
Beispiel #11
0
model = IRIM(step_models,grad_fun,im_channels)
# Wrap the model to be trained with invert to learn
model = MemoryFreeInvertibleModule(model)
model.to(device)
# Use DataParallel if multiple devices are available
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), learning_rate)

# We generate a simple toy data set where the ground truth data has the same values in the image
# dimensions but different values across batch and channel dimensions. This demonstrates that the
# IRIM can deal with the implicit structure in the data, with a high range of values, and it can even
# do extrapolation.
x = torch.ones(n_samples,im_channels,*[im_size]*conv_nd, requires_grad=False, device=device)
x = torch.cumsum(x,0)
x = torch.cumsum(x,1)
y = x + torch.randn_like(x)

# Training and test split. This will result un an extrapolation problem on the test set.
y, y_test = torch.chunk(y,2,0)
x, x_test = torch.chunk(x,2,0)

# Initial states of the IRIM
x_in = torch.cat((y,torch.zeros(y.size(0),n_channels[0]-im_channels,*[im_size]*conv_nd, device=device)),1)
x_test_in = torch.cat((y_test,torch.zeros(y_test.size(0),n_channels[0]-im_channels,*[im_size]*conv_nd, device=device)),1)
x_in.requires_grad_(True)
x_test_in.requires_grad_(False)

for i in range(3000):
    optimizer.zero_grad()
    def add_whole_word_mask(self, source, p):
        is_word_start = self.word_starts(source)
        num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
        num_inserts = 0
        if num_to_mask == 0:
            return source

        if self.mask_span_distribution is not None:
            lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))

            # Make sure we have enough to mask
            cum_length = torch.cumsum(lengths, 0)
            while cum_length[-1] < num_to_mask:
                lengths = torch.cat([lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,))], dim=0)
                cum_length = torch.cumsum(lengths, 0)

            # Trim to masking budget
            i = 0
            while cum_length[i] < num_to_mask:
                i += 1
            lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
            num_to_mask = i + 1
            lengths = lengths[:num_to_mask]

            # Handle 0-length mask (inserts) separately
            lengths = lengths[lengths > 0]
            num_inserts = num_to_mask - lengths.size(0)
            num_to_mask -= num_inserts
            if num_to_mask == 0:
                return self.add_insertion_noise(source, num_inserts / source.size(0))

            assert (lengths > 0).all()
        else:
            lengths = torch.ones((num_to_mask,)).long()
        assert is_word_start[-1] == 0
        word_starts = is_word_start.nonzero()
        indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1)
        mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio

        source_length = source.size(0)
        assert source_length - 1 not in indices
        to_keep = torch.ones(source_length, dtype=torch.bool)
        is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc
        if self.replace_length == 0:
            to_keep[indices] = 0
        else:
            # keep index, but replace it with [MASK]
            source[indices] = self.mask_idx
            source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))

        if self.mask_span_distribution is not None:
            assert len(lengths.size()) == 1
            assert lengths.size() == indices.size()
            lengths -= 1
            while indices.size(0) > 0:
                assert lengths.size() == indices.size()
                lengths -= is_word_start[indices + 1].long()
                uncompleted = lengths >= 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                lengths = lengths[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    source[indices] = self.mask_idx
                    source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
        else:
            # A bit faster when all lengths are 1
            while indices.size(0) > 0:
                uncompleted = is_word_start[indices + 1] == 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    source[indices] = self.mask_idx
                    source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))

                assert source_length - 1 not in indices

        source = source[to_keep]

        if num_inserts > 0:
            source = self.add_insertion_noise(source, num_inserts / source.size(0))

        return source
Beispiel #13
0
 def total_tests(self, table):
     """Returns a tensor of shape [num_regions, num_days]."""
     return torch.cumsum(self.new_tests(table), dim=1)
def mean_average_precision(pred_boxes,
                           true_boxes,
                           iou_threshold=0.5,
                           box_format="midpoint",
                           num_classes=20):
    """
    Calculates mean average precision 

    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes
        specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
        true_boxes (list): Similar as pred_boxes except all the correct ones 
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes

    Returns:
        float: mAP value across all classes given a specific IoU threshold 
    """

    # list storing all AP for respective classes
    average_precisions = []

    # used for numerical stability later on
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # Go through all predictions and targets,
        # and only add the ones that belong to the
        # current class c
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # find the amount of bboxes for each training example
        # Counter here finds how many ground truth bboxes we get
        # for each training example, so let's say img 0 has 3,
        # img 1 has 5 then we will obtain a dictionary with:
        # amount_bboxes = {0:3, 1:5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        # We then go through each key, val in this dictionary
        # and convert to the following (w.r.t same example):
        # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # sort by box probabilities which is index 2
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        # If none exists for this class then we can safely skip
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # Only take out the ground_truths that have the same
            # training idx as detection
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # only detect ground truth detection once
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    # true positive and add this bounding box to seen
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1

            # if IOU is lower then the detection is a false positive
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        # torch.trapz for numerical integration
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)
    def test_offset_verts(self):
        def naive_offset_verts(mesh, vert_offsets_packed):
            # new Meshes class
            new_verts_packed = mesh.verts_packed() + vert_offsets_packed
            new_verts_list = list(
                new_verts_packed.split(mesh.num_verts_per_mesh().tolist(), 0))
            new_faces_list = [f.clone() for f in mesh.faces_list()]
            return Meshes(verts=new_verts_list, faces=new_faces_list)

        N = 5
        mesh = TestMeshes.init_mesh(N, 10, 100)
        all_v = mesh.verts_packed().size(0)
        verts_per_mesh = mesh.num_verts_per_mesh()
        for force in [0, 1]:
            if force:
                # force mesh to have computed attributes
                mesh._compute_packed(refresh=True)
                mesh._compute_padded()
                mesh._compute_edges_packed()
                mesh.verts_padded_to_packed_idx()
                mesh._compute_face_areas_normals(refresh=True)
                mesh._compute_vertex_normals(refresh=True)

            deform = torch.rand((all_v, 3),
                                dtype=torch.float32,
                                device=mesh.device)
            # new meshes class to hold the deformed mesh
            new_mesh_naive = naive_offset_verts(mesh, deform)

            new_mesh = mesh.offset_verts(deform)

            # check verts_list & faces_list
            verts_cumsum = torch.cumsum(verts_per_mesh, 0).tolist()
            verts_cumsum.insert(0, 0)
            for i in range(N):
                self.assertClose(
                    new_mesh.verts_list()[i],
                    mesh.verts_list()[i] +
                    deform[verts_cumsum[i]:verts_cumsum[i + 1]],
                )
                self.assertClose(new_mesh.verts_list()[i],
                                 new_mesh_naive.verts_list()[i])
                self.assertClose(mesh.faces_list()[i],
                                 new_mesh_naive.faces_list()[i])
                self.assertClose(new_mesh.faces_list()[i],
                                 new_mesh_naive.faces_list()[i])
                # check faces and vertex normals
                self.assertClose(
                    new_mesh.verts_normals_list()[i],
                    new_mesh_naive.verts_normals_list()[i],
                )
                self.assertClose(
                    new_mesh.faces_normals_list()[i],
                    new_mesh_naive.faces_normals_list()[i],
                )

            # check padded & packed
            self.assertClose(new_mesh.faces_padded(),
                             new_mesh_naive.faces_padded())
            self.assertClose(new_mesh.verts_padded(),
                             new_mesh_naive.verts_padded())
            self.assertClose(new_mesh.faces_packed(),
                             new_mesh_naive.faces_packed())
            self.assertClose(new_mesh.verts_packed(),
                             new_mesh_naive.verts_packed())
            self.assertClose(new_mesh.edges_packed(),
                             new_mesh_naive.edges_packed())
            self.assertClose(
                new_mesh.verts_packed_to_mesh_idx(),
                new_mesh_naive.verts_packed_to_mesh_idx(),
            )
            self.assertClose(
                new_mesh.mesh_to_verts_packed_first_idx(),
                new_mesh_naive.mesh_to_verts_packed_first_idx(),
            )
            self.assertClose(new_mesh.num_verts_per_mesh(),
                             new_mesh_naive.num_verts_per_mesh())
            self.assertClose(
                new_mesh.faces_packed_to_mesh_idx(),
                new_mesh_naive.faces_packed_to_mesh_idx(),
            )
            self.assertClose(
                new_mesh.mesh_to_faces_packed_first_idx(),
                new_mesh_naive.mesh_to_faces_packed_first_idx(),
            )
            self.assertClose(new_mesh.num_faces_per_mesh(),
                             new_mesh_naive.num_faces_per_mesh())
            self.assertClose(
                new_mesh.edges_packed_to_mesh_idx(),
                new_mesh_naive.edges_packed_to_mesh_idx(),
            )
            self.assertClose(
                new_mesh.verts_padded_to_packed_idx(),
                new_mesh_naive.verts_padded_to_packed_idx(),
            )
            self.assertTrue(all(new_mesh.valid == new_mesh_naive.valid))
            self.assertTrue(new_mesh.equisized == new_mesh_naive.equisized)

            # check face areas, normals and vertex normals
            self.assertClose(new_mesh.verts_normals_packed(),
                             new_mesh_naive.verts_normals_packed())
            self.assertClose(new_mesh.verts_normals_padded(),
                             new_mesh_naive.verts_normals_padded())
            self.assertClose(new_mesh.faces_normals_packed(),
                             new_mesh_naive.faces_normals_packed())
            self.assertClose(new_mesh.faces_normals_padded(),
                             new_mesh_naive.faces_normals_padded())
            self.assertClose(new_mesh.faces_areas_packed(),
                             new_mesh_naive.faces_areas_packed())
            self.assertClose(
                new_mesh.mesh_to_edges_packed_first_idx(),
                new_mesh_naive.mesh_to_edges_packed_first_idx(),
            )
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bbox_list,
             gt_label_list,
             img_metas,
             gt_bboxes_ignore=None):
        assert len(cls_scores) == len(bbox_preds)

        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                 bbox_preds[0].device)
        num_imgs = cls_scores[0].size(0)
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_labels, flatten_bbox_targets = self.get_targets(
            gt_bbox_list, gt_label_list, featmap_sizes, points)
        flatten_bbox_targets = self.get_bbox_targets(flatten_bbox_targets,
                                                     featmap_sizes, points,
                                                     num_imgs)
        flatten_bbox_preds = self.get_bbox_targets(flatten_bbox_preds,
                                                   featmap_sizes, points,
                                                   num_imgs)

        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        pos_inds = (
            (flatten_labels >= 0)
            & (flatten_labels < self.background_label)).nonzero().view(-1)
        num_pos = len(pos_inds)

        if num_pos > 0:
            pos_bbox_preds = flatten_bbox_preds[pos_inds]
            pos_bbox_targets = flatten_bbox_targets[pos_inds]

            pos_weights = pos_bbox_targets.new_zeros(
                pos_bbox_targets.size()) + 1.0

            loss_bbox = self.loss_bbox(pos_bbox_preds, pos_bbox_targets,
                                       pos_weights)

            flat_labels = vectorize_labels(flatten_labels, self.num_classes)
            flat_preds = flatten_cls_scores.reshape(-1)
            loss_cls, rank, order = self.aLRP_Loss.apply(
                flat_preds, flat_labels, loss_bbox)

            #Order the regression losses considering the scores.
            ordered_losses_bbox = loss_bbox[order.detach()].flip(dims=[0])

            #Compute aLRP Regression Loss
            loss_bbox = ((torch.cumsum(ordered_losses_bbox, dim=0) /
                          rank[order.detach()].detach().flip(dims=[0])).mean())

            self.cls_LRP_hist.append(float(loss_cls.item()))
            self.reg_LRP_hist.append(float(loss_bbox.item()))
            self.counter += 1

            if self.counter == self.period:
                self.SB_weight = (np.mean(self.reg_LRP_hist) + np.mean(
                    self.cls_LRP_hist)) / np.mean(self.reg_LRP_hist)
                self.cls_LRP_hist.clear()
                self.reg_LRP_hist.clear()
                self.counter = 0

            loss_bbox *= self.SB_weight

        else:
            loss_cls = 0 * torch.cat(flatten_cls_scores)
            loss_bbox = 0 * torch.cat(flatten_bbox_preds)

        return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
Beispiel #17
0
def rational_linear_spline(inputs,
                           unnormalized_widths,
                           unnormalized_heights,
                           unnormalized_derivatives,
                           unnormalized_lambdas,
                           inverse=False,
                           left=0., right=1., bottom=0., top=1.,
                           min_bin_width=DEFAULT_MIN_BIN_WIDTH,
                           min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
                           min_derivative=DEFAULT_MIN_DERIVATIVE):

    if torch.min(inputs) < left or torch.max(inputs) > right:
        raise transforms.InputOutsideDomain()

    num_bins = unnormalized_widths.shape[-1]

    if min_bin_width * num_bins > 1.0:
        raise ValueError('Minimal bin width too large for the number of bins')
    if min_bin_height * num_bins > 1.0:
        raise ValueError('Minimal bin height too large for the number of bins')

    widths = F.softmax(unnormalized_widths, dim=-1)
    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

    cumwidths = torch.cumsum(widths, dim=-1)
    cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
    cumwidths = (right - left) * cumwidths + left

    cumwidths[..., 0] = left
    cumwidths[..., -1] = right
    widths = cumwidths[..., 1:] - cumwidths[..., :-1]

    derivatives = min_derivative + F.softplus(unnormalized_derivatives)

    heights = F.softmax(unnormalized_heights, dim=-1)
    heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
    cumheights = torch.cumsum(heights, dim=-1)
    cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
    cumheights = (top - bottom) * cumheights + bottom
    cumheights[..., 0] = bottom
    cumheights[..., -1] = top
    heights = cumheights[..., 1:] - cumheights[..., :-1]

    if inverse:
        bin_idx = utils.searchsorted(cumheights, inputs)[..., None]
    else:
        bin_idx = utils.searchsorted(cumwidths, inputs)[..., None]

    input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
    input_bin_widths = widths.gather(-1, bin_idx)[..., 0]

    input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
    delta = heights / widths
    input_delta = delta.gather(-1, bin_idx)[..., 0]

    input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
    input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]

    input_heights = heights.gather(-1, bin_idx)[..., 0]

    lambdas = 0.95 * torch.sigmoid(unnormalized_lambdas) + 0.025

    lam = lambdas.gather(-1, bin_idx)[..., 0]
    wa  = 1
    wb  = torch.sqrt(input_derivatives/input_derivatives_plus_one) * wa
    wc  = (lam * wa * input_derivatives + (1-lam) * wb * input_derivatives_plus_one)/input_delta
    ya  = input_cumheights
    yb  = input_heights + input_cumheights
    yc  = ((1-lam) * wa * ya + lam * wb * yb)/((1-lam) * wa + lam * wb)

    if inverse:

        numerator = (lam * wa * (ya - inputs)) * (inputs <= yc).float() \
                  +  ((wc - lam * wb) * inputs + lam * wb * yb - wc * yc) * (inputs > yc).float()

        denominator = ((wc - wa) * inputs + wa * ya - wc * yc) * (inputs <= yc).float()\
                    + ((wc - wb) * inputs + wb * yb - wc * yc) * (inputs > yc).float()

        theta = numerator/denominator

        outputs = theta * input_bin_widths + input_cumwidths

        derivative_numerator = (wa * wc * lam * (yc - ya) * (inputs <= yc).float()\
                             + wb * wc * (1 - lam) * (yb - yc) * (inputs > yc).float())*input_bin_widths

        logabsdet = torch.log(derivative_numerator) - 2 * torch.log(abs(denominator))

        return outputs, logabsdet
    else:

        theta = (inputs - input_cumwidths) / input_bin_widths

        numerator = (wa * ya * (lam - theta) + wc * yc * theta) * (theta <= lam).float()\
                  + (wc * yc * (1 - theta) + wb * yb * (theta - lam)) * (theta > lam).float()

        denominator = (wa * (lam - theta) + wc * theta) * (theta <= lam).float()\
                    + (wc * (1 - theta) + wb * (theta - lam)) * (theta > lam).float()

        outputs = numerator / denominator

        derivative_numerator = (wa * wc * lam * (yc - ya) * (theta <= lam).float()\
                             + wb * wc * (1 - lam) * (yb - yc) * (theta > lam).float())/input_bin_widths

        logabsdet = torch.log(derivative_numerator) - 2 * torch.log(abs(denominator))

        return outputs, logabsdet
Beispiel #18
0
 def model(u_in):
     y_lin_1 = G1(u_in)
     v_hat = F1(y_lin_1)
     v_hat = G2(v_hat)
     y_hat = torch.cumsum(v_hat, dim=1) * ts
     return y_hat, v_hat
Beispiel #19
0
def qr(a, tiles_per_proc=1, calc_q=True, overwrite_a=False):
    """

    Calculates the QR decomposition of a 2D DNDarray.
    Factor the matrix `a` as *qr*, where `q` is orthonormal and `r` is upper-triangular.

    Parameters
    ----------
    a : DNDarray
        DNDarray which will be decomposed
    tiles_per_proc : int, singlt element torch.Tensor
        optional, default: 1
        number of tiles per process to operate on
    calc_q : bool
        optional, default: True
        whether or not to calculate Q
        if True, function returns (Q, R)
        if False, function returns (None, R)
    overwrite_a : bool
        optional, default: False
        if True, function overwrites the DNDarray a, with R
        if False, a new array will be created for R

    Returns
    -------
    namedtuple of Q and R
        if calc_q == True, function returns QR(Q=Q, R=R)
        if calc_q == False, function returns QR(Q=None, R=R)

    Notes
    -----
    This function is built on top of PyTorch's QR function. torch.qr() using LAPACK on the backend.
    Basic information about QR factorization/decomposition can be found at
    https://en.wikipedia.org/wiki/QR_factorization

    The algorithms are based on the CAQR and TSQRalgorithms. For more information see references.

    References
    ----------
    [0]  W. Zheng, F. Song, L. Lin, and Z. Chen, “Scaling Up Parallel Computation of Tiled QR
            Factorizations by a Distributed Scheduling Runtime System and Analytical Modeling,”
            Parallel Processing Letters, vol. 28, no. 01, p. 1850004, 2018.
    [1] Bilel Hadri, Hatem Ltaief, Emmanuel Agullo, Jack Dongarra. Tile QR Factorization with
            Parallel Panel Processing for Multicore Architectures. 24th IEEE International Parallel
            and DistributedProcessing Symposium (IPDPS 2010), Apr 2010, Atlanta, United States.
            inria-00548899
    [2] Gene H. Golub and Charles F. Van Loan. 1996. Matrix Computations (3rd Ed.).

    Examples
    --------
    >>> a = ht.random.randn(9, 6, split=0)
    >>> qr = ht.linalg.qr(a)
    >>> print(ht.allclose(a, ht.dot(qr.Q, qr.R)))
    [0/1] True
    [1/1] True
    >>> st = torch.randn(9, 6)
    >>> a = ht.array(st, split=1)
    >>> a_comp = ht.array(st, split=0)
    >>> q, r = ht.linalg.qr(a)
    >>> print(ht.allclose(a_comp, ht.dot(q, r)))
    [0/1] True
    [1/1] True
    """
    if not isinstance(a, dndarray.DNDarray):
        raise TypeError("'a' must be a DNDarray")
    if not isinstance(tiles_per_proc, (int, torch.Tensor)):
        raise TypeError("tiles_per_proc must be an int or a torch.Tensor, "
                        "currently {}".format(type(tiles_per_proc)))
    if not isinstance(calc_q, bool):
        raise TypeError("calc_q must be a bool, currently {}".format(
            type(calc_q)))
    if not isinstance(overwrite_a, bool):
        raise TypeError("overwrite_a must be a bool, currently {}".format(
            type(overwrite_a)))
    if isinstance(tiles_per_proc, torch.Tensor):
        raise ValueError(
            "tiles_per_proc must be a single element torch.Tenor or int, "
            "currently has {} entries".format(tiles_per_proc.numel()))
    if len(a.shape) != 2:
        raise ValueError("Array 'a' must be 2 dimensional")

    QR = collections.namedtuple("QR", "Q, R")

    if a.split is None:
        q, r = a._DNDarray__array.qr(some=False)
        q = factories.array(q, device=a.device)
        r = factories.array(r, device=a.device)
        ret = QR(q if calc_q else None, r)
        return ret
    # =============================== Prep work ====================================================
    r = a if overwrite_a else a.copy()
    # r.create_square_diag_tiles(tiles_per_proc=tiles_per_proc)
    r_tiles = tiling.SquareDiagTiles(arr=r, tiles_per_proc=tiles_per_proc)
    tile_columns = r_tiles.tile_columns
    tile_rows = r_tiles.tile_rows
    if calc_q:
        q = factories.eye((r.gshape[0], r.gshape[0]),
                          split=0,
                          dtype=r.dtype,
                          comm=r.comm,
                          device=r.device)
        q_tiles = tiling.SquareDiagTiles(arr=q, tiles_per_proc=tiles_per_proc)
        q_tiles.match_tiles(r_tiles)
    else:
        q, q_tiles = None, None
    # ==============================================================================================

    if a.split == 0:
        rank = r.comm.rank
        active_procs = torch.arange(r.comm.size, device=r.device.torch_device)
        empties = torch.nonzero(input=r_tiles.lshape_map[..., 0] == 0,
                                as_tuple=False)
        empties = empties[0] if empties.numel() > 0 else []
        for e in empties:
            active_procs = active_procs[active_procs != e]
        tile_rows_per_pr_trmd = r_tiles.tile_rows_per_process[:active_procs[-1]
                                                              + 1]

        q_dict = {}
        q_dict_waits = {}
        proc_tile_start = torch.cumsum(torch.tensor(
            tile_rows_per_pr_trmd, device=r.device.torch_device),
                                       dim=0)
        # ------------------------------------ R Calculation ---------------------------------------
        for col in range(
                tile_columns
        ):  # for each tile column (need to do the last rank separately)
            # for each process need to do local qr
            not_completed_processes = torch.nonzero(
                input=col < proc_tile_start, as_tuple=False).flatten()
            if rank not in not_completed_processes or rank not in active_procs:
                # if the process is done calculating R the break the loop
                break
            diag_process = not_completed_processes[0]
            __split0_r_calc(
                r_tiles=r_tiles,
                q_dict=q_dict,
                q_dict_waits=q_dict_waits,
                col_num=col,
                diag_pr=diag_process,
                not_completed_prs=not_completed_processes,
            )
        # ------------------------------------- Q Calculation --------------------------------------
        if calc_q:
            for col in range(tile_columns):
                __split0_q_loop(
                    col=col,
                    r_tiles=r_tiles,
                    proc_tile_start=proc_tile_start,
                    active_procs=active_procs,
                    q0_tiles=q_tiles,
                    q_dict=q_dict,
                    q_dict_waits=q_dict_waits,
                )
    elif a.split == 1:
        # loop over the tile columns
        lp_cols = tile_columns if a.gshape[0] > a.gshape[1] else tile_rows
        for dcol in range(lp_cols):  # dcol is the diagonal column
            __split1_qr_loop(dcol=dcol,
                             r_tiles=r_tiles,
                             q0_tiles=q_tiles,
                             calc_q=calc_q)

    r.balance_()
    if q is not None:
        q.balance_()

    ret = QR(q, r)
    return ret
Beispiel #20
0
def __split1_qr_loop(dcol, r_tiles, q0_tiles, calc_q):
    """

    Helper function to do the QR factorization of the column 'dcol'. This function assumes that the
    target tile is at (dcol, dcol). This is the standard case at it assumes that the diagonal tile
    holds the diagonal entries of the matrix.

    Parameters
    ----------
    dcol : int
        column of the diagonal process
    r_tiles : tiling.SquareDiagTiles
        input matrix tiles to QR,
        if copy is true in QR then it is a copy of the data, else it is the same as the input
    q0_tiles : tiling.SquareDiagTiles
        the Q matrix tiles as created in the QR function.
    calc_q : Boolean
        Flag for weather to calculate Q or not, if False, then Q=None

    Returns
    -------
    None
    """
    r_torch_device = r_tiles.arr._DNDarray__array.device
    q0_torch_device = q0_tiles.arr._DNDarray__array.device if calc_q else None
    # ==================================== R Calculation - single tile =========================
    # loop over each column, need to do the QR for each tile in the column(should be rows)
    # need to get the diagonal process
    rank = r_tiles.arr.comm.rank
    cols_on_proc = torch.cumsum(torch.tensor(r_tiles.tile_columns_per_process,
                                             device=r_torch_device),
                                dim=0)
    not_completed_processes = torch.nonzero(input=dcol < cols_on_proc,
                                            as_tuple=False).flatten()
    diag_process = not_completed_processes[0].item()
    tile_rows = r_tiles.tile_rows
    # get the diagonal tile and do qr on it
    # send q to the other processes
    # 1st qr: only on diagonal tile + apply to the row
    if rank == diag_process:
        # do qr on diagonal process
        q1, r1 = r_tiles[dcol, dcol].qr(some=False)
        r_tiles.arr.comm.Bcast(q1.clone(), root=diag_process)
        r_tiles[dcol, dcol] = r1
        # apply q1 to the trailing matrix (other processes)

        # need to convert dcol to a local index
        loc_col = dcol - sum(r_tiles.tile_columns_per_process[:rank])
        hold = r_tiles.local_get(key=(dcol, slice(loc_col + 1, None)))
        if hold is not None:  # if there is more data on that row after the diagonal tile
            r_tiles.local_set(key=(dcol, slice(loc_col + 1, None)),
                              value=torch.matmul(q1.T, hold))
    elif rank > diag_process:
        # recv the Q from the diagonal process, and apply it to the trailing matrix
        st_sp = r_tiles.get_start_stop(key=(dcol, dcol))
        sz = st_sp[1] - st_sp[0], st_sp[3] - st_sp[2]

        q1 = torch.zeros((sz[0], sz[0]),
                         dtype=r_tiles.arr.dtype.torch_type(),
                         device=r_torch_device)
        loc_col = 0
        r_tiles.arr.comm.Bcast(q1, root=diag_process)
        hold = r_tiles.local_get(key=(dcol, slice(0, None)))
        r_tiles.local_set(key=(dcol, slice(0, None)),
                          value=torch.matmul(q1.T, hold))
    else:
        # these processes are already done calculating R, only need to calc Q, need to recv q1
        st_sp = r_tiles.get_start_stop(key=(dcol, dcol))
        sz = st_sp[1] - st_sp[0], st_sp[3] - st_sp[2]
        q1 = torch.zeros((sz[0], sz[0]),
                         dtype=r_tiles.arr.dtype.torch_type(),
                         device=r_torch_device)
        r_tiles.arr.comm.Bcast(q1, root=diag_process)

    # ================================ Q Calculation - single tile =============================
    if calc_q:
        for row in range(q0_tiles.tile_rows_per_process[rank]):
            # q1 is applied to each tile of the column dcol of q0 then written there
            q0_tiles.local_set(key=(row, dcol),
                               value=torch.matmul(
                                   q0_tiles.local_get(key=(row, dcol)), q1))
    del q1
    # loop over the rest of the rows, combine the tiles, then apply the result to the rest
    # 2nd step: merged QR on the rows
    # ================================ R Calculation - merged tiles ============================
    diag_tile = r_tiles[dcol, dcol]
    # st_sp = r_tiles.get_start_stop(key=(dcol, dcol))
    diag_st_sp = r_tiles.get_start_stop(key=(dcol, dcol))
    diag_sz = diag_st_sp[1] - diag_st_sp[0], diag_st_sp[3] - diag_st_sp[2]
    # (Q) need to get the start stop of diag tial
    for row in range(dcol + 1, tile_rows):
        lp_st_sp = r_tiles.get_start_stop(key=(row, dcol))
        lp_sz = lp_st_sp[1] - lp_st_sp[0], lp_st_sp[3] - lp_st_sp[2]
        if rank == diag_process:
            # cat diag tile and loop tile
            loop_tile = r_tiles[row, dcol]
            loop_cat = torch.cat((diag_tile, loop_tile), dim=0)
            # qr
            ql, rl = loop_cat.qr(some=False)
            # send ql to all
            r_tiles.arr.comm.Bcast(ql.clone().contiguous(), root=diag_process)
            # set rs
            r_tiles[dcol, dcol] = rl[:diag_sz[0]]
            r_tiles[row, dcol] = rl[diag_sz[0]:]
            # apply q to rest
            if loc_col + 1 < r_tiles.tile_columns_per_process[rank]:
                upp = r_tiles.local_get(key=(dcol, slice(loc_col + 1, None)))
                low = r_tiles.local_get(key=(row, slice(loc_col + 1, None)))
                hold = torch.matmul(ql.T, torch.cat((upp, low), dim=0))
                # set upper
                r_tiles.local_set(key=(dcol, slice(loc_col + 1, None)),
                                  value=hold[:diag_sz[0]])
                # set lower
                r_tiles.local_set(key=(row, slice(loc_col + 1, None)),
                                  value=hold[diag_sz[0]:])
        elif rank > diag_process:
            ql = torch.zeros(
                [lp_sz[0] + diag_sz[0]] * 2,
                dtype=r_tiles.arr.dtype.torch_type(),
                device=r_torch_device,
            )
            r_tiles.arr.comm.Bcast(ql, root=diag_process)
            upp = r_tiles.local_get(key=(dcol, slice(0, None)))
            low = r_tiles.local_get(key=(row, slice(0, None)))
            hold = torch.matmul(ql.T, torch.cat((upp, low), dim=0))
            # set upper
            r_tiles.local_set(key=(dcol, slice(0, None)),
                              value=hold[:diag_sz[0]])
            # set lower
            r_tiles.local_set(key=(row, slice(0, None)),
                              value=hold[diag_sz[0]:])
        else:
            ql = torch.zeros(
                [lp_sz[0] + diag_sz[0]] * 2,
                dtype=r_tiles.arr.dtype.torch_type(),
                device=r_torch_device,
            )
            r_tiles.arr.comm.Bcast(ql, root=diag_process)
        # ================================ Q Calculation - merged tiles ========================
        if calc_q:
            top_left = ql[:diag_sz[0], :diag_sz[0]]
            top_right = ql[:diag_sz[0], diag_sz[0]:]
            bottom_left = ql[diag_sz[0]:, :diag_sz[0]]
            bottom_right = ql[diag_sz[0]:, diag_sz[0]:]
            # two multiplications: one for the left tiles and one for the right
            # left tiles --------------------------------------------------------------------
            # create r column of the same size as the tile row of q0
            st_sp = r_tiles.get_start_stop(key=(slice(dcol, None), dcol))
            qloop_col_left_sz = st_sp[1] - st_sp[0], st_sp[3] - st_sp[2]
            qloop_col_left = torch.zeros(qloop_col_left_sz,
                                         dtype=q0_tiles.arr.dtype.torch_type(),
                                         device=q0_torch_device)
            # top left starts at 0 and goes until diag_sz[1]
            qloop_col_left[:diag_sz[0]] = top_left
            # bottom left starts at ? and goes until ? (only care about 0th dim)
            st, sp, _, _ = r_tiles.get_start_stop(key=(row, 0))
            st -= diag_st_sp[
                0]  # adjust these by subtracting the start index of the diag tile
            sp -= diag_st_sp[0]
            qloop_col_left[st:sp] = bottom_left
            # right tiles --------------------------------------------------------------------
            # create r columns tensor of the size of the tile column of index 'row'
            st_sp = q0_tiles.get_start_stop(key=(row, slice(dcol, None)))
            sz = st_sp[1] - st_sp[0], st_sp[3] - st_sp[2]
            qloop_col_right = torch.zeros(
                sz[1],
                sz[0],
                dtype=q0_tiles.arr.dtype.torch_type(),
                device=q0_torch_device)
            # top left starts at 0 and goes until diag_sz[1]
            qloop_col_right[:diag_sz[0]] = top_right
            # bottom left starts at ? and goes until ? (only care about 0th dim)
            st, sp, _, _ = r_tiles.get_start_stop(key=(row, 0))
            st -= diag_st_sp[
                0]  # adjust these by subtracting the start index of the diag tile
            sp -= diag_st_sp[0]
            qloop_col_right[st:sp] = bottom_right
            for qrow in range(q0_tiles.tile_rows_per_process[rank]):
                # q1 is applied to each tile of the column dcol of q0 then written there
                q0_row = q0_tiles.local_get(key=(qrow,
                                                 slice(dcol, None))).clone()
                q0_tiles.local_set(key=(qrow, dcol),
                                   value=torch.matmul(q0_row, qloop_col_left))
                q0_tiles.local_set(key=(qrow, row),
                                   value=torch.matmul(q0_row, qloop_col_right))
        del ql
Beispiel #21
0
def th_rcumsum(tensor, dim=0):
    return torch.flip(torch.cumsum(torch.flip(tensor, (dim, )), dim), (dim, ))
Beispiel #22
0
 def forward(ctx, input, dim):
     ctx.dim = dim
     return torch.cumsum(input, dim=ctx.dim)
Beispiel #23
0
def cumsum_from_zero(input_: Tensor) -> Tensor:
    cumsum = torch.zeros_like(input_)
    torch.cumsum(input_[:-1], dim=0, out=cumsum[1:])
    return cumsum
def calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels):
    """
    Calculate the Mean Average Precision (mAP) of detected objects.

    See https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 for an explanation

    :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes
    :param det_labels: list of tensors, one tensor for each image containing detected objects' labels
    :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores
    :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes
    :param true_labels: list of tensors, one tensor for each image containing actual objects' labels
    :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1)
    :return: list of average precisions for all classes, mean average precision (mAP)
    """
    assert len(det_boxes) == len(det_labels) == len(det_scores) == len(true_boxes) == len(
        true_labels)  # these are all lists of tensors of the same length, i.e. number of images
    n_classes = len(label_map)

    # Store all (true) objects in a single continuous tensor while keeping track of the image it is from
    true_images = list()
    for i in range(len(true_labels)):
        true_images.extend([i] * true_labels[i].size(0))
    true_images = torch.LongTensor(true_images).to(
        device)  # (n_objects), n_objects is the total no. of objects across all images
    true_boxes = torch.cat(true_boxes, dim=0)  # (n_objects, 4)
    true_labels = torch.cat(true_labels, dim=0)  # (n_objects)

    assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0)

    # Store all detections in a single continuous tensor while keeping track of the image it is from
    det_images = list()
    for i in range(len(det_labels)):
        det_images.extend([i] * det_labels[i].size(0))
    det_images = torch.LongTensor(det_images).to(device)  # (n_detections)
    det_boxes = torch.cat(det_boxes, dim=0)  # (n_detections, 4)
    det_labels = torch.cat(det_labels, dim=0)  # (n_detections)
    det_scores = torch.cat(det_scores, dim=0)  # (n_detections)

    assert det_images.size(0) == det_boxes.size(0) == det_labels.size(0) == det_scores.size(0)

    # Calculate APs for each class (except background)
    average_precisions = torch.zeros((n_classes - 1), dtype=torch.float)  # (n_classes - 1)
    for c in range(1, n_classes):
        # Extract only objects with this class
        true_class_images = true_images[true_labels == c]  # (n_class_objects)
        true_class_boxes = true_boxes[true_labels == c]  # (n_class_objects, 4)
        true_class_difficulties = torch.zeros(true_class_images.size(0))  # (n_class_objects)
        n_easy_class_objects = (1 - true_class_difficulties).sum().item()  # ignore difficult objects

        # Keep track of which true objects with this class have already been 'detected'
        # So far, none
        true_class_boxes_detected = torch.zeros((true_class_difficulties.size(0)), dtype=torch.uint8).to(
            device)  # (n_class_objects)

        # Extract only detections with this class
        det_class_images = det_images[det_labels == c]  # (n_class_detections)
        det_class_boxes = det_boxes[det_labels == c]  # (n_class_detections, 4)
        det_class_scores = det_scores[det_labels == c]  # (n_class_detections)
        n_class_detections = det_class_boxes.size(0)
        if n_class_detections == 0:
            continue

        # Sort detections in decreasing order of confidence/scores
        det_class_scores, sort_ind = torch.sort(det_class_scores, dim=0, descending=True)  # (n_class_detections)
        det_class_images = det_class_images[sort_ind]  # (n_class_detections)
        det_class_boxes = det_class_boxes[sort_ind]  # (n_class_detections, 4)

        # In the order of decreasing scores, check if true or false positive
        true_positives = torch.zeros((n_class_detections), dtype=torch.float).to(device)  # (n_class_detections)
        false_positives = torch.zeros((n_class_detections), dtype=torch.float).to(device)  # (n_class_detections)
        for d in range(n_class_detections):
            this_detection_box = det_class_boxes[d].unsqueeze(0)  # (1, 4)
            this_image = det_class_images[d]  # (), scalar

            # Find objects in the same image with this class, their difficulties, and whether they have been detected before
            object_boxes = true_class_boxes[true_class_images == this_image]  # (n_class_objects_in_img)
            object_difficulties = true_class_difficulties[true_class_images == this_image]  # (n_class_objects_in_img)
            # If no such object in this image, then the detection is a false positive
            if object_boxes.size(0) == 0:
                false_positives[d] = 1
                continue

            # Find maximum overlap of this detection with objects in this image of this class
            overlaps = find_jaccard_overlap(this_detection_box, object_boxes)  # (1, n_class_objects_in_img)
            max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0)  # (), () - scalars

            # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties'
            # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index...
            original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[true_class_images == this_image][ind]
            # We need 'original_ind' to update 'true_class_boxes_detected'

            # If the maximum overlap is greater than the threshold of 0.5, it's a match
            if max_overlap.item() > 0.5:
                # If the object it matched with is 'difficult', ignore it
                if object_difficulties[ind] == 0:
                    # If this object has already not been detected, it's a true positive
                    if true_class_boxes_detected[original_ind] == 0:
                        true_positives[d] = 1
                        true_class_boxes_detected[original_ind] = 1  # this object has now been detected/accounted for
                    # Otherwise, it's a false positive (since this object is already accounted for)
                    else:
                        false_positives[d] = 1
            # Otherwise, the detection occurs in a different location than the actual object, and is a false positive
            else:
                false_positives[d] = 1

        # Compute cumulative precision and recall at each detection in the order of decreasing scores
        cumul_true_positives = torch.cumsum(true_positives, dim=0)  # (n_class_detections)
        cumul_false_positives = torch.cumsum(false_positives, dim=0)  # (n_class_detections)
        cumul_precision = cumul_true_positives / (
                cumul_true_positives + cumul_false_positives + 1e-10)  # (n_class_detections)
        cumul_recall = cumul_true_positives / n_easy_class_objects  # (n_class_detections)

        # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't'
        recall_thresholds = torch.arange(start=0, end=1.1, step=.1).tolist()  # (11)
        precisions = torch.zeros((len(recall_thresholds)), dtype=torch.float).to(device)  # (11)
        for i, t in enumerate(recall_thresholds):
            recalls_above_t = cumul_recall >= t
            if recalls_above_t.any():
                precisions[i] = cumul_precision[recalls_above_t].max()
            else:
                precisions[i] = 0.
        average_precisions[c - 1] = precisions.mean()  # c is in [1, n_classes - 1]

    # Calculate Mean Average Precision (mAP)
    mean_average_precision = average_precisions.mean().item()

    # Keep class-wise average precisions in a dictionary
    average_precisions = {rev_label_map[c + 1]: v for c, v in enumerate(average_precisions.tolist())}

    return average_precisions, mean_average_precision
Beispiel #25
0
    def forward(self,
                key,
                value,
                query,
                mask=None,
                aw_prev=None,
                mode='parallel'):
        """Soft monotonic attention during training.

        Args:
            key (FloatTensor): `[B, kmax, key_dim]`
            value (FloatTensor): `[B, kmax, value_dim]`
            query (FloatTensor): `[B, 1, query_dim]`
            mask (ByteTensor): `[B, qmax, kmax]`
            aw_prev (FloatTensor): `[B, kmax, 1 (n_heads)]`
            mode (str): recursive/parallel/hard
        Return:
            cv (FloatTensor): `[B, 1, value_dim]`
            aw_prev (FloatTensor): `[B, kmax, 1 (n_heads)]`

        """
        bs, kmax = key.size()[:2]

        if aw_prev is None:
            # aw_prev = [1, 0, 0 ... 0]
            aw_prev = key.new_zeros(bs, kmax, 1)
            aw_prev[:, 0:1] = key.new_ones(bs, 1, 1)

        # Compute monotonic energy
        e_mono = self.monotonic_energy(key, query, mask)

        if mode == 'recursive':  # training time
            p_choose = torch.sigmoid(add_gaussian_noise(e_mono))  # `[B, kmax]`
            # Compute [1, 1 - p_choose[0], 1 - p_choose[1], ..., 1 - p_choose[-2]]
            shifted_1_minus_p_choose = torch.cat(
                [key.new_ones(bs, 1), 1 - p_choose[:, :-1]], dim=1)
            # Compute attention distribution recursively as
            # q[j] = (1 - p_choose[j])*q[j - 1] + aw_prev[j]
            # alpha[j] = p_choose[j]*q[j]
            q = key.new_zeros(bs, kmax + 1)
            for j in range(kmax):
                q[:, j + 1] = shifted_1_minus_p_choose[:, j].clone(
                ) * q[:, j].clone() + aw_prev[:, j, 0].clone()
            alpha = p_choose * q[:, 1:]

        elif mode == 'parallel':  # training time
            p_choose = torch.sigmoid(add_gaussian_noise(e_mono))  # `[B, kmax]`
            # safe_cumprod computes cumprod in logspace with numeric checks
            cumprod_1_minus_p_choose = safe_cumprod(1 - p_choose, eps=1e-10)
            # Compute recurrence relation solution
            alpha = p_choose * cumprod_1_minus_p_choose * torch.cumsum(
                aw_prev.squeeze(2) /
                torch.clamp(cumprod_1_minus_p_choose, min=1e-10, max=1.0),
                dim=1)

        elif mode == 'hard':  # test time
            # Attend when monotonic energy is above threshold (Sigmoid > 0.5)
            p_choose = (e_mono > 0).float()

            # Remove any probabilities before the index chosen last time step
            p_choose *= torch.cumsum(aw_prev.squeeze(2), dim=1)  # `[B, kmax]`

            # Now, use exclusive cumprod to remove probabilities after the first
            # chosen index, like so:
            # p_choose                        = [0, 0, 0, 1, 1, 0, 1, 1]
            # 1 - p_choose                    = [1, 1, 1, 0, 0, 1, 0, 0]
            # exclusive_cumprod(1 - p_choose) = [1, 1, 1, 1, 0, 0, 0, 0]
            # alpha: product of above         = [0, 0, 0, 1, 0, 0, 0, 0]
            alpha = p_choose * exclusive_cumprod(1 - p_choose)

            # Not attended => attend at last encoder output
            # NOTE: Assume that encoder outputs are not padded
            attended = alpha.sum(dim=1)
            for i_b in range(bs):
                if attended[i_b] == 0:
                    alpha[i_b, -1] = 1
            # Original paperによるとzero vector
        else:
            raise ValueError(
                "mode must be 'recursive', 'parallel', or 'hard'.")

        # Compute chunk energy
        if self.window > 1:
            e_chunk = self.chunk_energy(key, query, mask)
            beta = efficient_chunkwise_attention(alpha, e_chunk, self.window)
            # alpha_norm = alpha / torch.sum(alpha, dim=1, keepdim=True)
            # beta = efficient_chunkwise_attention(alpha_norm, e_chunk, self.window)

        # Compute context vector
        if self.window > 1:
            cv = torch.bmm(beta.unsqueeze(1), value)
        else:
            cv = torch.bmm(alpha.unsqueeze(1), value)

        return cv, alpha.unsqueeze(2)
    def forward(self,
                input_sentences,
                input_sentence_length,
                input_conversation_length,
                target_sentences,
                decode=False):
        """
        Args:
            input_sentences: (Variable, LongTensor) [num_sentences, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        num_sentences = input_sentences.size(0)
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction]
        # encoder_hidden: [num_layers * direction, num_sentences, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(input_sentences,
                                                       input_sentence_length)

        # encoder_hidden: [num_sentences, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(
            num_sentences, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(
            torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                       input_conversation_length[:-1])), 0)

        # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([
            pad(encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(
                start.data.tolist(), input_conversation_length.data.tolist())
        ], 0)

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(
            encoder_hidden, input_conversation_length)

        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([
            context_outputs[i, :l, :]
            for i, l in enumerate(input_conversation_length.data)
        ])

        # project context_outputs to decoder init state
        decoder_init = self.context2decoder(context_outputs)

        # [num_layers, batch_size, hidden_size]
        decoder_init = decoder_init.view(self.decoder.num_layers, -1,
                                         self.decoder.hidden_size)

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)
            return decoder_outputs

        else:
            # decoder_outputs = self.decoder(target_sentences,
            #                                init_h=decoder_init,
            #                                decode=decode)
            # return decoder_outputs.unsqueeze(1)
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(
                init_h=decoder_init)

            # Get top prediction only
            # [batch_size, max_unroll]
            # prediction = prediction[:, 0]

            # [batch_size, beam_size, max_unroll]
            return prediction
Beispiel #27
0
def exclusive_cumsum(xs):
    """Exclusive cumulative summation [a, b, c] => [0, a, a + b]"""
    # assert len(xs.size()) == 2
    return torch.cumsum(torch.cat([xs.new_zeros(xs.size(0), 1), xs],
                                  dim=1)[:, :-1],
                        dim=1)
Beispiel #28
0
    def projection_l2(self, points_to_project, w_hyperplane, b_hyperplane):
        t = points_to_project.clone()
        w = w_hyperplane.clone()
        b = b_hyperplane.clone()

        c = (w * t).sum(1) - b
        ind2 = (c < 0).nonzero().squeeze()
        ind2 = self.check_shape(ind2)
        w[ind2] *= -1
        c[ind2] *= -1

        u = torch.arange(0, w.shape[0]).unsqueeze(1)

        r = torch.max(t / w, (t - 1) / w)
        u2 = torch.ones(r.shape).to(self.device)
        r = torch.min(r, 1e12 * u2)
        r = torch.max(r, -1e12 * u2)
        r[w.abs() < 1e-8] = 1e12
        r[r == -1e12] = -r[r == -1e12]
        rs, indr = torch.sort(r, dim=1)
        rs2 = torch.cat(
            (rs[:, 1:], torch.zeros(rs.shape[0], 1).to(self.device)), 1)
        rs[rs == 1e12] = 0
        rs2[rs2 == 1e12] = 0

        w3 = w**2
        w3s = w3[u, indr]
        w5 = w3s.sum(dim=1, keepdim=True)
        ws = w5 - torch.cumsum(w3s, dim=1)
        d = -(r * w).clone()
        d = d * (w.abs() > 1e-8).float()
        s = torch.cat(
            ((-w5.squeeze() * rs[:, 0]).unsqueeze(1),
             torch.cumsum(
                 (-rs2 + rs) * ws, dim=1) - w5 * rs[:, 0].unsqueeze(-1)), 1)

        c4 = (s[:, 0] + c < 0)
        c3 = ((d * w).sum(dim=1) + c > 0)
        c6 = c4.nonzero().squeeze()
        c2 = ((1 - c4.float()) * (1 - c3.float())).nonzero().squeeze()
        c6 = self.check_shape(c6)
        c2 = self.check_shape(c2)

        counter = 0
        lb = torch.zeros(c2.shape[0])
        ub = torch.ones(c2.shape[0]) * (w.shape[1] - 1)
        nitermax = torch.ceil(torch.log2(torch.tensor(w.shape[1]).float()))
        counter2 = torch.zeros(lb.shape).long()

        while counter < nitermax:
            counter4 = torch.floor((lb + ub) / 2)
            counter2 = counter4.long()
            c3 = s[c2, counter2] + c[c2] > 0
            ind3 = c3.nonzero().squeeze()
            ind32 = (~c3).nonzero().squeeze()
            ind3 = self.check_shape(ind3)
            ind32 = self.check_shape(ind32)
            lb[ind3] = counter4[ind3]
            ub[ind32] = counter4[ind32]
            counter += 1

        lb = lb.long()
        alpha = torch.zeros([1])

        if c6.nelement() != 0:
            alpha = c[c6] / w5[c6].squeeze(-1)
            d[c6] = -alpha.unsqueeze(-1) * w[c6]

        if c2.nelement() != 0:
            alpha = (s[c2, lb] + c[c2]) / ws[c2, lb] + rs[c2, lb]
            if torch.sum(ws[c2, lb] == 0) > 0:
                ind = (ws[c2, lb] == 0).nonzero().squeeze().long()
                ind = self.check_shape(ind)
                alpha[ind] = 0
            c5 = (alpha.unsqueeze(-1) > r[c2]).float()
            d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5)

        return d * (w.abs() > 1e-8).float()
Beispiel #29
0
def average_precision(
        # all prediction boxes of considered data set
        pred_boxes,  # pred_boxes (list): [[img_idx, class_pred, prob_score, x1, y1, x2, y2], [...], ...]
        true_boxes,
        iou_threshold=0.5,
        box_format='corners',
        num_classes=20):
    average_precisions = []
    epsilon = 1e-6

    # 6. Do all for all classes
    for c in range(num_classes):
        # 1. Get all bounding boxes (predictions and groun-truths)
        detections = [
            detection for detection in pred_boxes if detection[1] == c
        ]
        ground_truths = [
            true_box for true_box in true_boxes if true_box[1] == c
        ]

        # 2. Sort by descending confidence score
        detections.sort(key=lambda x: x[2], reverse=True)

        # 3. Calculate the precision and recall as we go through all the outputs
        # dictionary containing the number of bboxes for each image
        amount_bboxes = Counter([gt[0] for gt in ground_truths
                                 ])  # amount_bboxes = {0: 3, 1: 5, ...}

        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(
                val
            )  # amount_bboxes = {0: torch.tensor([0, 0, 0]), 1: torch.tensor([0, 0, 0, 0, 0]), ...}

        # initialize some variables
        TP = torch.zeros(len(detections))
        FP = torch.zeros(len(detections))
        total_true_bboxes = len(ground_truths)

        for detection_idx, detection in enumerate(detections):
            # we only compare bboxes and detections of the same images
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]
            num_gts = len(ground_truth_img)

            best_iou = 0
            for idx, gt in enumerate(ground_truth_img):
                iou = intersect_over_union(torch.tensor(detection[3:]),
                                           torch.tensor(gt[3:]),
                                           box_format=box_format)
                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        # 4. "Plot" the Precision-Recall graph
        # [1, 1, 0, 1, 0] - > [1, 2, 2, 3, 3]
        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        recalls = torch.cat((torch.tensor([0]), recalls))
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
        precisions = torch.cat((torch.tensor([1]), precisions))

        # 5. Calculate the Area under the PR curve
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)
Beispiel #30
0
    def compute_adjacency_info(vertices: torch.Tensor, faces: torch.Tensor):
        """Build data structures to help speed up connectivity queries. Assumes
        a homogeneous mesh, i.e., each face has the same number of vertices.

        The outputs have the following format: AA, AA_count
        AA_count: [count_0, ..., count_n]
        with AA:
        [[aa_{0,0}, ..., aa_{0,count_0} (, -1, ..., -1)],
         [aa_{1,0}, ..., aa_{1,count_1} (, -1, ..., -1)],
                    ...
         [aa_{n,0}, ..., aa_{n,count_n} (, -1, ..., -1)]]
        """

        device = vertices.device
        facesize = faces.shape[1]
        nb_vertices = vertices.shape[0]
        nb_faces = faces.shape[0]
        edges = torch.cat([faces[:, i:i + 2]
                           for i in range(facesize - 1)] + [faces[:, [-1, 0]]],
                          dim=0)
        # Sort the vertex of edges in increasing order
        edges = torch.sort(edges, dim=1)[0]
        # id of corresponding face in edges
        face_ids = torch.arange(nb_faces, device=device,
                                dtype=torch.long).repeat(facesize)
        # remove multiple occurences and sort by the first vertex
        # the edge key / id is fixed from now as the first axis position
        # edges_ids will give the key of the edges on the original vector
        edges, edges_ids = torch.unique(edges,
                                        sorted=True,
                                        return_inverse=True,
                                        dim=0)
        nb_edges = edges.shape[0]

        # EDGE2FACE
        sorted_edges_ids, order_edges_ids = torch.sort(edges_ids)
        sorted_faces_ids = face_ids[order_edges_ids]
        # indices of first occurences of each key
        idx_first = torch.where(
            torch.nn.functional.pad(
                sorted_edges_ids[1:] != sorted_edges_ids[:-1], (1, 0),
                value=1))[0]
        nb_faces_per_edge = idx_first[1:] - idx_first[:-1]
        # compute sub_idx (2nd axis indices to store the faces)
        offsets = torch.zeros(sorted_edges_ids.shape[0],
                              device=device,
                              dtype=torch.long)
        offsets[idx_first[1:]] = nb_faces_per_edge
        sub_idx = (torch.arange(
            sorted_edges_ids.shape[0], device=device, dtype=torch.long) -
                   torch.cumsum(offsets, dim=0))
        # TODO(cfujitsang): potential way to compute sub_idx differently
        #                   to test with bigger model
        #sub_idx = torch.ones(sorted_edges_ids.shape[0], device=device, dtype=torch.long)
        #sub_idx[0] = 0
        #sub_idx[idx_first[1:]] = 1 - nb_faces_per_edge
        #sub_idx = torch.cumsum(sub_idx, dim=0)
        nb_faces_per_edge = torch.cat(
            [nb_faces_per_edge, sorted_edges_ids.shape[0] - idx_first[-1:]],
            dim=0)
        max_sub_idx = torch.max(nb_faces_per_edge)
        ef = torch.zeros(
            (nb_edges, max_sub_idx), device=device, dtype=torch.long) - 1
        ef[sorted_edges_ids, sub_idx] = sorted_faces_ids
        # FACE2FACES
        nb_faces_per_face = torch.stack([
            nb_faces_per_edge[edges_ids[i * nb_faces:(i + 1) * nb_faces]]
            for i in range(facesize)
        ],
                                        dim=1).sum(dim=1) - facesize
        ff = torch.cat([
            ef[edges_ids[i * nb_faces:(i + 1) * nb_faces]]
            for i in range(facesize)
        ],
                       dim=1)
        # remove self occurences
        ff[ff == torch.arange(nb_faces, device=device, dtype=torch.long).view(
            -1, 1)] = -1
        ff = torch.sort(ff, dim=-1, descending=True)[0]
        to_del = (ff[:, 1:] == ff[:, :-1]) & (ff[:, 1:] != -1)
        ff[:, 1:][to_del] = -1
        nb_faces_per_face = nb_faces_per_face - torch.sum(to_del, dim=1)
        max_sub_idx = torch.max(nb_faces_per_face)
        ff = torch.sort(ff, dim=-1, descending=True)[0][:, :max_sub_idx]

        # VERTEX2VERTICES and VERTEX2EDGES
        npy_edges = edges.cpu().numpy()
        edge2key = {tuple(npy_edges[i]): i for i in range(nb_edges)}
        #_edges and double_edges 2nd axis correspond to the triplet:
        # [left vertex, right vertex, edge key]
        _edges = torch.cat(
            [edges, torch.arange(nb_edges, device=device).view(-1, 1)], dim=1)
        double_edges = torch.cat([_edges, _edges[:, [1, 0, 2]]], dim=0)
        double_edges = torch.unique(double_edges, sorted=True, dim=0)
        # TODO(cfujitsang): potential improvment, to test with bigger model:
        #double_edges0, order_double_edges = torch.sort(double_edges[0])
        nb_double_edges = double_edges.shape[0]
        # indices of first occurences of each key
        idx_first = torch.where(
            torch.nn.functional.pad(
                double_edges[1:, 0] != double_edges[:-1, 0], (1, 0),
                value=1))[0]
        nb_edges_per_vertex = idx_first[1:] - idx_first[:-1]
        # compute sub_idx (2nd axis indices to store the edges)
        offsets = torch.zeros(nb_double_edges, device=device, dtype=torch.long)
        offsets[idx_first[1:]] = nb_edges_per_vertex
        sub_idx = (
            torch.arange(nb_double_edges, device=device, dtype=torch.long) -
            torch.cumsum(offsets, dim=0))
        nb_edges_per_vertex = torch.cat(
            [nb_edges_per_vertex, nb_double_edges - idx_first[-1:]], dim=0)
        max_sub_idx = torch.max(nb_edges_per_vertex)
        vv = torch.zeros(
            (nb_vertices, max_sub_idx), device=device, dtype=torch.long) - 1
        vv[double_edges[:, 0], sub_idx] = double_edges[:, 1]
        ve = torch.zeros(
            (nb_vertices, max_sub_idx), device=device, dtype=torch.long) - 1
        ve[double_edges[:, 0], sub_idx] = double_edges[:, 2]
        # EDGE2EDGES
        ee = torch.cat([ve[edges[:, 0], :], ve[edges[:, 1], :]], dim=1)
        nb_edges_per_edge = nb_edges_per_vertex[
            edges[:, 0]] + nb_edges_per_vertex[edges[:, 1]] - 2
        max_sub_idx = torch.max(nb_edges_per_edge)
        # remove self occurences
        ee[ee == torch.arange(nb_edges, device=device, dtype=torch.long).view(
            -1, 1)] = -1
        ee = torch.sort(ee, dim=-1, descending=True)[0][:, :max_sub_idx]
        # VERTEX2FACES
        vertex_ordered, order_vertex = torch.sort(faces.view(-1))
        face_ids_in_vertex_order = order_vertex / facesize
        # indices of first occurences of each id
        idx_first = torch.where(
            torch.nn.functional.pad(vertex_ordered[1:] != vertex_ordered[:-1],
                                    (1, 0),
                                    value=1))[0]
        nb_faces_per_vertex = idx_first[1:] - idx_first[:-1]
        # compute sub_idx (2nd axis indices to store the faces)
        offsets = torch.zeros(vertex_ordered.shape[0],
                              device=device,
                              dtype=torch.long)
        offsets[idx_first[1:]] = nb_faces_per_vertex
        sub_idx = (torch.arange(
            vertex_ordered.shape[0], device=device, dtype=torch.long) -
                   torch.cumsum(offsets, dim=0))
        # TODO(cfujitsang): it seems that nb_faces_per_vertex == nb_edges_per_vertex ?
        nb_faces_per_vertex = torch.cat(
            [nb_faces_per_vertex, vertex_ordered.shape[0] - idx_first[-1:]],
            dim=0)
        max_sub_idx = torch.max(nb_faces_per_vertex)
        vf = torch.zeros(
            (nb_vertices, max_sub_idx), device=device, dtype=torch.long) - 1
        vf[vertex_ordered, sub_idx] = face_ids_in_vertex_order

        return edge2key, edges, vv, nb_edges_per_vertex, ve, nb_edges_per_vertex, vf, \
            nb_faces_per_vertex, ff, nb_faces_per_face, ee, nb_edges_per_edge, ef, nb_faces_per_edge
Beispiel #31
0
    def decode(self, input_seq, input_lens, top_p=0, max_len=100):
        batch_size = input_seq.size(0)
        predictions = [['_go'] for _ in range(batch_size)]
        eos_seen = [False for _ in range(batch_size)]

        def _pad(arr, pad):
            # Given an array of integer arrays, pad all arrays to the same length
            lengths = [len(e) for e in arr]
            max_len = max(lengths)
            return [e + [pad] * (max_len - len(e)) for e in arr], lengths

        with torch.no_grad():
            enc_output = self.transformer.enc(input_seq)

            for t in range(max_len):
                # Create the targets so far
                targets = [[
                    self.w2i.get(w, self.w2i['_unk']) for w in row + ['_pad']
                ] for row in predictions]
                target_seq, target_lens = _pad(targets, pad=self.w2i['_pad'])
                target_seq = torch.cuda.LongTensor(target_seq)

                # Pass through transformer
                proba = F.softmax(self.transformer(input_seq,
                                                   target_seq,
                                                   enc_output=enc_output),
                                  dim=-1)[:, -1]

                # Get top candidates
                if top_p == 0:
                    topv, topi = proba.topk(1)
                else:
                    s_probs, s_inds = torch.sort(proba, descending=True)
                    cum_probs = torch.cumsum(s_probs, dim=-1)

                    # Remove all outside the nucleus
                    sinds_to_remove = cum_probs > top_p

                    # HuggingFace implementation did this to ensure first one is kept
                    sinds_to_remove[:, 1:] = sinds_to_remove[:, :-1].clone()
                    sinds_to_remove[:, 0] = 0

                    for b in range(s_inds.size(0)):
                        # Remove
                        inds_to_remove = s_inds[b][sinds_to_remove[b]]

                        # Set to be filtered in original
                        proba[b, inds_to_remove] = 0

                    # Sample
                    topi = torch.multinomial(proba.squeeze(0), 1)

                topi = topi.view(-1)
                words = [self.i2w[e.item()] for e in topi]
                for i in range(len(predictions)):
                    predictions[i].append(words[i])
                    if words[i] == '_eos':
                        eos_seen[i] = True

                if all(eos_seen):
                    break

        predicted_sentences = []
        for sentence in predictions:
            predicted_sentences.append(' '.join(
                sentence[1:-1 if '_eos' not in
                         sentence else sentence.index('_eos')]))

        return predicted_sentences
Beispiel #32
0
def compute_pr_curves(class_hist, total_hist):
    """
    Computes precision recall curves from the true sample / total
    sample histogram tensors. The histogram tensors are num_bins x num_classes
    and each column represents a histogram over
    prediction_probabilities.

    The two tensors should have the same dimensions.
    The two tensors should have nonnegative integer values.

    Returns map of precision / recall values from highest precision to lowest
    and the calculated AUPRC (i.e. the average precision).
    """
    assert torch.is_tensor(class_hist) and torch.is_tensor(
        total_hist), "Both arguments must be tensors"
    assert (class_hist.dtype == torch.int64 and total_hist.dtype
            == torch.int64), "Both arguments must contain int64 values"
    assert (len(class_hist.size()) == 2 and len(total_hist.size())
            == 2), "Both arguments must have 2 dimensions, (score_bin, class)"
    assert (class_hist.size() == total_hist.size()), """
        For compute_pr_curve, arguments must be  of same size.
        class_hist.size(): %s
        total_hist.size(): %s
        """ % (
        str(class_hist.size()),
        str(total_hist.size()),
    )
    assert (class_hist > total_hist).sum() == 0, (
        "Invalid. Class histogram must be less than or equal to total histogram"
    )

    num_bins = class_hist.size()[0]
    # Cumsum from highest bucket to lowest
    cum_class_hist = torch.cumsum(torch.flip(class_hist, dims=(0, )),
                                  dim=0).double()
    cum_total_hist = torch.cumsum(torch.flip(total_hist, dims=(0, )),
                                  dim=0).double()
    class_totals = cum_class_hist[-1, :]

    prec_t = cum_class_hist / cum_total_hist
    recall_t = cum_class_hist / class_totals

    prec = torch.unbind(prec_t, dim=1)
    recall = torch.unbind(recall_t, dim=1)
    assert len(prec) == len(
        recall
    ), "The number of precision curves does not match the number of recall curves"

    final_prec = []
    final_recall = []
    final_ap = []
    for c, prec_curve in enumerate(prec):
        recall_curve = recall[c]
        assert (
            recall_curve.size()[0] == num_bins
            and prec_curve.size()[0] == num_bins
        ), "Precision and recall curves do not have the correct number of entries"

        # Check if any samples from class were seen
        if class_totals[c] == 0:
            continue

        # Remove duplicate entries
        prev_r = torch.tensor(-1.0).double()
        prev_p = torch.tensor(1.1).double()
        new_recall_curve = torch.tensor([], dtype=torch.double)
        new_prec_curve = torch.tensor([], dtype=torch.double)
        for idx, r in enumerate(recall_curve):
            p = prec_curve[idx]
            # Remove points on PR curve that are invalid
            if r.item() <= 0:
                continue

            # Remove duplicates (due to empty buckets):
            if r.item() == prev_r.item() and p.item() == prev_p.item():
                continue

            # Add points to curve
            new_recall_curve = torch.cat((new_recall_curve, r.unsqueeze(0)),
                                         dim=0)
            new_prec_curve = torch.cat((new_prec_curve, p.unsqueeze(0)), dim=0)
            prev_r = r
            prev_p = p

        ap = calc_ap(new_prec_curve, new_recall_curve)
        final_prec.append(new_prec_curve)
        final_recall.append(new_recall_curve)
        final_ap.append(ap)

    return {"prec": final_prec, "recall": final_recall, "ap": final_ap}
Beispiel #33
0
    def _accumulate(self, **kwargs):
        '''
        accumulate stats in all images to calculate AP
        '''
        print('accumulating results')
        num_gt = self.num_gt
        tps = torch.cat(self.tps, dim=1)
        # sort all the tps in descending order of score
        scores = torch.cat(self.scores, dim=0)
        scores, sortidx = torch.sort(scores, dim=0, descending=True)
        tps = tps[:, sortidx]
        # False Positive = NOT True Positive
        fps = ~tps
        num_dt = tps.shape[1]
        assert tps.dim() == 2 and tps.shape[0] == len(self.iou_thres)

        # accumulate
        tps, fps = tps.float(), fps.float()
        tp_sum = torch.cumsum(tps, dim=1)
        fp_sum = torch.cumsum(fps, dim=1)
        assert ((tp_sum[:, -1] + fp_sum[:, -1]) == num_dt).all()
        # calculate precision and recall
        precision = tp_sum / (tp_sum + fp_sum)
        recall = tp_sum / num_gt
        f1 = 2 * (precision * recall) / (precision + recall)
        # print('p', precision[0,:100])
        # print('r', recall[0,:100])
        if kwargs.get('debug', False):
            import matplotlib.pyplot as plt
            p = precision[0, :].numpy()
            r = recall[0, :].numpy()
            plt.plot(r, p)
            plt.title(f'P-R curve at IoU=0.5 before smoothing')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            plt.show()
            debug = 1

        # initialize the approximate P-R curve for all IoU thresholds
        PRcurve = torch.zeros(len(self.iou_thres), len(self.rec_thres))
        # there is no searchsorted() in pytorch so convert recall to numpy
        recall = recall.numpy()
        for ti, (prec_T, rc_T) in enumerate(zip(precision, recall)):
            assert prec_T.shape[0] == rc_T.shape[0] == num_dt

            # make the Precision monotonically decreasing
            for i in range(num_dt - 1, 0, -1):
                if prec_T[i] > prec_T[i - 1]:
                    prec_T[i - 1] = prec_T[i]
            # find the 101 recall points
            idxs = np.searchsorted(rc_T, self.rec_thres, side='left')
            # fill in the P-R curve
            for ri, pi in enumerate(idxs):
                if pi >= len(prec_T):
                    # reach the upper bound of Recall
                    break
                PRcurve[ti, ri] = prec_T[pi]

        self.PRcurve = PRcurve
        self.APs = self.PRcurve.mean(dim=1)
        self.best_thres = scores[torch.argmax(f1, dim=1)]
Beispiel #34
0
 def forward(ctx, input, dim):
     ctx.dim = dim
     return torch.cumsum(input, dim=ctx.dim)
Beispiel #35
0
 def test_cumsum(self):
     x = torch.randn(2, 3, 4, requires_grad=True)
     self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
Beispiel #36
0
    def decode(self, input_seq, input_lens, top_p=0, max_len=100, p_copy=0):
        batch_size = input_seq.size(1)
        predictions = torch.zeros((batch_size, max_len))

        with torch.no_grad():
            # Encoder
            encoder_outputs, encoder_hidden = self.encoder(
                input_seq, input_lens)

            # Decoder
            decoder_hidden = encoder_hidden
            last_word = torch.cuda.LongTensor(
                [[self.w2i['_go'] for _ in range(batch_size)]])

            # Input one-hot
            input_oh = torch.eye(len(self.w2i))[input_seq].cuda()
            for t in range(max_len):
                # Pass through decoder
                decoder_output, decoder_hidden, attn = self.decoder(
                    decoder_hidden,
                    last_word,
                    encoder_outputs,
                    ret_logits=top_p > 0,
                    ret_attn=True)
                copy_prob = attn.bmm(input_oh.permute(1, 0,
                                                      2)).permute(1, 0, 2)

                # Get top candidates
                if top_p == 0:
                    topv, topi = (torch.exp(decoder_output) +
                                  p_copy * copy_prob).data.topk(1)
                else:
                    probs = F.softmax(decoder_output,
                                      dim=-1) + p_copy * copy_prob
                    s_probs, s_inds = torch.sort(probs, descending=True)
                    cum_probs = torch.cumsum(s_probs, dim=-1)

                    # Remove all outside the nucleus
                    sinds_to_remove = cum_probs > top_p

                    # HuggingFace implementation did this to ensure first one is kept
                    sinds_to_remove[:, :,
                                    1:] = sinds_to_remove[:, :, :-1].clone()
                    sinds_to_remove[:, :, 0] = 0

                    for b in range(s_inds.size(1)):
                        # Remove
                        inds_to_remove = s_inds[:, b][sinds_to_remove[:, b]]

                        # Set to be filtered in original
                        probs[0, b, inds_to_remove] = 0

                    # Sample
                    topi = torch.multinomial((probs).squeeze(0), 1)

                topi = topi.view(-1)
                predictions[:, t] = topi

                # Set new last word
                last_word = topi.detach().view(1, -1)

        predicted_sentences = []
        for sentence in predictions:
            sent = []
            for ind in sentence:
                word = self.i2w[ind.long().item()]
                if word == '_eos':
                    break
                sent.append(word)
            predicted_sentences.append(' '.join(sent))

        return predicted_sentences
Beispiel #37
0
def __split0_global_q_dict_set(q_dict_col,
                               col,
                               r_tiles,
                               q_tiles,
                               global_merge_dict=None):
    """

    The function takes the original Q tensors from the global QR calculation and sets them to
    the keys which corresponds with their tile coordinates in Q. this returns a separate dictionary,
    it does NOT set the values of Q

    Parameters
    ----------
    q_dict_col : Dict
        The dictionary of the Q values for a given column, should be given as q_dict[col]
    col : int, single element torch.Tensor
        current column for which Q is being calculated for
    r_tiles : tiling.SquareDiagTiles
        tiling object for 'r'
    q_tiles : tiling.SquareDiagTiles
        tiling object for Q0
    global_merge_dict : Dict, optional
        the output of the function will be in this dictionary
        Form of output: key index : torch.Tensor

    Returns
    -------
    None
    """
    # q is already created, the job of this function is to create the group the merging q's together
    # it takes the merge qs, splits them, then puts them into a new dictionary
    proc_tile_start = torch.cumsum(
        torch.tensor(r_tiles.tile_rows_per_process,
                     device=r_tiles.arr._DNDarray__array.device),
        dim=0,
    )
    diag_proc = torch.nonzero(input=proc_tile_start > col,
                              as_tuple=False)[0].item()
    proc_tile_start = torch.cat(
        (torch.tensor([0], device=r_tiles.arr._DNDarray__array.device),
         proc_tile_start[:-1]),
        dim=0)

    # 1: create caqr dictionary
    # need to have empty lists for all tiles in q
    global_merge_dict = {} if global_merge_dict is None else global_merge_dict

    # intended to be used as [row][column] -> data
    # 2: loop over keys in the dictionary
    merge_list = list(q_dict_col.keys())
    merge_list.sort()
    # todo: possible improvement -> make the keys have the process they are on as well,
    #  then can async get them if they are not on the diagonal process
    for key in merge_list:
        # this loops over all of the Qs for col and creates the dictionary for the pr Q merges
        p0 = key.find("p0")
        p1 = key.find("p1")
        end = key.find("e")
        r0 = int(key[p0 + 2:p1])
        r1 = int(key[p1 + 2:end])
        lp_q = q_dict_col[key][0]
        base_size = q_dict_col[key][1]
        # cut the q into 4 bits (end of base array)
        # todo: modify this so that it will get what is needed from the process,
        #  instead of gathering all the qs
        top_left = lp_q[:base_size[0], :base_size[0]]
        top_right = lp_q[:base_size[0], base_size[0]:]
        bottom_left = lp_q[base_size[0]:, :base_size[0]]
        bottom_right = lp_q[base_size[0]:, base_size[0]:]
        # need to adjust the keys to be the global row
        if diag_proc == r0:
            col1 = col
        else:
            col1 = proc_tile_start[r0].item()
        col2 = proc_tile_start[r1].item()
        # col0 and col1 are the columns numbers
        # r0 and r1 are the ranks
        jdim = (col1, col1)
        kdim = (col1, col2)
        ldim = (col2, col1)
        mdim = (col2, col2)

        # if there are no elements on that location than set it as the tile
        # 1. get keys of what already has data
        curr_keys = set(global_merge_dict.keys())
        # 2. determine which tiles need to be touched/created
        # these are the keys which are to be multiplied by the q in the current loop
        # for matrix of form: | J  K |
        #                     | L  M |
        mult_keys_00 = [(i, col1) for i in range(q_tiles.tile_columns)]  # (J)
        # (J) -> inds: (i, col0)(col0, col0) -> set at (i, col0)
        mult_keys_01 = [(i, col1) for i in range(q_tiles.tile_columns)]  # (K)
        # (K) -> inds: (i, col0)(col0, col1) -> set at (i, col1)
        mult_keys_10 = [(i, col2) for i in range(q_tiles.tile_columns)]  # (L)
        # (L) -> inds: (i, col1)(col1, col0) -> set at (i, col0)
        mult_keys_11 = [(i, col2) for i in range(q_tiles.tile_columns)]  # (M)
        # (M) -> inds: (i, col1)(col1, col1) -> set at (i, col1)

        # if there are no elements in the mult_keys then set the element to the same place
        s00 = set(mult_keys_00) & curr_keys
        s01 = set(mult_keys_01) & curr_keys
        s10 = set(mult_keys_10) & curr_keys
        s11 = set(mult_keys_11) & curr_keys
        hold_dict = global_merge_dict.copy()

        # (J)
        if not len(s00):
            global_merge_dict[jdim] = top_left
        else:  # -> do the mm for all of the mult keys
            for k in s00:
                global_merge_dict[k[0], jdim[1]] = hold_dict[k] @ top_left
        # (K)
        if not len(s01):
            # check that we are not overwriting here
            global_merge_dict[kdim] = top_right
        else:  # -> do the mm for all of the mult keys
            for k in s01:
                global_merge_dict[k[0], kdim[1]] = hold_dict[k] @ top_right
        # (L)
        if not len(s10):
            # check that we are not overwriting here
            global_merge_dict[ldim] = bottom_left
        else:  # -> do the mm for all of the mult keys
            for k in s10:
                global_merge_dict[k[0], ldim[1]] = hold_dict[k] @ bottom_left
        # (M)
        if not len(s11):
            # check that we are not overwriting here
            global_merge_dict[mdim] = bottom_right
        else:  # -> do the mm for all of the mult keys
            for k in s11:
                global_merge_dict[k[0], mdim[1]] = hold_dict[k] @ bottom_right
    return global_merge_dict