コード例 #1
0
ファイル: observer.py プロジェクト: loadbxh/Torch
    def forward(self, x_orig):
        x = x_orig.detach()  # avoid keeping autograd tape
        x = x.to(self.min_vals.dtype)
        min_vals = self.min_vals
        max_vals = self.max_vals
        x_dim = x.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x.permute(new_axis_list)
        y = torch.flatten(y, start_dim=1)
        if min_vals.numel() == 0 or max_vals.numel() == 0:
            min_vals, max_vals = torch._aminmax(y, 1)
        else:
            min_vals_cur, max_vals_cur = torch._aminmax(y, 1)
            min_vals = min_vals + self.averaging_constant * (min_vals_cur -
                                                             min_vals)
            max_vals = max_vals + self.averaging_constant * (max_vals_cur -
                                                             max_vals)
        self.min_vals.resize_(min_vals.shape)
        self.max_vals.resize_(max_vals.shape)
        self.min_vals.copy_(min_vals)
        self.max_vals.copy_(max_vals)
        return x_orig
コード例 #2
0
ファイル: observer.py プロジェクト: loadbxh/Torch
    def _forward(self, x_orig):
        x = x_orig.detach()  # avoid keeping autograd tape
        min_vals = self.min_vals
        max_vals = self.max_vals
        x_dim = x.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x.permute(new_axis_list)
        # Need to match dtype of min/max because the updates to buffers
        # are done in place and types need to match for comparisons
        y = y.to(self.min_vals.dtype)
        y = torch.flatten(y, start_dim=1)
        if min_vals.numel() == 0 or max_vals.numel() == 0:
            min_vals, max_vals = torch._aminmax(y, 1)
        else:
            min_vals_cur, max_vals_cur = torch._aminmax(y, 1)
            min_vals = torch.min(min_vals_cur, min_vals)
            max_vals = torch.max(max_vals_cur, max_vals)
        self.min_vals.resize_(min_vals.shape)
        self.max_vals.resize_(max_vals.shape)
        self.min_vals.copy_(min_vals)
        self.max_vals.copy_(max_vals)
        return x_orig
コード例 #3
0
    def forward(self, x_orig):
        # type: (torch.Tensor) -> torch.Tensor
        x = x_orig.detach()
        min_val = self.min_val
        max_val = self.max_val
        same_values = min_val.item() == max_val.item()
        is_uninitialized = min_val == float('inf') and max_val == float('-inf')
        if is_uninitialized or same_values:
            min_val, max_val = torch._aminmax(x)
            self.min_val.resize_(min_val.shape)
            self.min_val.copy_(min_val)
            self.max_val.resize_(max_val.shape)
            self.max_val.copy_(max_val)
            assert min_val.numel() == 1 and max_val.numel() == 1, (
                "histogram min/max values must be scalar."
            )
            torch.histc(x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram)
        else:
            new_min, new_max = torch._aminmax(x)
            combined_min = torch.min(new_min, min_val)
            combined_max = torch.max(new_max, max_val)
            # combine the existing histogram and new histogram into 1 histogram
            # We do this by first upsampling the histogram to a dense grid
            # and then downsampling the histogram efficiently
            combined_min, combined_max, downsample_rate, start_idx = \
                self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
            assert combined_min.numel() == 1 and combined_max.numel() == 1, (
                "histogram min/max values must be scalar."
            )
            combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max))
            if combined_min == min_val and combined_max == max_val:
                combined_histogram += self.histogram
            else:
                combined_histogram = self._combine_histograms(
                    combined_histogram,
                    self.histogram,
                    self.upsample_rate,
                    downsample_rate,
                    start_idx,
                    self.bins)

            self.histogram.resize_(combined_histogram.shape)
            self.histogram.copy_(combined_histogram)
            self.min_val.resize_(combined_min.shape)
            self.min_val.copy_(combined_min)
            self.max_val.resize_(combined_max.shape)
            self.max_val.copy_(combined_max)
        return x_orig
コード例 #4
0
 def forward(self, x_orig):
     x = x_orig.detach()  # avoid keeping autograd tape
     x = x.to(self.min_val.dtype)
     min_val = self.min_val
     max_val = self.max_val
     if min_val == float('inf') and max_val == float('-inf'):
         min_val, max_val = torch._aminmax(x)
     else:
         min_val_cur, max_val_cur = torch._aminmax(x)
         min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
         max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
     self.min_val.resize_(min_val.shape)
     self.max_val.resize_(max_val.shape)
     self.min_val.copy_(min_val)
     self.max_val.copy_(max_val)
     return x_orig
コード例 #5
0
    def forward(self, x_orig):
        # type: (Tensor) -> Tensor
        x = x_orig.detach()
        min_val = self.min_val
        max_val = self.max_val
        same_values = False
        if min_val.numel() > 0 and max_val.numel() > 0:
            same_values = min_val.item() == max_val.item()
        if min_val.numel() == 0 or max_val.numel() == 0 or same_values:
            min_val, max_val = torch._aminmax(x)
            self.min_val.resize_(min_val.shape)
            self.min_val.copy_(min_val)
            self.max_val.resize_(max_val.shape)
            self.max_val.copy_(max_val)
            torch.histc(x,
                        self.bins,
                        min=min_val,
                        max=max_val,
                        out=self.histogram)
        else:
            new_min, new_max = torch._aminmax(x)
            combined_min = torch.min(new_min, min_val)
            combined_max = torch.max(new_max, max_val)
            # combine the existing histogram and new histogram into 1 histogram
            # We do this by first upsampling the histogram to a dense grid
            # and then downsampling the histogram efficiently
            combined_min, combined_max, downsample_rate, start_idx = \
                self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
            combined_histogram = torch.histc(x,
                                             self.bins,
                                             min=combined_min,
                                             max=combined_max)
            if combined_min == min_val and combined_max == max_val:
                combined_histogram += self.histogram
            else:
                combined_histogram = self._combine_histograms(
                    combined_histogram, self.histogram, self.upsample_rate,
                    downsample_rate, start_idx, self.bins)

            self.histogram.resize_(combined_histogram.shape)
            self.histogram.copy_(combined_histogram)
            self.min_val.resize_(combined_min.shape)
            self.min_val.copy_(combined_min)
            self.max_val.resize_(combined_max.shape)
            self.max_val.copy_(combined_max)
        return x_orig
コード例 #6
0
ファイル: observer.py プロジェクト: loadbxh/Torch
 def forward(self, x_orig):
     r"""Records the running minimum and maximum of ``x``."""
     x = x_orig.detach()  # avoid keeping autograd tape
     x = x.to(self.min_val.dtype)
     min_val_cur, max_val_cur = torch._aminmax(x)
     min_val = torch.min(min_val_cur, self.min_val)
     max_val = torch.max(max_val_cur, self.max_val)
     self.min_val.copy_(min_val)
     self.max_val.copy_(max_val)
     return x_orig