def f(x):
     # aminmax.out returns a tuple of tensors.
     # functionalization should properly handle the tuple.
     out_min = torch.empty(4)
     out_max = torch.empty(4)
     torch.aminmax(x, dim=0, out=(out_max, out_min))
     return out_max
示例#2
0
    def _calculate_min_max_stats(self, x_copy):
        r"""Calculates and stores the per_channel min, max stats with forward values.
        Does calculation based on channel axis: self.ch_axis

        Args
            x_copy: A copy of the forward data

        Returns the passed in x_copy
        """
        # get the current min and max vals
        min_val = self.min_val
        max_val = self.max_val
        x_dim = x_copy.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x_copy.permute(new_axis_list)
        # Need to match dtype of min/max because the updates to buffers
        # are done in place and types need to match for comparisons
        y = y.to(self.min_val.dtype)
        y = torch.flatten(y, start_dim=1)
        if min_val.numel() == 0 or max_val.numel() == 0:
            min_val, max_val = torch.aminmax(y, dim=1)
        else:
            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
            min_val = torch.min(min_val_cur, min_val)
            max_val = torch.max(max_val_cur, max_val)

        self.min_val.resize_(min_val.shape)
        self.max_val.resize_(max_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.copy_(max_val)

        return x_copy
示例#3
0
 def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(
         self):
     x = torch.rand((3, 3))
     actual = torch.zeros(3)
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         torch.aminmax(x, dim=0, out=[actual, actual])
     self.assertEqual(torch.amax(x, dim=0), actual)
示例#4
0
 def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
     x = torch.rand((3, 3))
     actual = torch.zeros(3)
     schema_check = SchemaCheckMode()
     with enable_torch_dispatch_mode(schema_check):
         torch.aminmax(x, dim=0, out=[actual, actual])
     self.assertEqual([('aten::aminmax', 'min'), ('aten::aminmax', 'max')],
                      schema_check.mutated)
     self.assertEqual([('aten::aminmax', 'min', 'output_0'),
                       ('aten::aminmax', 'min', 'output_1'),
                       ('aten::aminmax', 'max', 'output_0'),
                       ('aten::aminmax', 'max', 'output_1')],
                      schema_check.aliasing)
示例#5
0
 def reduction_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     return (
         torch.argmax(a),
         torch.argmin(a),
         torch.amax(a),
         torch.amin(a),
         torch.aminmax(a),
         torch.all(a),
         torch.any(a),
         torch.max(a),
         torch.min(a),
         torch.dist(a, b),
         torch.logsumexp(a, 0),
         torch.mean(a),
         torch.nanmean(a),
         torch.median(a),
         torch.nanmedian(a),
         torch.mode(a),
         torch.norm(a),
         torch.nansum(a),
         torch.prod(a),
         torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])),
         torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])),
         torch.std(a),
         torch.std_mean(a),
         torch.sum(a),
         torch.unique(a),
         torch.unique_consecutive(a),
         torch.var(a),
         torch.var_mean(a),
         torch.count_nonzero(a),
     )
示例#6
0
    def _calculate_range_stats(self, x_copy):
        r"""Calculates and stores range stats with forward values.

        Args
            x_copy: A copy of the forward data

        Returns the passed in x_copy
        """
        # get the min, max values of the data
        min_val_cur, max_val_cur = torch.aminmax(x_copy)

        # calculate new epoch range values
        epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur)
        epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur)

        self.epoch_activation_min.copy_(epoch_min_val)
        self.epoch_activation_max.copy_(epoch_max_val)

        # calculate the average batch activation range
        current_batch_range = max_val_cur - min_val_cur
        new_range = (
            self.average_batch_activation_range * self.num_batches_tracked +
            current_batch_range) / (self.num_batches_tracked + 1)

        self.average_batch_activation_range = new_range
        self.num_batches_tracked += 1  # new batch was processed

        return x_copy
示例#7
0
 def forward(self, x_orig):
     if x_orig.numel() == 0:
         return x_orig
     x = x_orig.detach()
     min_val, max_val = torch.aminmax(x)
     if self.min_val.numel():
         min_val = torch.min(min_val, self.min_val)
     if self.max_val.numel():
         max_val = torch.max(max_val, self.max_val)
     self.min_val = min_val
     self.max_val = max_val
     return x_orig
示例#8
0
    def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size):
        # split up data into batches
        split_up_data = torch.split(ex_input, batch_size)
        for epoch in range(num_epochs):
            # reset all model report obs
            model.apply(
                lambda module: module.reset_batch_and_epoch_values()
                if isinstance(module, ModelReportObserver)
                else None
            )

            # quick check that a reset occurred
            self.assertEqual(
                getattr(model, "obs1").average_batch_activation_range,
                torch.tensor(float(0)),
            )
            self.assertEqual(getattr(model, "obs1").epoch_activation_min, torch.tensor(float("inf")))
            self.assertEqual(getattr(model, "obs1").epoch_activation_max, torch.tensor(float("-inf")))

            # loop through the batches and run through
            for index, batch in enumerate(split_up_data):

                num_tracked_so_far = getattr(model, "obs1").num_batches_tracked
                self.assertEqual(num_tracked_so_far, index)

                # get general info about the batch and the model to use later
                batch_min, batch_max = torch.aminmax(batch)
                current_average_range = getattr(model, "obs1").average_batch_activation_range
                current_epoch_min = getattr(model, "obs1").epoch_activation_min
                current_epoch_max = getattr(model, "obs1").epoch_activation_max

                # run input through
                model(ex_input)

                # check that average batch activation range updated correctly
                correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / (
                    num_tracked_so_far + 1
                )
                self.assertEqual(
                    getattr(model, "obs1").average_batch_activation_range,
                    correct_updated_value,
                )

                if current_epoch_max - current_epoch_min > 0:
                    self.assertEqual(
                        getattr(model, "obs1").get_batch_to_epoch_ratio(),
                        correct_updated_value / (current_epoch_max - current_epoch_min),
                    )
示例#9
0
    def forward(self, x):
        x_copy = x.detach()  # avoid keeping autograd tape
        x_copy = x_copy.to(self.epoch_activation_min.dtype)
        min_val_cur, max_val_cur = torch.aminmax(x_copy)

        # calculate new epoch range values
        epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur)
        epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur)

        self.epoch_activation_min.copy_(epoch_min_val)
        self.epoch_activation_max.copy_(epoch_max_val)

        # calculate the average batch activation range
        current_batch_range = max_val_cur - min_val_cur
        new_range = (
            self.average_batch_activation_range * self.num_batches_tracked +
            current_batch_range) / (self.num_batches_tracked + 1)

        self.average_batch_activation_range = new_range
        self.num_batches_tracked += 1  # new batch was processed

        # return the passed in the value
        return x
示例#10
0
    def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]:
        r"""
        Takes in a callibrated GraphModule and then finds the relavent observers.
        It then extracts the weight information for each layer an observer is attached to.

        Args
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a dict mapping module fqns (str) to a dict with keys:
            "per_channel_max" : maps to the per_channel max values
            "per_channel_min" : maps to the per_channel min values
            "global_max" : maps to the global max recorded
            "global_min" : maps to the global min recorded
        """
        # return dictionary mapping observer fqns to desired info
        weight_info: Dict[str, Dict] = {}

        for fqn, module in model.named_modules():
            # if module is supported and it has a pre-observer
            if self._is_supported(module):
                # we don't need actual observer, just the module weights
                # calculate min and max vals
                min_val, max_val = torch.aminmax(module.weight,
                                                 dim=self.ch_axis)

                # flatten entries since conv can have multiple dimensions
                min_val = torch.flatten(min_val)
                max_val = torch.flatten(max_val)

                weight_info[fqn] = {
                    self.PER_CHANNEL_MAX_KEY: max_val,
                    self.PER_CHANNEL_MIN_KEY: min_val,
                    self.GLOBAL_MAX_KEY: max(max_val),
                    self.GLOBAL_MIN_KEY: min(min_val),
                }

        return weight_info