def _run_inference(self, summary, threshold=0.5):
        """
        Run inference for the dataset and write the inference related data into summary.

        Args:
            summary (SummaryRecord): The summary object to store the data
            threshold (float): The threshold for prediction.

        Returns:
            dict, The map of sample d to the union of its ground truth and predicted labels.
        """
        sample_id_labels = {}
        self._sample_index = 0
        ds.config.set_seed(self._DATASET_SEED)
        for j, next_element in enumerate(self._dataset):
            now = time()
            inputs, labels, _ = self._unpack_next_element(next_element)
            prob = self._full_network(inputs).asnumpy()

            for idx, inp in enumerate(inputs):
                gt_labels = labels[idx]
                gt_probs = [float(prob[idx][i]) for i in gt_labels]

                data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
                original_image = _np_to_image(_normalize(data_np), mode='RGB')
                original_image_path = self._save_original_image(self._sample_index, original_image)

                predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
                predicted_probs = [float(prob[idx][i]) for i in predicted_labels]

                union_labs = list(set(gt_labels + predicted_labels))
                sample_id_labels[str(self._sample_index)] = union_labs

                explain = Explain()
                explain.sample_id = self._sample_index
                explain.image_path = original_image_path
                summary.add_value("explainer", "sample", explain)

                explain = Explain()
                explain.sample_id = self._sample_index
                explain.ground_truth_label.extend(gt_labels)
                explain.inference.ground_truth_prob.extend(gt_probs)
                explain.inference.predicted_label.extend(predicted_labels)
                explain.inference.predicted_prob.extend(predicted_probs)

                summary.add_value("explainer", "inference", explain)

                summary.record(1)

                self._sample_index += 1
            self._spaced_print("Finish running and writing {}-th batch inference data."
                               " Time elapsed: {:.3f} s".format(j, time() - now),
                               end='')
        return sample_id_labels
    def _run_exp_step(self, next_element, explainer, sample_id_labels,
                      summary):
        """
        Run the explanation for each step and write explanation results into summary.

        Args:
            next_element (Tuple): Data of one step
            explainer (_Attribution): An Attribution object to generate saliency maps.
            sample_id_labels (dict): A dict that maps the sample id and its union labels.
            summary (SummaryRecord): The summary object to store the data

        Returns:
            list, List of dict that maps label to its corresponding saliency map.
        """
        inputs, labels, _ = self._unpack_next_element(next_element)
        sample_index = self._sample_index
        unions = []
        for _ in range(len(labels)):
            unions_labels = sample_id_labels[str(sample_index)]
            unions.append(unions_labels)
            sample_index += 1

        batch_unions = self._make_label_batch(unions)
        saliency_dict_lst = []

        if isinstance(explainer, RISE):
            batch_saliency_full = explainer(inputs, batch_unions)
        else:
            batch_saliency_full = []
            for i in range(len(batch_unions[0])):
                batch_saliency = explainer(inputs, batch_unions[:, i])
                batch_saliency_full.append(batch_saliency)
            concat = ms.ops.operations.Concat(1)
            batch_saliency_full = concat(tuple(batch_saliency_full))

        for idx, union in enumerate(unions):
            saliency_dict = {}
            explain = Explain()
            explain.sample_id = self._sample_index
            for k, lab in enumerate(union):
                saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
                saliency_dict[lab] = saliency

                saliency_np = _normalize(saliency.asnumpy().squeeze())
                saliency_image = _np_to_image(saliency_np, mode='L')
                heatmap_path = self._save_heatmap(explainer.__class__.__name__,
                                                  lab, self._sample_index,
                                                  saliency_image)

                explanation = explain.explanation.add()
                explanation.explain_method = explainer.__class__.__name__
                explanation.heatmap_path = heatmap_path
                explanation.label = lab

            summary.add_value("explainer", "explanation", explain)
            summary.record(1)

            self._sample_index += 1
            saliency_dict_lst.append(saliency_dict)
        return saliency_dict_lst
    def _run_hoc(self, summary, sample_id, sample_input, prob):
        """
        Run HOC search for a sample image, and then save the result to summary.

        Args:
            summary (SummaryRecord): The summary object to store the data.
            sample_id (int): The sample ID.
            sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1).
            prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for
                labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default).
        """
        if isinstance(sample_input, ms.Tensor):
            sample_input = sample_input.asnumpy()
        if len(sample_input.shape) == 3:
            sample_input = np.expand_dims(sample_input, axis=0)
        has_rec = False
        explain = Explain()
        explain.sample_id = sample_id
        str_mask = hoc.auto_str_mask(sample_input)
        compiled_mask = None
        for label_idx, label_prob in enumerate(prob):
            if label_prob > self._hoc_searcher.threshold:
                if compiled_mask is None:
                    compiled_mask = hoc.compile_mask(str_mask, sample_input)
                try:
                    edit_tree, layer_outputs = self._hoc_searcher.search(
                        sample_input, label_idx, compiled_mask)
                except hoc.NoValidResultError:
                    log.warning(
                        f"No Hierarchical Occlusion result was found in sample#{sample_id} "
                        f"label:{self._labels[label_idx]}, skipped.")
                    continue
                has_rec = True
                hoc_rec = explain.hoc.add()
                hoc_rec.label = label_idx
                hoc_rec.mask = str_mask
                layer_count = edit_tree.max_layer + 1
                for layer in range(layer_count):
                    steps = edit_tree.get_layer_or_leaf_steps(layer)
                    layer_output = layer_outputs[layer]
                    hoc_layer = hoc_rec.layer.add()
                    hoc_layer.prob = layer_output
                    for step in steps:
                        hoc_layer.box.extend(list(step.box))
        if has_rec:
            summary.add_value("explainer", "hoc", explain)
            summary.record(1)
            self._manifest['hierarchical_occlusion'] = True
Example #4
0
    def _run_hoc(self, summary, sample_id, sample_input, prob):
        """
        Run HOC search for a sample image, and then save the result to summary.

        Args:
            summary (SummaryRecord): The summary object to store the data.
            sample_id (int): The sample ID.
            sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1).
            prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for
                labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default).
        """
        if isinstance(sample_input, ms.Tensor):
            sample_input = sample_input.asnumpy()
        if len(sample_input.shape) == 3:
            sample_input = np.expand_dims(sample_input, axis=0)

        explain = None
        str_mask = hoc.auto_str_mask(sample_input)
        compiled_mask = None

        for label_idx, label_prob in enumerate(prob):
            if label_prob <= self._hoc_searcher.threshold:
                continue
            if compiled_mask is None:
                compiled_mask = hoc.compile_mask(str_mask, sample_input)
            try:
                edit_tree, layer_outputs = self._hoc_searcher.search(
                    sample_input, label_idx, compiled_mask)
            except hoc.NoValidResultError:
                log.warning(
                    f"No Hierarchical Occlusion result was found in sample#{sample_id} "
                    f"label:{self._labels[label_idx]}, skipped.")
                continue

            if explain is None:
                explain = Explain()
                explain.sample_id = sample_id

            self._add_hoc_result_to_explain(label_idx, str_mask, edit_tree,
                                            layer_outputs, explain)

        if explain is not None:
            summary.add_value("explainer", "hoc", explain)
            summary.record(1)
            self._manifest['hierarchical_occlusion'] = True
Example #5
0
    def _add_exp_step_samples(self, explainer, sample_label_sets,
                              batch_saliency_full, summary):
        """
        Add explanation results of samples to summary record.

        Args:
            explainer (Attribution): The explainer to be run.
            sample_label_sets (list[list[int]]): The label sets of samples.
            batch_saliency_full (Tensor): The saliency output from explainer.
            summary (SummaryRecord): The summary record.
        """
        saliency_dict_lst = []
        has_saliency_rec = False
        for idx, label_set in enumerate(sample_label_sets):
            saliency_dict = {}
            explain = Explain()
            explain.sample_id = self._sample_index
            for k, lab in enumerate(label_set):
                saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
                saliency_dict[lab] = saliency

                saliency_np = _normalize(saliency.asnumpy().squeeze())
                saliency_image = _np_to_image(saliency_np, mode='L')
                heatmap_path = self._save_heatmap(explainer.__class__.__name__,
                                                  lab, self._sample_index,
                                                  saliency_image)

                explanation = explain.explanation.add()
                explanation.explain_method = explainer.__class__.__name__
                explanation.heatmap_path = heatmap_path
                explanation.label = lab

                has_saliency_rec = True

            summary.add_value("explainer", "explanation", explain)
            summary.record(1)

            self._sample_index += 1
            saliency_dict_lst.append(saliency_dict)

        return saliency_dict_lst, has_saliency_rec
Example #6
0
    def _run_inference(self, summary, threshold=0.5):
        """
        Run inference for the dataset and write the inference related data into summary.

        Args:
            summary (SummaryRecord): The summary object to store the data.
            threshold (float): The threshold for prediction.

        Returns:
            dict, The map of sample d to the union of its ground truth and predicted labels.
        """
        has_uncertainty_rec = False
        sample_id_labels = {}
        self._sample_index = 0
        ds.config.set_seed(self._DATASET_SEED)
        for j, next_element in enumerate(self._dataset):
            now = time()
            inputs, labels, _ = self._unpack_next_element(next_element)
            prob = self._full_network(inputs).asnumpy()

            if self._uncertainty is not None:
                prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
            else:
                prob_var = None

            for idx, inp in enumerate(inputs):
                gt_labels = labels[idx]
                gt_probs = [float(prob[idx][i]) for i in gt_labels]

                if prob_var is not None:
                    gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels]
                    gt_itl_lows, gt_itl_his, gt_prob_sds = \
                        self._calc_beta_intervals(gt_probs, gt_prob_vars)

                data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
                original_image = _np_to_image(_normalize(data_np), mode='RGB')
                original_image_path = self._save_original_image(self._sample_index, original_image)

                predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
                predicted_probs = [float(prob[idx][i]) for i in predicted_labels]

                if prob_var is not None:
                    predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels]
                    predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \
                        self._calc_beta_intervals(predicted_probs, predicted_prob_vars)

                union_labs = list(set(gt_labels + predicted_labels))
                sample_id_labels[str(self._sample_index)] = union_labs

                explain = Explain()
                explain.sample_id = self._sample_index
                explain.image_path = original_image_path
                summary.add_value("explainer", "sample", explain)

                explain = Explain()
                explain.sample_id = self._sample_index
                explain.ground_truth_label.extend(gt_labels)
                explain.inference.ground_truth_prob.extend(gt_probs)
                explain.inference.predicted_label.extend(predicted_labels)
                explain.inference.predicted_prob.extend(predicted_probs)

                if prob_var is not None:
                    explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
                    explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows)
                    explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his)
                    explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
                    explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows)
                    explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his)

                    has_uncertainty_rec = True

                summary.add_value("explainer", "inference", explain)
                summary.record(1)

                if self._is_hoc_registered:
                    self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx])

                self._sample_index += 1
            self._spaced_print("Finish running and writing {}-th batch inference data."
                               " Time elapsed: {:.3f} s".format(j, time() - now))

        if has_uncertainty_rec:
            self._manifest["uncertainty"] = True

        return sample_id_labels
Example #7
0
    def _run_inference(self, dataset, summary, threshold=0.5):
        """
        Run inference for the dataset and write the inference related data into summary.

        Args:
            dataset (`ds`): the parsed dataset
            summary (`SummaryRecord`): the summary object to store the data
            threshold (float): the threshold for prediction.

        Returns:
            imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
        """
        spacer = '{:120}\r'
        imageid_labels = {}
        ds.config.set_seed(_SEED)
        self._count = 0
        for j, next_element in enumerate(dataset):
            now = time()
            inputs, labels, _ = self._unpack_next_element(next_element)
            prob = self._model(inputs).asnumpy()
            if self._uncertainty is not None:
                prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
                prob_sd = np.sqrt(prob_var)
            else:
                prob_var = prob_sd = None

            for idx, inp in enumerate(inputs):
                gt_labels = labels[idx]
                gt_probs = [float(prob[idx][i]) for i in gt_labels]

                data_np = _convert_image_format(
                    np.expand_dims(inp.asnumpy(), 0), 'NCHW')
                original_image = _np_to_image(_normalize(data_np), mode='RGB')
                original_image_path = self._save_original_image(
                    self._count, original_image)

                predicted_labels = [
                    int(i) for i in (prob[idx] > threshold).nonzero()[0]
                ]
                predicted_probs = [
                    float(prob[idx][i]) for i in predicted_labels
                ]

                has_uncertainty = False
                gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
                predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
                if prob_var is not None:
                    gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
                    predicted_prob_sds = [
                        float(prob_sd[idx][i]) for i in predicted_labels
                    ]
                    try:
                        gt_prob_itl95_lows, gt_prob_itl95_his = \
                            _calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
                        predicted_prob_itl95_lows, predicted_prob_itl95_his = \
                            _calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
                                                                        for i in predicted_labels])
                        has_uncertainty = True
                    except ValueError:
                        log.error(traceback.format_exc())
                        log.error("Error on calculating uncertainty")

                union_labs = list(set(gt_labels + predicted_labels))
                imageid_labels[str(self._count)] = union_labs

                explain = Explain()
                explain.sample_id = self._count
                explain.image_path = original_image_path
                summary.add_value("explainer", "sample", explain)

                explain = Explain()
                explain.sample_id = self._count
                explain.ground_truth_label.extend(gt_labels)
                explain.inference.ground_truth_prob.extend(gt_probs)
                explain.inference.predicted_label.extend(predicted_labels)
                explain.inference.predicted_prob.extend(predicted_probs)

                if has_uncertainty:
                    explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
                    explain.inference.ground_truth_prob_itl95_low.extend(
                        gt_prob_itl95_lows)
                    explain.inference.ground_truth_prob_itl95_hi.extend(
                        gt_prob_itl95_his)

                    explain.inference.predicted_prob_sd.extend(
                        predicted_prob_sds)
                    explain.inference.predicted_prob_itl95_low.extend(
                        predicted_prob_itl95_lows)
                    explain.inference.predicted_prob_itl95_hi.extend(
                        predicted_prob_itl95_his)

                summary.add_value("explainer", "inference", explain)

                summary.record(1)

                self._count += 1
            print(spacer.format(
                "Finish running and writing {}-th batch inference data."
                " Time elapsed: {:.3f} s".format(j,
                                                 time() - now)),
                  end='')
        return imageid_labels
Example #8
0
    def _run_sample(self, summary, next_element, sample_id_labels, threshold):
        """
        Run inference for a sample.

        Args:
            summary (SummaryRecord): The summary object to store the data.
            next_element (tuple): The next dataset sample.
            sample_id_labels (dict): The sample id to labels dictionary.
            threshold (float): The threshold for prediction.
        """
        inputs, labels, _ = self._unpack_next_element(next_element)
        prob = self._full_network(inputs).asnumpy()

        if self._uncertainty is not None:
            prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
        else:
            prob_var = None

        for idx, inp in enumerate(inputs):
            gt_labels = labels[idx]
            gt_probs = [float(prob[idx][i]) for i in gt_labels]

            if prob_var is not None:
                gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels]
                gt_itl_lows, gt_itl_his, gt_prob_sds = \
                    self._calc_beta_intervals(gt_probs, gt_prob_vars)

            data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0),
                                            'NCHW')
            original_image = _np_to_image(_normalize(data_np), mode='RGB')
            original_image_path = self._save_original_image(
                self._sample_index, original_image)

            predicted_labels = [
                int(i) for i in (prob[idx] > threshold).nonzero()[0]
            ]
            predicted_probs = [float(prob[idx][i]) for i in predicted_labels]

            if prob_var is not None:
                predicted_prob_vars = [
                    float(prob_var[idx][i]) for i in predicted_labels
                ]
                predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \
                    self._calc_beta_intervals(predicted_probs, predicted_prob_vars)

            union_labs = list(set(gt_labels + predicted_labels))
            sample_id_labels[str(self._sample_index)] = union_labs

            explain = Explain()
            explain.sample_id = self._sample_index
            explain.image_path = original_image_path
            summary.add_value("explainer", "sample", explain)

            explain = Explain()
            explain.sample_id = self._sample_index
            explain.ground_truth_label.extend(gt_labels)
            explain.inference.ground_truth_prob.extend(gt_probs)
            explain.inference.predicted_label.extend(predicted_labels)
            explain.inference.predicted_prob.extend(predicted_probs)

            if prob_var is not None:
                explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
                explain.inference.ground_truth_prob_itl95_low.extend(
                    gt_itl_lows)
                explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his)
                explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
                explain.inference.predicted_prob_itl95_low.extend(
                    predicted_itl_lows)
                explain.inference.predicted_prob_itl95_hi.extend(
                    predicted_itl_his)

                self._manifest["uncertainty"] = True

            summary.add_value("explainer", "inference", explain)
            summary.record(1)

            if self._is_hoc_registered:
                self._run_hoc(summary, self._sample_index, inputs[idx],
                              prob[idx])