Ejemplo n.º 1
0
    def run_hacked_export_quantization(self, x):
        from nncf.quantization.layers import (
            ExportQuantizeToFakeQuantize, ExportQuantizeToONNXQuantDequant,
            QuantizerExportMode, get_scale_zp_from_input_low_input_high)
        from nncf.utils import no_jit_trace
        with no_jit_trace():
            input_range = abs(self.scale) + self.eps
            # todo: take bias into account during input_low/input_high calculation
            input_low = input_range * self.level_low / self.level_high
            input_high = input_range

            if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
                y_scale, y_zero_point = get_scale_zp_from_input_low_input_high(
                    self.level_low, self.level_high, input_low, input_high)

        if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
            return ExportQuantizeToONNXQuantDequant.apply(
                x, y_scale, y_zero_point)
        if self._export_mode == QuantizerExportMode.FAKE_QUANTIZE:
            x = x / 2.0
            return ExportQuantizeToFakeQuantize.apply(x, self.levels,
                                                      input_low, input_high,
                                                      input_low * 2,
                                                      input_high * 2)
        raise RuntimeError
Ejemplo n.º 2
0
 def forward(self, weight):
     if is_tracing_state():
         with no_jit_trace():
             return weight.mul_(self.binary_mask)
     tmp_tensor = self._calc_training_binary_mask(weight)
     self.binary_mask = self._calc_binary_mask(weight)
     return apply_binary_mask_impl(tmp_tensor, weight)
Ejemplo n.º 3
0
 def forward(self, conv_weight):
     if is_tracing_state():
         with no_jit_trace():
             return conv_weight
     return conv_weight
Ejemplo n.º 4
0
    def forward(ctx, loc_data, conf_data, prior_data, detection_output_params):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers
                Shape: [batch,num_priors*4]
            conf_data: (tensor) Shape: Conf preds from conf layers
                Shape: [batch,num_priors*num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers
                Shape: [1,2,num_priors*4]
        """
        with no_jit_trace(), no_nncf_trace():
            if detection_output_params.nms_threshold <= 0:
                raise ValueError('nms_threshold must be non negative.')
            device = loc_data.device
            batch_size = loc_data.size(0)  # batch size
            num_priors = int(loc_data.size(1) / 4)
            loc_data = loc_data.view(batch_size, num_priors, 4)
            conf_data = conf_data.view(batch_size, num_priors, -1)
            prior_data = prior_data.view(1, 2, num_priors, 4)
            output = torch.zeros(batch_size, 1,
                                 detection_output_params.keep_top_k,
                                 7).to(device)

            conf_preds = conf_data.view(
                batch_size, num_priors,
                detection_output_params.num_classes).transpose(2, 1)

            # Decode predictions into bboxes.
            for i in range(batch_size):
                output_for_img = torch.zeros(0, 7).to(device)
                decoded_boxes = decode(loc_data[i], prior_data[0])
                # For each class, perform nms
                conf_scores = conf_preds[i].clone()

                total_detections_count = 0
                all_indices = dict(
                )  # indices of confident detections for each class
                boxes = dict()
                for cl in range(0, detection_output_params.num_classes):
                    if cl == detection_output_params.background_label_id:
                        continue
                    c_mask = conf_scores[cl].gt(
                        detection_output_params.confidence_threshold)
                    scores = conf_scores[cl][c_mask]
                    if scores.dim() == 0:
                        continue
                    conf_scores[cl, :scores.size()[0]] = scores
                    conf_scores[cl, scores.size()[0]:] = 0
                    l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                    boxes[cl] = decoded_boxes[l_mask].view(-1, 4)
                    # idx of highest scoring and non-overlapping boxes per class
                    all_indices[cl], count = nms(
                        boxes[cl], scores,
                        detection_output_params.nms_threshold,
                        detection_output_params.top_k)
                    all_indices[cl] = all_indices[cl][:count]
                    total_detections_count += count

                score_index_pairs = list(
                )  # list of tuples (score, label, idx)
                for label, indices in all_indices.items():
                    indices = indices.cpu().numpy()
                    for idx in indices:
                        score_index_pairs.append(
                            (conf_scores[label, idx], label, idx))

                score_index_pairs.sort(key=lambda tup: tup[0], reverse=True)
                score_index_pairs = score_index_pairs[:detection_output_params.
                                                      keep_top_k]

                all_indices_new = dict()
                for _, label, idx in score_index_pairs:
                    if label not in all_indices_new:
                        all_indices_new[label] = [idx]
                    else:
                        all_indices_new[label].append(idx)

                for label, indices in all_indices_new.items():
                    out = torch.cat(
                        (torch.zeros(
                            (len(indices), 1), dtype=torch.float).new_full(
                                (len(indices), 1), i).to(device),
                         torch.zeros(
                             (len(indices), 1), dtype=torch.float).new_full(
                                 (len(indices), 1), label).to(device),
                         conf_scores[label, indices].unsqueeze(1).to(device),
                         boxes[label][indices].to(device)), 1)
                    output_for_img = torch.cat((output_for_img, out), 0)

                output[i, 0, :output_for_img.size()[0]] = output_for_img
        return output