def forward(self, x_orig): x = x_orig.detach() # avoid keeping autograd tape x = x.to(self.min_vals.dtype) min_vals = self.min_vals max_vals = self.max_vals x_dim = x.size() new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 new_axis_list[self.ch_axis] = 0 new_axis_list[0] = self.ch_axis y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) if min_vals.numel() == 0 or max_vals.numel() == 0: min_vals, max_vals = torch._aminmax(y, 1) else: min_vals_cur, max_vals_cur = torch._aminmax(y, 1) min_vals = min_vals + self.averaging_constant * (min_vals_cur - min_vals) max_vals = max_vals + self.averaging_constant * (max_vals_cur - max_vals) self.min_vals.resize_(min_vals.shape) self.max_vals.resize_(max_vals.shape) self.min_vals.copy_(min_vals) self.max_vals.copy_(max_vals) return x_orig
def _forward(self, x_orig): x = x_orig.detach() # avoid keeping autograd tape min_vals = self.min_vals max_vals = self.max_vals x_dim = x.size() new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 new_axis_list[self.ch_axis] = 0 new_axis_list[0] = self.ch_axis y = x.permute(new_axis_list) # Need to match dtype of min/max because the updates to buffers # are done in place and types need to match for comparisons y = y.to(self.min_vals.dtype) y = torch.flatten(y, start_dim=1) if min_vals.numel() == 0 or max_vals.numel() == 0: min_vals, max_vals = torch._aminmax(y, 1) else: min_vals_cur, max_vals_cur = torch._aminmax(y, 1) min_vals = torch.min(min_vals_cur, min_vals) max_vals = torch.max(max_vals_cur, max_vals) self.min_vals.resize_(min_vals.shape) self.max_vals.resize_(max_vals.shape) self.min_vals.copy_(min_vals) self.max_vals.copy_(max_vals) return x_orig
def forward(self, x_orig): # type: (torch.Tensor) -> torch.Tensor x = x_orig.detach() min_val = self.min_val max_val = self.max_val same_values = min_val.item() == max_val.item() is_uninitialized = min_val == float('inf') and max_val == float('-inf') if is_uninitialized or same_values: min_val, max_val = torch._aminmax(x) self.min_val.resize_(min_val.shape) self.min_val.copy_(min_val) self.max_val.resize_(max_val.shape) self.max_val.copy_(max_val) assert min_val.numel() == 1 and max_val.numel() == 1, ( "histogram min/max values must be scalar." ) torch.histc(x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram) else: new_min, new_max = torch._aminmax(x) combined_min = torch.min(new_min, min_val) combined_max = torch.max(new_max, max_val) # combine the existing histogram and new histogram into 1 histogram # We do this by first upsampling the histogram to a dense grid # and then downsampling the histogram efficiently combined_min, combined_max, downsample_rate, start_idx = \ self._adjust_min_max(combined_min, combined_max, self.upsample_rate) assert combined_min.numel() == 1 and combined_max.numel() == 1, ( "histogram min/max values must be scalar." ) combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max)) if combined_min == min_val and combined_max == max_val: combined_histogram += self.histogram else: combined_histogram = self._combine_histograms( combined_histogram, self.histogram, self.upsample_rate, downsample_rate, start_idx, self.bins) self.histogram.resize_(combined_histogram.shape) self.histogram.copy_(combined_histogram) self.min_val.resize_(combined_min.shape) self.min_val.copy_(combined_min) self.max_val.resize_(combined_max.shape) self.max_val.copy_(combined_max) return x_orig
def forward(self, x_orig): x = x_orig.detach() # avoid keeping autograd tape x = x.to(self.min_val.dtype) min_val = self.min_val max_val = self.max_val if min_val == float('inf') and max_val == float('-inf'): min_val, max_val = torch._aminmax(x) else: min_val_cur, max_val_cur = torch._aminmax(x) min_val = min_val + self.averaging_constant * (min_val_cur - min_val) max_val = max_val + self.averaging_constant * (max_val_cur - max_val) self.min_val.resize_(min_val.shape) self.max_val.resize_(max_val.shape) self.min_val.copy_(min_val) self.max_val.copy_(max_val) return x_orig
def forward(self, x_orig): # type: (Tensor) -> Tensor x = x_orig.detach() min_val = self.min_val max_val = self.max_val same_values = False if min_val.numel() > 0 and max_val.numel() > 0: same_values = min_val.item() == max_val.item() if min_val.numel() == 0 or max_val.numel() == 0 or same_values: min_val, max_val = torch._aminmax(x) self.min_val.resize_(min_val.shape) self.min_val.copy_(min_val) self.max_val.resize_(max_val.shape) self.max_val.copy_(max_val) torch.histc(x, self.bins, min=min_val, max=max_val, out=self.histogram) else: new_min, new_max = torch._aminmax(x) combined_min = torch.min(new_min, min_val) combined_max = torch.max(new_max, max_val) # combine the existing histogram and new histogram into 1 histogram # We do this by first upsampling the histogram to a dense grid # and then downsampling the histogram efficiently combined_min, combined_max, downsample_rate, start_idx = \ self._adjust_min_max(combined_min, combined_max, self.upsample_rate) combined_histogram = torch.histc(x, self.bins, min=combined_min, max=combined_max) if combined_min == min_val and combined_max == max_val: combined_histogram += self.histogram else: combined_histogram = self._combine_histograms( combined_histogram, self.histogram, self.upsample_rate, downsample_rate, start_idx, self.bins) self.histogram.resize_(combined_histogram.shape) self.histogram.copy_(combined_histogram) self.min_val.resize_(combined_min.shape) self.min_val.copy_(combined_min) self.max_val.resize_(combined_max.shape) self.max_val.copy_(combined_max) return x_orig
def forward(self, x_orig): r"""Records the running minimum and maximum of ``x``.""" x = x_orig.detach() # avoid keeping autograd tape x = x.to(self.min_val.dtype) min_val_cur, max_val_cur = torch._aminmax(x) min_val = torch.min(min_val_cur, self.min_val) max_val = torch.max(max_val_cur, self.max_val) self.min_val.copy_(min_val) self.max_val.copy_(max_val) return x_orig