def f(x): # aminmax.out returns a tuple of tensors. # functionalization should properly handle the tuple. out_min = torch.empty(4) out_max = torch.empty(4) torch.aminmax(x, dim=0, out=(out_max, out_min)) return out_max
def _calculate_min_max_stats(self, x_copy): r"""Calculates and stores the per_channel min, max stats with forward values. Does calculation based on channel axis: self.ch_axis Args x_copy: A copy of the forward data Returns the passed in x_copy """ # get the current min and max vals min_val = self.min_val max_val = self.max_val x_dim = x_copy.size() new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 new_axis_list[self.ch_axis] = 0 new_axis_list[0] = self.ch_axis y = x_copy.permute(new_axis_list) # Need to match dtype of min/max because the updates to buffers # are done in place and types need to match for comparisons y = y.to(self.min_val.dtype) y = torch.flatten(y, start_dim=1) if min_val.numel() == 0 or max_val.numel() == 0: min_val, max_val = torch.aminmax(y, dim=1) else: min_val_cur, max_val_cur = torch.aminmax(y, dim=1) min_val = torch.min(min_val_cur, min_val) max_val = torch.max(max_val_cur, max_val) self.min_val.resize_(min_val.shape) self.max_val.resize_(max_val.shape) self.min_val.copy_(min_val) self.max_val.copy_(max_val) return x_copy
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing( self): x = torch.rand((3, 3)) actual = torch.zeros(3) with enable_torch_dispatch_mode(SchemaCheckMode()): torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual(torch.amax(x, dim=0), actual)
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): x = torch.rand((3, 3)) actual = torch.zeros(3) schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual([('aten::aminmax', 'min'), ('aten::aminmax', 'max')], schema_check.mutated) self.assertEqual([('aten::aminmax', 'min', 'output_0'), ('aten::aminmax', 'min', 'output_1'), ('aten::aminmax', 'max', 'output_0'), ('aten::aminmax', 'max', 'output_1')], schema_check.aliasing)
def reduction_ops(self): a = torch.randn(4) b = torch.randn(4) return ( torch.argmax(a), torch.argmin(a), torch.amax(a), torch.amin(a), torch.aminmax(a), torch.all(a), torch.any(a), torch.max(a), torch.min(a), torch.dist(a, b), torch.logsumexp(a, 0), torch.mean(a), torch.nanmean(a), torch.median(a), torch.nanmedian(a), torch.mode(a), torch.norm(a), torch.nansum(a), torch.prod(a), torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])), torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])), torch.std(a), torch.std_mean(a), torch.sum(a), torch.unique(a), torch.unique_consecutive(a), torch.var(a), torch.var_mean(a), torch.count_nonzero(a), )
def _calculate_range_stats(self, x_copy): r"""Calculates and stores range stats with forward values. Args x_copy: A copy of the forward data Returns the passed in x_copy """ # get the min, max values of the data min_val_cur, max_val_cur = torch.aminmax(x_copy) # calculate new epoch range values epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) self.epoch_activation_min.copy_(epoch_min_val) self.epoch_activation_max.copy_(epoch_max_val) # calculate the average batch activation range current_batch_range = max_val_cur - min_val_cur new_range = ( self.average_batch_activation_range * self.num_batches_tracked + current_batch_range) / (self.num_batches_tracked + 1) self.average_batch_activation_range = new_range self.num_batches_tracked += 1 # new batch was processed return x_copy
def forward(self, x_orig): if x_orig.numel() == 0: return x_orig x = x_orig.detach() min_val, max_val = torch.aminmax(x) if self.min_val.numel(): min_val = torch.min(min_val, self.min_val) if self.max_val.numel(): max_val = torch.max(max_val, self.max_val) self.min_val = min_val self.max_val = max_val return x_orig
def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size): # split up data into batches split_up_data = torch.split(ex_input, batch_size) for epoch in range(num_epochs): # reset all model report obs model.apply( lambda module: module.reset_batch_and_epoch_values() if isinstance(module, ModelReportObserver) else None ) # quick check that a reset occurred self.assertEqual( getattr(model, "obs1").average_batch_activation_range, torch.tensor(float(0)), ) self.assertEqual(getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf"))) self.assertEqual(getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf"))) # loop through the batches and run through for index, batch in enumerate(split_up_data): num_tracked_so_far = getattr(model, "obs1").num_batches_tracked self.assertEqual(num_tracked_so_far, index) # get general info about the batch and the model to use later batch_min, batch_max = torch.aminmax(batch) current_average_range = getattr(model, "obs1").average_batch_activation_range current_epoch_min = getattr(model, "obs1").epoch_activation_min current_epoch_max = getattr(model, "obs1").epoch_activation_max # run input through model(ex_input) # check that average batch activation range updated correctly correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / ( num_tracked_so_far + 1 ) self.assertEqual( getattr(model, "obs1").average_batch_activation_range, correct_updated_value, ) if current_epoch_max - current_epoch_min > 0: self.assertEqual( getattr(model, "obs1").get_batch_to_epoch_ratio(), correct_updated_value / (current_epoch_max - current_epoch_min), )
def forward(self, x): x_copy = x.detach() # avoid keeping autograd tape x_copy = x_copy.to(self.epoch_activation_min.dtype) min_val_cur, max_val_cur = torch.aminmax(x_copy) # calculate new epoch range values epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) self.epoch_activation_min.copy_(epoch_min_val) self.epoch_activation_max.copy_(epoch_max_val) # calculate the average batch activation range current_batch_range = max_val_cur - min_val_cur new_range = ( self.average_batch_activation_range * self.num_batches_tracked + current_batch_range) / (self.num_batches_tracked + 1) self.average_batch_activation_range = new_range self.num_batches_tracked += 1 # new batch was processed # return the passed in the value return x
def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]: r""" Takes in a callibrated GraphModule and then finds the relavent observers. It then extracts the weight information for each layer an observer is attached to. Args model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers Returns a dict mapping module fqns (str) to a dict with keys: "per_channel_max" : maps to the per_channel max values "per_channel_min" : maps to the per_channel min values "global_max" : maps to the global max recorded "global_min" : maps to the global min recorded """ # return dictionary mapping observer fqns to desired info weight_info: Dict[str, Dict] = {} for fqn, module in model.named_modules(): # if module is supported and it has a pre-observer if self._is_supported(module): # we don't need actual observer, just the module weights # calculate min and max vals min_val, max_val = torch.aminmax(module.weight, dim=self.ch_axis) # flatten entries since conv can have multiple dimensions min_val = torch.flatten(min_val) max_val = torch.flatten(max_val) weight_info[fqn] = { self.PER_CHANNEL_MAX_KEY: max_val, self.PER_CHANNEL_MIN_KEY: min_val, self.GLOBAL_MAX_KEY: max(max_val), self.GLOBAL_MIN_KEY: min(min_val), } return weight_info