def _run_inference(self, dataset, summary, threshod=0.5): """ Run inference for the dataset and write the inference related data into summary. Args: dataset (`ds`): the parsed dataset summary (`SummaryRecord`): the summary object to store the data threshold (float): the threshold for prediction. Returns: imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. """ imageid_labels = {} ds.config.set_seed(58) self._count = 0 for j, next_element in enumerate(dataset): now = time() inputs, labels, _ = self._unpack_next_element(next_element) prob = self._model(inputs).asnumpy() for idx, inp in enumerate(inputs): gt_labels = labels[idx] gt_probs = [float(prob[idx][i]) for i in gt_labels] data_np = _convert_image_format( np.expand_dims(inp.asnumpy(), 0), 'NCHW') _, _, _, image_string = _make_image(_normalize(data_np)) predicted_labels = [ int(i) for i in (prob[idx] > threshod).nonzero()[0] ] predicted_probs = [ float(prob[idx][i]) for i in predicted_labels ] union_labs = list(set(gt_labels + predicted_labels)) imageid_labels[str(self._count)] = union_labs explain = Explain() explain.image_id = str(self._count) explain.image_data = image_string summary.add_value("explainer", "image", explain) explain = Explain() explain.image_id = str(self._count) explain.ground_truth_label.extend(gt_labels) explain.inference.ground_truth_prob.extend(gt_probs) explain.inference.predicted_label.extend(predicted_labels) explain.inference.predicted_prob.extend(predicted_probs) summary.add_value("explainer", "inference", explain) summary.record(1) self._count += 1 print( "Finish running and writing {}-th batch inference data. Time elapsed: {}s" .format(j, time() - now)) return imageid_labels
def _run_exp_step(self, next_element, explainer, imageid_labels, summary): """ Run the explanation for each step and write explanation results into summary. Args: next_element (Tuple): data of one step explainer (_Attribution): an Attribution object to generate saliency maps. imageid_labels (dict): a dict that maps the image_id and its union labels. summary (SummaryRecord): the summary object to store the data Returns: List of dict that maps label to its corresponding saliency map. """ inputs, labels, _ = self._unpack_next_element(next_element) count = self._count unions = [] for _ in range(len(labels)): unions_labels = imageid_labels[str(count)] unions.append(unions_labels) count += 1 batch_unions = self._make_label_batch(unions) saliency_dict_lst = [] batch_saliency_full = [] for i in range(len(batch_unions[0])): batch_saliency = explainer(inputs, batch_unions[:, i]) batch_saliency_full.append(batch_saliency) for idx, union in enumerate(unions): saliency_dict = {} explain = Explain() explain.image_id = str(self._count) for k, lab in enumerate(union): saliency = batch_saliency_full[k][idx:idx + 1] saliency_dict[lab] = saliency saliency_np = _make_rgba(saliency) _, _, _, saliency_string = _make_image(_normalize(saliency_np)) explanation = explain.explanation.add() explanation.explain_method = explainer.__class__.__name__ explanation.label = lab explanation.heatmap = saliency_string summary.add_value("explainer", "explanation", explain) summary.record(1) self._count += 1 saliency_dict_lst.append(saliency_dict) return saliency_dict_lst