Ejemplo n.º 1
0
    def _run_inference(self, dataset, summary, threshod=0.5):
        """
        Run inference for the dataset and write the inference related data into summary.

        Args:
            dataset (`ds`): the parsed dataset
            summary (`SummaryRecord`): the summary object to store the data
            threshold (float): the threshold for prediction.

        Returns:
            imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
        """
        imageid_labels = {}
        ds.config.set_seed(58)
        self._count = 0
        for j, next_element in enumerate(dataset):
            now = time()
            inputs, labels, _ = self._unpack_next_element(next_element)
            prob = self._model(inputs).asnumpy()
            for idx, inp in enumerate(inputs):
                gt_labels = labels[idx]
                gt_probs = [float(prob[idx][i]) for i in gt_labels]

                data_np = _convert_image_format(
                    np.expand_dims(inp.asnumpy(), 0), 'NCHW')
                _, _, _, image_string = _make_image(_normalize(data_np))

                predicted_labels = [
                    int(i) for i in (prob[idx] > threshod).nonzero()[0]
                ]
                predicted_probs = [
                    float(prob[idx][i]) for i in predicted_labels
                ]

                union_labs = list(set(gt_labels + predicted_labels))
                imageid_labels[str(self._count)] = union_labs

                explain = Explain()
                explain.image_id = str(self._count)
                explain.image_data = image_string
                summary.add_value("explainer", "image", explain)

                explain = Explain()
                explain.image_id = str(self._count)
                explain.ground_truth_label.extend(gt_labels)
                explain.inference.ground_truth_prob.extend(gt_probs)
                explain.inference.predicted_label.extend(predicted_labels)
                explain.inference.predicted_prob.extend(predicted_probs)
                summary.add_value("explainer", "inference", explain)

                summary.record(1)

                self._count += 1
            print(
                "Finish running and writing {}-th batch inference data. Time elapsed: {}s"
                .format(j,
                        time() - now))
        return imageid_labels
Ejemplo n.º 2
0
    def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
        """
        Run the explanation for each step and write explanation results into summary.

        Args:
            next_element (Tuple): data of one step
            explainer (_Attribution): an Attribution object to generate saliency maps.
            imageid_labels (dict): a dict that maps the image_id and its union labels.
            summary (SummaryRecord): the summary object to store the data

        Returns:
            List of dict that maps label to its corresponding saliency map.
        """
        inputs, labels, _ = self._unpack_next_element(next_element)
        count = self._count
        unions = []
        for _ in range(len(labels)):
            unions_labels = imageid_labels[str(count)]
            unions.append(unions_labels)
            count += 1

        batch_unions = self._make_label_batch(unions)
        saliency_dict_lst = []

        batch_saliency_full = []
        for i in range(len(batch_unions[0])):
            batch_saliency = explainer(inputs, batch_unions[:, i])
            batch_saliency_full.append(batch_saliency)

        for idx, union in enumerate(unions):
            saliency_dict = {}
            explain = Explain()
            explain.image_id = str(self._count)
            for k, lab in enumerate(union):
                saliency = batch_saliency_full[k][idx:idx + 1]

                saliency_dict[lab] = saliency

                saliency_np = _make_rgba(saliency)
                _, _, _, saliency_string = _make_image(_normalize(saliency_np))

                explanation = explain.explanation.add()
                explanation.explain_method = explainer.__class__.__name__

                explanation.label = lab
                explanation.heatmap = saliency_string

            summary.add_value("explainer", "explanation", explain)
            summary.record(1)

            self._count += 1
            saliency_dict_lst.append(saliency_dict)
        return saliency_dict_lst