Beispiel #1
0
 def forward(self, longlabel, floatlabel, maxnum, feature):
     newcenter = self.center.index_select(dim=0, index=longlabel)
     count = torch.histc(floatlabel,
                         bins=int(maxnum + 1),
                         min=0,
                         max=int(maxnum))
     num = count.index_select(dim=0, index=longlabel)
     loss = torch.mean(
         torch.sqrt(torch.sum((feature - newcenter)**2)) / num)
     return loss
Beispiel #2
0
def feat_map_shape_hist(inp, nb_bins=8):
    """
    https://discuss.pytorch.org/t/differentiable-torch-histc/25865
    """
    x = inp
    x = x.contiguous().to(device)
    with torch.no_grad():
        hist_ = torch.histc(x, bins=inp.shape[0]*inp.shape[1]*nb_bins, min=x.min().item(), max=x.max().item()).cuda().view(inp.shape[0], inp.shape[1], nb_bins)
    
    return hist_
Beispiel #3
0
def accuracy(pred_cls, true_cls, nclass=15, drop=drop):
    positive = torch.histc(true_cls.cpu().float(), bins=nclass, min=0, max=nclass, out=None)
    per_cls_counts = []
    tpos = []
    for i in range(nclass):
        if i not in drop:
            true_positive = ((pred_cls == i) + (true_cls == i)).eq(2).sum().item()
            tpos.append(true_positive)
            per_cls_counts.append(positive[i])
    return np.array(tpos), np.array(per_cls_counts)
Beispiel #4
0
    def forward(self, x_orig):
        # type: (Tensor) -> Tensor
        x = x_orig.detach()
        min_val = self.min_val
        max_val = self.max_val
        if min_val.numel() == 0 or max_val.numel() == 0:
            min_val = torch.min(x)
            max_val = torch.max(x)
            self.min_val.resize_(min_val.shape)
            self.min_val.copy_(min_val)
            self.max_val.resize_(max_val.shape)
            self.max_val.copy_(max_val)
            torch.histc(x, self.bins, min=min_val, max=max_val, out=self.histogram)
        else:
            new_min = torch.min(x)
            new_max = torch.max(x)
            combined_min = torch.min(new_min, min_val)
            combined_max = torch.max(new_max, max_val)
            # combine the existing histogram and new histogram into 1 histogram
            # We do this by first upsampling the histogram to a dense grid
            # and then downsampling the histogram efficiently
            combined_min, combined_max, downsample_rate, start_idx = \
                self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
            combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max)
            if combined_min == min_val and combined_max == max_val:
                combined_histogram += self.histogram
            else:
                combined_histogram = self._combine_histograms(
                    combined_histogram,
                    self.histogram,
                    self.upsample_rate,
                    downsample_rate,
                    start_idx,
                    self.bins)

            self.histogram.resize_(combined_histogram.shape)
            self.histogram.copy_(combined_histogram)
            self.min_val.resize_(combined_min.shape)
            self.min_val.copy_(combined_min)
            self.max_val.resize_(combined_max.shape)
            self.max_val.copy_(combined_max)
        return x_orig
def stage1_test(net, testloader, device):
    correct = 0
    total = 0
    Energy_list = []
    Target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            out = net(inputs)  # shape [batch,class]
            # energy = (out["normweight_fea2cen"]).sum(dim=1, keepdim=False)
            energy = torch.logsumexp(out["normweight_fea2cen"],
                                     dim=1,
                                     keepdim=False)
            Energy_list.append(energy)
            Target_list.append(targets)

            _, predicted = (out["normweight_fea2cen"]).max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(
                batch_idx, len(testloader), '| Acc: %.3f%% (%d/%d)' %
                (100. * correct / total, correct, total))

    print("\nTesting results is {:.2f}%".format(100. * correct / total))

    # Energy analysis
    Energy_list = torch.cat(Energy_list, dim=0)
    Target_list = torch.cat(Target_list, dim=0)
    unknown_label = Target_list.max()
    unknown_Energy_list = Energy_list[Target_list == unknown_label]
    known_Energy_list = Energy_list[Target_list != unknown_label]
    unknown_hist = torch.histc(unknown_Energy_list,
                               bins=args.bins,
                               min=Energy_list.min().data,
                               max=Energy_list.max().data)
    known_hist = torch.histc(known_Energy_list,
                             bins=args.bins,
                             min=Energy_list.min().data,
                             max=Energy_list.max().data)
    print(f"unknown_hist: \n{unknown_hist}")
    print(f"known_hist: \n{known_hist}")
Beispiel #6
0
def batch_intersection_union(predict, target, nclass):
    """Batch Intersection of Union
    Args:
        predict: input 4D tensor
        target: label 3D tensor
        nclass: number of categories (int)
    """
    _, predict = torch.max(predict, 1)
    mini = 0
    maxi = nclass - 1
    nbins = nclass
    """
    predict = predict.cpu().numpy()
    target = target.cpu().numpy()
    predict = predict * (target >= 0).astype(predict.dtype)
    intersection = predict * (predict == target)
    # areas of intersection and union
    area_inter, _ = np.histogram(intersection, bins=nbins,
                                 range=(mini, maxi))
    area_pred, _ = np.histogram(predict, bins=nbins,
                                range=(mini, maxi))
    area_lab, _ = np.histogram(target, bins=nbins,
                               range=(mini, maxi))
    area_union = area_pred + area_lab - area_inter
    # Somehow PyTorch update break this, will change back if fixed
    """
    predict = predict * (target >= 0).type_as(predict)
    intersection = predict * (predict == target).type_as(predict)
    area_inter = torch.histc(intersection.cpu().float(),
                             bins=nclass,
                             min=mini,
                             max=maxi)
    area_pred = torch.histc(predict.cpu().float(),
                            bins=nclass,
                            min=mini,
                            max=maxi)
    area_lab = torch.histc(target.cpu().float(),
                           bins=nclass,
                           min=mini,
                           max=maxi)
    area_union = area_pred + area_lab - area_inter
    return area_inter, area_union
Beispiel #7
0
 def forward(self, x_orig):
     # type: (Tensor) -> Tensor
     x = x_orig.detach()
     min_val = self.min_val
     max_val = self.max_val
     if min_val is None or max_val is None:
         min_val = torch.min(x)
         max_val = torch.max(x)
         self.min_val = min_val
         self.max_val = max_val
         self.histogram = torch.histc(x,
                                      self.bins,
                                      min=min_val,
                                      max=max_val)
     else:
         new_min = torch.min(x)
         new_max = torch.max(x)
         new_histogram = torch.histc(x, self.bins, min=new_min, max=new_max)
         # combine the existing histogram and new histogram into 1 histogram
         combined_histogram = torch.zeros_like(self.histogram)
         combined_min = torch.min(new_min, min_val)
         combined_max = torch.max(new_max, max_val)
         self._combine_histograms(
             combined_histogram,
             combined_min.item(),
             combined_max.item(),
             self.histogram,
             min_val.item(),
             max_val.item(),
         )
         self._combine_histograms(
             combined_histogram,
             combined_min.item(),
             combined_max.item(),
             new_histogram,
             new_min.item(),
             new_max.item(),
         )
         self.histogram = combined_histogram
         self.min_val = combined_min
         self.max_val = combined_max
     return x
Beispiel #8
0
    def getHistogramsTorch(img):

        img = Image.fromarray((img).astype(np.uint8))

        h, s, v = img.convert('HSV').split()
        h = np.array(h)
        s = np.array(s)
        v = np.array(v)

        histH = torch.histc(torch.from_numpy(h).float().to(DEVICE),
                            bins=NBINS,
                            min=0.0,
                            max=255.0)
        histS = torch.histc(torch.from_numpy(s).float().to(DEVICE),
                            bins=NBINS,
                            min=0.0,
                            max=255.0)
        histV = torch.histc(torch.from_numpy(v).float().to(DEVICE),
                            bins=NBINS,
                            min=0.0,
                            max=255.0)

        imgGray = img.convert('L')

        # # settings for LBP
        # # radius = 3
        # # n_points = 8 * radius
        # # METHOD = 'uniform'
        # # lbp = ft.local_binary_pattern(imgGray, n_points, radius, METHOD)
        lbp = ft.local_binary_pattern(imgGray, 24, 3, 'uniform')

        histLBP = torch.histc(torch.from_numpy(lbp).float().to(DEVICE),
                              bins=NBINS,
                              min=0.0,
                              max=255.0)

        hist = torch.stack((histH, histS, histV, histLBP))
        hist = hist.view(-1)

        # img.save("tmp/bgTmp"+str((time.time()))+".png")

        return hist.numpy()
Beispiel #9
0
def _get_vector_label_from_map(target, num_class):
    batch = target.size(0)
    mvect = torch.zeros(batch, num_class)
    for i in range(batch):
        hist = torch.histc(target[i].cpu().data.float(),
                           bins=num_class,
                           min=0,
                           max=num_class - 1)
        vect = hist > 0
        mvect[i] = vect
    return mvect
Beispiel #10
0
 def calculate_weight(self, target):
     """
     calculate weights of classes based on the training crop
     """
     bins = torch.histc(target, bins=self.num_classes, min=0.0, max=self.num_classes)
     hist_norm = bins.float() / bins.sum()
     if self.norm:
         hist = ((bins != 0).float() * self.upper_bound * (1 / hist_norm)) + 1.0
     else:
         hist = ((bins != 0).float() * self.upper_bound * (1. - hist_norm)) + 1.0
     return hist
Beispiel #11
0
    def forward(self,feat_t0,feat_t1,ground_truth):

        n,c,h,w = feat_t0.data.shape
        out_t0_rz = torch.transpose(feat_t0.view(c,h*w),1,0)
        out_t1_rz = torch.transpose(feat_t1.view(c,h*w),1,0)
        gt_np = ground_truth.view(h * w).data.cpu().numpy()
        #### inspired by Source code from Histogram loss ###
        ### get all pos in positive pairs and negative pairs ###
        pos_inds_np,neg_inds_np = np.squeeze(np.where(gt_np == 0), 1),np.squeeze(np.where(gt_np !=0),1)
        pos_size,neg_size = pos_inds_np.shape[0],neg_inds_np.shape[0]
        pos_inds,neg_inds = torch.from_numpy(pos_inds_np).cuda(),torch.from_numpy(neg_inds_np).cuda()
        ### get similarities(l2 distance) for all position ###
        distance = torch.squeeze(self.various_distance(out_t0_rz,out_t1_rz),dim=1)
        ### build similarity histogram of positive pairs and negative pairs ###
        pos_dist_ls,neg_dist_ls = distance[pos_inds],distance[neg_inds]
        pos_dist_ls_t,neg_dist_ls_t = torch.from_numpy(pos_dist_ls.data.cpu().numpy()),torch.from_numpy(neg_dist_ls.data.cpu().numpy())
        hist_pos = Variable(torch.histc(pos_dist_ls_t,bins=100,min=0,max=1)/pos_size,requires_grad=True)
        hist_neg = Variable(torch.histc(neg_dist_ls_t,bins=100,min=0,max=1)/neg_size,requires_grad=True)
        loss = self.distance(hist_pos,hist_neg)
        return loss
Beispiel #12
0
def batch_intersection_union(output, target, nclass):
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = torch.argmax(output, 1) + 1
    target = target.float() + 1

    predict = predict.float() * (target > 0).float()
    intersection = predict * (predict == target).float()

    area_inter = torch.histc(intersection.cpu(),
                             bins=nbins,
                             min=mini,
                             max=maxi)
    area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
    area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
    area_union = area_pred + area_lab - area_inter
    assert torch.sum(area_inter > area_union).item(
    ) == 0, "Intersection area should be smaller than Union area"
    return area_inter.float(), area_union.float()
def batch_intersection_union(predict, target, num_class, labeled):
    predict = predict * labeled.long()
    intersection = predict * (predict == target).long()

    area_inter = torch.histc(intersection.float(),
                             bins=num_class,
                             max=num_class,
                             min=1)
    area_pred = torch.histc(predict.float(),
                            bins=num_class,
                            max=num_class,
                            min=1)
    area_lab = torch.histc(target.float(),
                           bins=num_class,
                           max=num_class,
                           min=1)
    area_union = area_pred + area_lab - area_inter
    assert (area_inter <= area_union
            ).all(), "Intersection area should be smaller than Union area"
    return area_inter.cpu().numpy(), area_union.cpu().numpy()
Beispiel #14
0
 def forward(self, x):
     min_val = self.min_val
     max_val = self.max_val
     histogram = self.histogram
     if min_val is None or max_val is None or histogram is None:
         min_val = torch.min(x)
         max_val = torch.max(x)
         range = max_val - min_val
         self.min_val = min_val - 0.5 * range
         self.max_val = max_val + 0.5 * range
         self.histogram = torch.histc(x,
                                      self.bins,
                                      min=min_val - 0.5 * range,
                                      max=max_val + 0.5 * range)
     else:
         if min_val < torch.min(x) or max_val > torch.max(x):
             warnings.warn(
                 "Incoming data is outside the min_val/max_val range.")
         new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val)
         self.histogram = new_histogram + histogram
def singlebar_hist(Out_list: torch.Tensor, args, name: str):
    out_hist = torch.histc(Out_list,
                           bins=args.hist_bins,
                           min=Out_list.min(),
                           max=Out_list.max())
    x = np.arange(args.hist_bins)
    out_hist = out_hist.data.cpu().numpy()
    plt.bar(x, out_hist, color='C1')
    save_name = os.path.join(args.histfolder, name + '.png')
    plt.savefig(save_name, bbox_inches='tight', dpi=args.plot_quality)
    plt.close()
Beispiel #16
0
 def forward(self, feature, label):
     center_exp = self.center.index_select(dim=0, index=label.long())
     count = torch.histc(label,
                         bins=int(max(label).item() + 1),
                         min=0,
                         max=int(max(label).item()))
     count_exp = count.index_select(dim=0, index=label.long())
     loss = self.lambdas / 2 * torch.mean(
         torch.div(torch.sum(torch.pow(feature - center_exp, 2), dim=1),
                   count_exp))
     return loss
def batch_intersection_union(output, target, nclass):
    """mIoU"""
    # inputs are numpy array, output 4D, target 3D
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = torch.argmax(output, 1) + 1
    target = target.float() + 1

    predict = predict.float() * (target > 0).float()
    intersection = predict * (predict == target).float()
    # areas of intersection and union
    # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
    area_inter = torch.histc(intersection, bins=nbins, min=mini, max=maxi)
    area_pred = torch.histc(predict, bins=nbins, min=mini, max=maxi)
    area_lab = torch.histc(target, bins=nbins, min=mini, max=maxi)
    area_union = area_pred + area_lab - area_inter
    assert torch.sum(area_inter > area_union).item(
    ) == 0, "Intersection area should be smaller than Union area"
    return area_inter.float(), area_union.float()
Beispiel #18
0
def _scale_channel(im: torch.Tensor) -> torch.Tensor:
    r"""Scale the data in the channel to implement equalize.

    Args:
        input (torch.Tensor): image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
    Returns:
        torch.Tensor: image tensor with the batch in the zero position.
    """
    min_ = im.min()
    max_ = im.max()

    if min_.item() < 0. and not torch.isclose(
            min_, torch.tensor(0., dtype=min_.dtype)):
        raise ValueError(
            f"Values in the input tensor must greater or equal to 0.0. Found {min_.item()}."
        )
    if max_.item() > 1. and not torch.isclose(
            max_, torch.tensor(1., dtype=max_.dtype)):
        raise ValueError(
            f"Values in the input tensor must lower or equal to 1.0. Found {max_.item()}."
        )

    ndims = len(im.shape)
    if ndims not in (2, 3):
        raise TypeError(
            f"Input tensor must have 2 or 3 dimensions. Found {ndims}.")

    im = im * 255
    # Compute the histogram of the image channel.
    histo = torch.histc(im, bins=256, min=0, max=255)
    # For the purposes of computing the step, filter out the nonzeros.
    nonzero_histo = torch.reshape(histo[histo != 0], [-1])
    step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255

    def build_lut(histo, step):
        # Compute the cumulative sum, shifting by step // 2
        # and then normalization by step.
        lut = (torch.cumsum(histo, 0) + (step // 2)) // step
        # Shift lut, prepending with 0.
        lut = torch.cat([torch.zeros(1, device=lut.device), lut[:-1]])
        # Clip the counts to be in range.  This is done
        # in the C code for image.point.
        return torch.clamp(lut, 0, 255)

    # If step is zero, return the original image.  Otherwise, build
    # lut from the full histogram and step and then index from it.
    if step == 0:
        result = im
    else:
        # can't index using 2d index. Have to flatten and then reshape
        result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
        result = result.reshape_as(im)

    return result / 255.
Beispiel #19
0
 def _get_batch_label_vector(target, nclass):
         # target is a 3D Variable BxHxW, output is 2D BxnClass
         batch = target.size(0)
         tvect = torch.zeros(batch, nclass)
         for i in range(batch):
             hist = torch.histc(target[i].cpu().data.float(),
                             bins=nclass, min=0,
                             max=nclass-1)
             vect = hist > 0
             tvect[i] = vect
         return tvect
Beispiel #20
0
    def forward(self, xs, label):
        xs = f.normalize(xs)
        #根据label索引选择中心点
        cen_select = self.center.index_select(dim=0, index=label)
        #统计出每个类的data---->[2,1]
        count = torch.histc(label.float(), bins=10, min=0, max=9)
        #根据count出来的数量从label里重新选择,count_dis为每个data对于的数量----》[2,2,1]
        count_dis = count.index_select(dim=0, index=label)

        return torch.sum(
            torch.sum((xs - cen_select)**2, dim=1) / count_dis.float())
Beispiel #21
0
def f_compute_hist(data, bins):

    try:
        hist_data = torch.histc(data, bins=bins)
        ## A kind of normalization of histograms: divide by total sum
        hist_data = (hist_data * bins) / torch.sum(hist_data)
    except Exception as e:
        print(e)
        hist_data = torch.zeros(bins)

    return hist_data
Beispiel #22
0
def compute_mIOU(logits, target):
    # Assumes logits (B, n_classes, H, W), target (B, H, W)
    n_classes = logits.shape[1]
    pred = torch.argmax(logits, dim=1)

    # Ignore background class 0
    intersection = pred * (pred == target)
    area_intersection = torch.histc(intersection,
                                    bins=n_classes - 1,
                                    min=1,
                                    max=n_classes - 1)

    area_pred = torch.histc(pred, bins=n_classes - 1, min=1, max=n_classes - 1)
    area_target = torch.histc(target,
                              bins=n_classes - 1,
                              min=1,
                              max=n_classes - 1)
    area_union = area_pred + area_target - area_intersection

    return torch.mean(area_intersection / (area_union + 1e-10)) * 100.
 def even_split(self, num_train, num_dev_test):
     while num_dev_test / (num_train +
                           num_dev_test) > 0.2 and num_dev_test > 2:
         num_dev_test -= 1
         num_train += 1
     hist = torch.histc(torch.range(0, num_dev_test - 1),
                        bins=2,
                        min=0,
                        max=num_dev_test - 1)
     num_dev, num_test = hist.type(torch.LongTensor).tolist()
     return num_train, num_dev, num_test
Beispiel #24
0
def batch_intersection_union(output, target, nclass):
    """mIoU"""
    # inputs are NDarray, output 4D, target 3D
    # the category -1 is ignored class, typically for background / boundary
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = torch.argmax(output, 1) + 1
    target = target.float() + 1

    predict = predict.float() * (target > 0).float()
    intersection = predict * (predict == target).float()
    # areas of intersection and union
    area_inter = torch.histc(intersection, bins=nbins, min=mini, max=maxi)
    area_pred = torch.histc(predict, bins=nbins, min=mini, max=maxi)
    area_lab = torch.histc(target, bins=nbins, min=mini, max=maxi)
    area_union = area_pred + area_lab - area_inter
    assert torch.sum(area_inter > area_union).item() == 0, \
        "Intersection area should be smaller than Union area"
    return area_inter.float(), area_union.float()
Beispiel #25
0
 def forward(self, xs, ys):
     center_exp = self.center.index_select(dim=0, index=ys.long())
     count = torch.histc(ys.float(),
                         bins=self.cls_num,
                         min=0,
                         max=self.cls_num - 1)
     count_exp = count.index_select(dim=0, index=ys.long()).float()
     loss = torch.sum(
         torch.mean(torch.pow(xs - center_exp, 2), dim=1) /
         (2. * count_exp)) / self.cls_num
     return loss
Beispiel #26
0
    def get_iou(self, output, target):
        if isinstance(output, tuple):
            output = output[0]

        if len(output.size()) == 4:  # Case of raw outputs
            _, pred = torch.max(output, 1)
        else:  # Case of argmax
            pred = output

        # histc in torch is implemented only for cpu tensors, so move your tensors to CPU
        if pred.device == torch.device('cuda'):
            pred = pred.cpu()
        if target.device == torch.device('cuda'):
            target = target.cpu()

        pred = pred.type(torch.ByteTensor)
        target = target.type(torch.ByteTensor)

        # shift by 1 so that 255 is 0
        pred += 1
        target += 1

        pred = pred * (target > 0)
        # pred = pred * (target < self.num_classes)
        inter = pred * (pred == target)
        # inter = pred * (target < self.num_classes)
        area_inter = torch.histc(inter.float(),
                                 bins=self.num_classes,
                                 min=1,
                                 max=self.num_classes)
        area_pred = torch.histc(pred.float(),
                                bins=self.num_classes,
                                min=1,
                                max=self.num_classes)
        area_mask = torch.histc(target.float(),
                                bins=self.num_classes,
                                min=1,
                                max=self.num_classes)
        area_union = area_pred + area_mask - area_inter + self.epsilon

        return area_inter.numpy(), area_union.numpy()
Beispiel #27
0
def calculate_KL_divergence(real_data,
                            predicted_data,
                            min_obs,
                            max_obs,
                            bin_size=0.1):
    """
        real_data: in case of 1D-GMM this is the dataset
        predicted_data: this is the data produced by the generators
        min_obs: the lower limit of the dataset
        max_obs: the upper limit of the dataset
            data ranges from [min_obs, max_obs)
        bin_size: the size of the bin for each class
    """
    #first create the tensor of zeros of size (max_obs-min_obs)/bin_size

    num_samples = real_data.size(0)
    # print(type(num_samples))
    num_bins = int((max_obs - min_obs) / (bin_size))

    print(real_data.type(), predicted_data.type())

    real_bins = torch.histc(real_data, num_bins, min_obs,
                            max_obs) / float(num_samples)

    pred_bins = torch.histc(predicted_data, num_bins, min_obs,
                            max_obs).float() / float(num_samples)

    print(real_bins, pred_bins)

    # for k in range(1000):
    # 	print(k, pred_bins[k])

    real_entropy = torch.mul(torch.dot(real_bins, custom_log(real_bins)),
                             float(-1))
    cross_entropy = torch.mul(torch.dot(real_bins, custom_log(pred_bins)),
                              float(-1))

    kl = cross_entropy - real_entropy
    print(cross_entropy, real_entropy)
    print(kl)
    return kl
    def forward(ctx, input, minV, maxV):

        # 将区间分为[minV, maxV] 步长为1 统计落入各个值的数量
        # input应该为范围在[minV, maxV]之间的整数
        p = torch.histc(input, bins=(maxV - minV + 1), min=minV, max=maxV)

        p = p / p.shape[0]  # 转换为各个数值的概率值

        ctx.save_for_backward(input, p)

        entropy = -p.mul(p.log2()).sum()
        return entropy
Beispiel #29
0
def _torch_histc_cast(input: torch.Tensor, bins: int, min: int, max: int) -> torch.Tensor:
    """Helper function to make torch.histc work with other than fp32/64.

    The function torch.histc is only implemented for fp32/64 which makes impossible to be used by fp16 or others. What
    this function does, is cast input data type to fp32, apply torch.inverse, and cast back to the input dtype.
    """
    if not isinstance(input, torch.Tensor):
        raise AssertionError(f"Input must be torch.Tensor. Got: {type(input)}.")
    dtype: torch.dtype = input.dtype
    if dtype not in (torch.float32, torch.float64):
        dtype = torch.float32
    return torch.histc(input.to(dtype), bins, min, max).to(input.dtype)
Beispiel #30
0
def energy_hist(Out_list: torch.Tensor, Target_list:torch.Tensor, args, name:str):
    unknown_label = Target_list.max()
    unknown_list = Out_list[Target_list == unknown_label]
    known_list = Out_list[Target_list != unknown_label]
    unknown_hist = torch.histc(unknown_list, bins=args.hist_bins, min=Out_list.min().data,
                               max=Out_list.max().data)
    known_hist = torch.histc(known_list, bins=args.hist_bins, min=Out_list.min().data,
                           max=Out_list.max().data)
    if args.hist_norm:
        unknown_hist = unknown_hist/(unknown_hist.sum())
        known_hist = known_hist/(known_hist.sum())
        name += "_normed"
    if args.hist_save:
        plot_bar(unknown_hist, known_hist, args, name)
    torch.save(
        {"unknown": unknown_hist,
         "known": known_hist,
         },
        os.path.join(args.histfolder, name + '.pkl')
    )
    print(f"{name} processed.")
Beispiel #31
0
    def update(self, x, y, z):
        if self.state == None:
            return

        if self.state and self.state['stage'] == 0:
            y_abs = torch.abs(y.data)
            self.state['min'] = min(self.state['min'], torch.min(y_abs))
            self.state['max'] = max(self.state['max'], torch.max(y_abs))

        if self.state and self.state['stage'] == 1:
            # Update histogram
            if self.state['quantized_out']:
                y_abs = torch.abs(y.data)
                self.state['histogram'] += torch.histc(y_abs.cpu(), bins=2048, min=self.state['min'], max=self.state['max'])


        if self.state and self.state['stage'] == 2:
            # Activate or not quantization
            self.quantized_in = self.state['quantized_in']
            self.quantized_out = self.state['quantized_out']

            # Compute scales
            self.scale[0] = self.quantizer.history[id(x)] if self.quantized_in else 1.
            self.scale[-1] = self.quantizer.history[id(z)] if z is not None else 1.
            if self.quantized_out:
                scale = self.quantizer.scale(self.state['histogram'], True, bins = self.state['bins'])
                self.scale[-2] = self.quantizer.history[id(y)] = scale
            else:
                self.scale[-2] = 1.

            # Scale weights
            if self.quantized_in and self.weight is not None:
                self.scale[1] = self.quantizer.scale(self.weight.data, False)
                transpose_pack(self.weight, self.scale[1])

            # Reset
            self.state = None
            self.quantizer = None
Beispiel #32
0
def get_doc_freqs_t(cnts):
    """Return word --> # of docs it appears in (torch version)."""
    return torch.histc(cnts._indices()[0].float(), bins=cnts.size(0),
                       min=0, max=cnts.size(0))