def __init__(self,
                 dataset_dir,
                 vocab_args,
                 term_graph_args,
                 image_types=None,
                 size=None,
                 normalize=True,
                 image_transform_configs=[],
                 sampling_window=None,
                 sampling_rate=1,
                 report_transform_configs=[],
                 task_configs=[],
                 max_len=None,
                 skip_scans=False,
                 split=None):
        """
        """
        self.skip_scans = skip_scans
        if not skip_scans:
            ImageDataset.__init__(self,
                                  dataset_dir,
                                  image_types=image_types,
                                  size=size,
                                  normalize=normalize,
                                  sampling_window=sampling_window,
                                  transform_configs=image_transform_configs,
                                  sampling_rate=sampling_rate,
                                  split=split)
        ReportDataset.__init__(self,
                               dataset_dir,
                               split=split,
                               transform_configs=report_transform_configs)

        self.task_configs = {
            task_config["fn"]: task_config
            for task_config in task_configs
        }

        self.max_len = max_len
        self.vocab = WordPieceVocab(**vocab_args)
        self.term_graph = TermGraph(**term_graph_args)
    def __init__(self,
                 dataset_dir,
                 vocab_args={},
                 transform_configs=[],
                 pct_tokens=0.15,
                 pct_random=0.10,
                 pct_same=0.10,
                 max_len=None,
                 split=None):
        ReportDataset.__init__(self, dataset_dir, transform_configs, split)
        """
        WARNING: The vocab file used here MUST match the vocab file used in your model.
        """
        self.vocab = WordPieceVocab(**vocab_args)

        # parameters for mlm masking
        self.pct_tokens = pct_tokens
        self.pct_random = pct_random
        self.pct_same = pct_same

        self.max_len = max_len
    def __init__(self,
                 dir,
                 model_dir,
                 dataset_class="ReportDataset",
                 dataset_args={},
                 term_graph_dir="data/pet_ct_terms/terms.json",
                 terms="all",
                 match_task="fdg_abnorm",
                 split_fn="split_impression_sections",
                 max_len=200,
                 vocab_args={},
                 seed=123,
                 cuda=True,
                 devices=[0]):
        """
        """
        super().__init__(dir)
        self.cuda = cuda
        self.devices = devices
        self.device = devices[0]

        self.split_fn = globals()[split_fn]

        dataset = getattr(datasets, dataset_class)(**dataset_args)
        self.dataloader = DataLoader(dataset, batch_size=1)

        logging.info("Loading TermGraph and Vocab...")
        self.match_task = match_task
        self.terms = terms
        self.term_graph = TermGraph(term_graph_dir)
        if terms == "all":
            self.terms = self.term_graph.term_names
        else:
            self.terms = terms

        self.vocab = WordPieceVocab(**vocab_args)
        self.max_len = max_len

        logging.info("Loading Model...")
        self._load_model(model_dir)
    def __init__(self,
                 dataset_dir,
                 split=None,
                 task_configs={},
                 vocab_args={},
                 max_len=200,
                 cls_label=True):
        """
        """
        self.dataset_dir = dataset_dir
        self.split = split

        if split is None or split == "all":
            matches_path = os.path.join(dataset_dir, 'match_labels.csv')
        else:
            matches_path = os.path.join(dataset_dir, 'split', f'{split}.csv')
        self.matches_df = pd.read_csv(matches_path)
        self.matches_df = self.matches_df[self.matches_df["not_applicable"] ==
                                          False]

        self.vocab = WordPieceVocab(**vocab_args)

        self.max_len = max_len
        self.cls_label = cls_label
    def __init__(self,
                 scan_encoder_class=None,
                 scan_encoder_args={},
                 bert_class=None,
                 bert_args={},
                 scan_decoder_class=None,
                 scan_decoder_args={},
                 task_configs=[],
                 vocab_args={},
                 loss_weighting=None,
                 optim_class="Adam",
                 optim_args={},
                 scheduler_class=None,
                 scheduler_args={},
                 pretrained_configs=[],
                 cuda=True,
                 devices=[0]):
        """
        """
        super().__init__(optim_class, optim_args, scheduler_class,
                         scheduler_args, pretrained_configs, cuda, devices)

        self.encodes_scans = scan_encoder_class is not None
        if self.encodes_scans:
            self.scan_encoder = getattr(
                modules, scan_encoder_class)(**scan_encoder_args)
            self.scan_encoder = nn.DataParallel(self.scan_encoder,
                                                device_ids=self.devices)

        if bert_class == "BertModelPreTrained":
            self.bert = BertModel.from_pretrained(**bert_args)
        elif bert_class == "BertForPretraining":
            self.bert = BertForPreTraining.from_pretrained(**bert_args)
        elif bert_class == "BertModel":
            bert_args["config"] = BertConfig.from_dict(bert_args["config"])
            self.bert = BertModel(**bert_args)
        else:
            self.bert = getattr(modules, bert_class)(**bert_args)
        self.bert = nn.DataParallel(self.bert, device_ids=self.devices)

        self.decodes_scans = scan_decoder_class is not None
        if self.decodes_scans:
            self.scan_decoder = getattr(
                modules, scan_decoder_class)(**scan_decoder_args)

        self.task_heads = {}
        self.task_inputs = {}
        for task_head_config in task_configs:
            task = task_head_config["task"]
            head_class = getattr(modules, task_head_config["class"])
            args = task_head_config["args"]
            self.task_inputs[task] = (task_head_config["inputs"] if "inputs"
                                      in task_head_config else "pool")

            if "config" in args:
                # bert task heads take config object for parameters, must convert from dict
                config = args["config"]
                args["config"] = namedtuple("Config",
                                            config.keys())(*config.values())

            if head_class is BertOnlyMLMHead:
                embs = self.bert.module.embeddings.word_embeddings.weight
                self.task_heads[task] = head_class(
                    bert_model_embedding_weights=embs, **args)
            else:
                self.task_heads[task] = head_class(**args)

        self.task_heads = torch.nn.ModuleDict(self.task_heads)

        self.vocab = WordPieceVocab(**vocab_args)

        self._build_loss(loss_weighting)

        self._post_init()
class BertScanModel(BaseModel):
    """
    """
    def __init__(self,
                 scan_encoder_class=None,
                 scan_encoder_args={},
                 bert_class=None,
                 bert_args={},
                 scan_decoder_class=None,
                 scan_decoder_args={},
                 task_configs=[],
                 vocab_args={},
                 loss_weighting=None,
                 optim_class="Adam",
                 optim_args={},
                 scheduler_class=None,
                 scheduler_args={},
                 pretrained_configs=[],
                 cuda=True,
                 devices=[0]):
        """
        """
        super().__init__(optim_class, optim_args, scheduler_class,
                         scheduler_args, pretrained_configs, cuda, devices)

        self.encodes_scans = scan_encoder_class is not None
        if self.encodes_scans:
            self.scan_encoder = getattr(
                modules, scan_encoder_class)(**scan_encoder_args)
            self.scan_encoder = nn.DataParallel(self.scan_encoder,
                                                device_ids=self.devices)

        if bert_class == "BertModelPreTrained":
            self.bert = BertModel.from_pretrained(**bert_args)
        elif bert_class == "BertForPretraining":
            self.bert = BertForPreTraining.from_pretrained(**bert_args)
        elif bert_class == "BertModel":
            bert_args["config"] = BertConfig.from_dict(bert_args["config"])
            self.bert = BertModel(**bert_args)
        else:
            self.bert = getattr(modules, bert_class)(**bert_args)
        self.bert = nn.DataParallel(self.bert, device_ids=self.devices)

        self.decodes_scans = scan_decoder_class is not None
        if self.decodes_scans:
            self.scan_decoder = getattr(
                modules, scan_decoder_class)(**scan_decoder_args)

        self.task_heads = {}
        self.task_inputs = {}
        for task_head_config in task_configs:
            task = task_head_config["task"]
            head_class = getattr(modules, task_head_config["class"])
            args = task_head_config["args"]
            self.task_inputs[task] = (task_head_config["inputs"] if "inputs"
                                      in task_head_config else "pool")

            if "config" in args:
                # bert task heads take config object for parameters, must convert from dict
                config = args["config"]
                args["config"] = namedtuple("Config",
                                            config.keys())(*config.values())

            if head_class is BertOnlyMLMHead:
                embs = self.bert.module.embeddings.word_embeddings.weight
                self.task_heads[task] = head_class(
                    bert_model_embedding_weights=embs, **args)
            else:
                self.task_heads[task] = head_class(**args)

        self.task_heads = torch.nn.ModuleDict(self.task_heads)

        self.vocab = WordPieceVocab(**vocab_args)

        self._build_loss(loss_weighting)

        self._post_init()

    def show_attention(self, inputs):
        """
        Must be initialized with BertConfig where config.output_attentions = True
        """
        # turn on output attentions
        report_inputs = inputs["report"]
        report_input_ids, attention_mask = self.vocab.to_input_tensor(
            report_inputs, device=self.device)

        outputs = self.bert(input_ids=report_input_ids,
                            attention_mask=attention_mask)

        return outputs

    def forward(self, inputs, targets):
        """
        Args:
            inputs  (torch.Tensor) a (batch_size, ...) shaped input tensor
            targets     (list) a list of task targets from t0, t1...tn. The last element
                        should be a list(list(str)) representing the target report.

        Return:
            outputs (list) of
        """
        if self.encodes_scans:
            scan_inputs = inputs["scan"]
            scan_encodings = self.scan_encoder(scan_inputs)
            # TODO: 3d positional encodings
            scan_encodings = scan_encodings.view(scan_encodings.shape[0],
                                                 scan_encodings.shape[1], -1)
            scan_encodings = scan_encodings.permute(0, 2, 1)

        report_inputs = inputs["report"]
        report_input_ids, attention_mask = self.vocab.to_input_tensor(
            report_inputs, device=self.device)

        bert_seq, pool = self.bert(input_ids=report_input_ids,
                                   attention_mask=attention_mask,
                                   output_all_encoded_layers=False)

        if self.decodes_scans:
            scan_seq, pool = self.scan_decoder(scan_encodings=scan_encodings,
                                               report_encodings=bert_seq,
                                               attention_mask=attention_mask,
                                               output_all_encoded_layers=False)
            # residual
            seq = bert_seq + scan_seq
        else:
            seq = bert_seq

        task_outputs = {}
        for task, task_head in self.task_heads.items():
            if self.task_inputs[task] == "seq":
                task_outputs[task] = task_head(seq)
            elif self.task_inputs[task] == "pool":
                task_outputs[task] = task_head(pool)
            else:
                raise ValueError("Input type not recognized.")

        return task_outputs

    def predict(self, inputs, probabilities=True):
        """
        """
        task_outputs = self.forward(inputs, None)
        task_predictions = {
            task: nn.functional.softmax(output, dim=-1)
            for task, output in task_outputs.items()
        }

        if not probabilities:
            task_predictions = {
                task: torch.argmax(output, dim=-1, keepdim=True)
                for task, output in task_predictions.items()
            }

        return task_predictions

    def loss(self, outputs, targets):
        """
        """
        total_loss = 0
        for task, output in outputs.items():
            curr_loss = self.loss_fn(output.view(-1, output.shape[-1]),
                                     targets[task].view(-1))
            total_loss += curr_loss

        return total_loss

    def _build_loss(self, loss_weighting=None):
        """
        """
        if loss_weighting is None:
            class_weights = None

        elif loss_weighting == "log":
            class_weights = torch.ones(len(self.vocab))
            for token, freq in self.vocab.token_to_freq.items():
                if freq == 0:
                    continue
                class_weights[
                    self.vocab.token_to_idx[token]] = 1 / np.log(freq + 1)
        else:
            raise ValueError("Loss weighting scheme not recognized.")

        self.loss_fn = nn.CrossEntropyLoss(weight=class_weights,
                                           ignore_index=-1)

    def _log_predictions(self, inputs, targets, predictions, info=None):
        """
        """
        predictions = {
            task: torch.argmax(output, dim=-1)
            for task, output in predictions.items()
        }

        if "scan_mlm" in predictions:
            logging.info(f"Exam ID: {info[0]['exam_id']}")
            mlm_preds, mlm_targets = predictions["scan_mlm"], targets[
                "scan_mlm"]
            mlm_preds = mlm_preds[mlm_targets != -1]
            mlm_targets = mlm_targets[mlm_targets != -1]

            # handle batch size of 1
            if len(mlm_targets.shape) < 2:
                mlm_preds = mlm_preds.unsqueeze(0)
                mlm_targets = mlm_targets.unsqueeze(0)

            if mlm_preds.shape[-1] != 0:
                mlm_preds = self.vocab.from_output_tensor(mlm_preds)
                mlm_targets = self.vocab.from_output_tensor(mlm_targets)

                for curr_inputs, target, pred in zip(inputs["report"],
                                                     mlm_targets, mlm_preds):
                    logging.info(f"Inputs: {curr_inputs}")
                    logging.info(f"Targets: {' --- '.join(target)}")
                    logging.info(f"Predict: {' --- '.join(pred)}")

        if "scan_match" in predictions:
            logging.info(f"Inputs: {inputs['report'][0]}")
            logging.info(f"Matched:   {targets['scan_match']}")
            logging.info(f"Predicted: {predictions['scan_match']}")
class MatchDataset:
    def __init__(self,
                 dataset_dir,
                 split=None,
                 task_configs={},
                 vocab_args={},
                 max_len=200,
                 cls_label=True):
        """
        """
        self.dataset_dir = dataset_dir
        self.split = split

        if split is None or split == "all":
            matches_path = os.path.join(dataset_dir, 'match_labels.csv')
        else:
            matches_path = os.path.join(dataset_dir, 'split', f'{split}.csv')
        self.matches_df = pd.read_csv(matches_path)
        self.matches_df = self.matches_df[self.matches_df["not_applicable"] ==
                                          False]

        self.vocab = WordPieceVocab(**vocab_args)

        self.max_len = max_len
        self.cls_label = cls_label

    def __len__(self):
        """
        """
        return len(self.matches_df)

    def __getitem__(self, idx):
        """
        """
        match = self.matches_df.iloc[idx]

        text = match["text"]
        tokens = self.vocab.tokenize(text)

        if self.max_len is not None and self.max_len < len(tokens):
            tokens = tokens[:self.max_len]

        tokens = self.vocab.wrap_sentence(tokens)

        # get numeric label
        if match["fdg_abnormality_label"] == "abnormal":
            label = 1
        elif match["fdg_abnormality_label"] == "normal":
            label = 0
        else:
            label = 0

        labels = -1 * torch.ones(len(tokens), dtype=torch.long)
        if self.cls_label:
            labels[0] = label

        # label tokens
        match_start, match_end = match["start"], match["end"]
        find_start = 0
        for idx, token in enumerate(tokens):
            if token == "[CLS]" or token == "[SEP]":
                continue

            if token.startswith("##"):
                # remove pounds
                token = token[2:]

            token_start = text.find(token, find_start)
            token_end = token_start + len(token)
            find_start = token_end

            if ((token_start >= match_start and token_start < match_end)
                    or (token_end >= match_start and token_end < match_end)):
                labels[idx] = label
        inputs = {"report": tokens}
        targets = {"fdg_abnorm": labels}
        info = {"exam_id": match["exam_id"], "term_name": match["term_name"]}

        return inputs, targets, info
class BertPretrainingDataset(ReportDataset):
    """
    """
    def __init__(self,
                 dataset_dir,
                 vocab_args={},
                 transform_configs=[],
                 pct_tokens=0.15,
                 pct_random=0.10,
                 pct_same=0.10,
                 max_len=None,
                 split=None):
        ReportDataset.__init__(self, dataset_dir, transform_configs, split)
        """
        WARNING: The vocab file used here MUST match the vocab file used in your model.
        """
        self.vocab = WordPieceVocab(**vocab_args)

        # parameters for mlm masking
        self.pct_tokens = pct_tokens
        self.pct_random = pct_random
        self.pct_same = pct_same

        self.max_len = max_len

    def get_targets(self, tasks=[]):
        """
        """
        dataset = H5Dataset(self.dataset_name, self.data_dir, mode="read")
        for idx, exam in self.exams_df.iterrows():
            yield self._get_targets(exam, dataset, tasks=tasks)

    def _get_targets(self, exam, dataset=None, tasks=[]):
        """
        """
        targets = {}
        if "abnorm" in tasks:
            targets["abnorm"] = torch.tensor(exam['label'])

        return targets

    def __len__(self):
        """
        """
        return len(self.exams_df.index)

    def __getitem__(self, idx):
        """
        """
        dataset = H5Dataset(self.dataset_name, self.data_dir, mode="read")
        exam = self.exams_df.iloc[int(idx)]

        label = exam['label']
        exam_id = exam['exam_id']
        patient_id = exam['patient_id']

        report = self._get_report(exam, dataset)

        report = self.vocab.tokenize(report)

        if self.max_len is not None and self.max_len < len(report):
            report = report[:self.max_len]

        report = self.vocab.wrap_sentence(report)

        report, mask_labels = self._mask_inputs(report)

        inputs = {"report": report}
        targets = {"mlm": mask_labels, "abnorm": torch.tensor(label)}
        info = {"exam_id": exam_id, "patient_id": patient_id}

        return inputs, targets, info

    def _mask_inputs(self, report):
        """
        """
        # don't sample [CLS] or [SEP] tokens TODO: make this work for middle of sequence [SEP]
        n_tokens = max(
            0,
            np.ceil((len(report) - 2) * self.pct_tokens).astype(int))
        token_idxs = np.random.choice(np.arange(1,
                                                len(report) - 1),
                                      size=n_tokens,
                                      replace=False)

        labels = -1 * torch.ones(len(report), dtype=torch.long)
        masked_report = report.copy()

        for token_idx in token_idxs:
            sample = np.random.rand()
            if sample < self.pct_random:
                # TODO: check if this is correct, do we also change the input word
                labels[token_idx] = int(
                    np.random.choice(list(self.vocab.idx_to_token.keys())))

            elif sample < self.pct_random + self.pct_same:
                labels[token_idx] = self.vocab.token_to_idx[report[token_idx]]

            else:
                masked_report[token_idx] = "[MASK]"
                labels[token_idx] = self.vocab.token_to_idx[report[token_idx]]

        return masked_report, labels
class BertScanDataset(ImageDataset, ReportDataset):
    """
    """
    def __init__(self,
                 dataset_dir,
                 vocab_args,
                 term_graph_args,
                 image_types=None,
                 size=None,
                 normalize=True,
                 image_transform_configs=[],
                 sampling_window=None,
                 sampling_rate=1,
                 report_transform_configs=[],
                 task_configs=[],
                 max_len=None,
                 skip_scans=False,
                 split=None):
        """
        """
        self.skip_scans = skip_scans
        if not skip_scans:
            ImageDataset.__init__(self,
                                  dataset_dir,
                                  image_types=image_types,
                                  size=size,
                                  normalize=normalize,
                                  sampling_window=sampling_window,
                                  transform_configs=image_transform_configs,
                                  sampling_rate=sampling_rate,
                                  split=split)
        ReportDataset.__init__(self,
                               dataset_dir,
                               split=split,
                               transform_configs=report_transform_configs)

        self.task_configs = {
            task_config["fn"]: task_config
            for task_config in task_configs
        }

        self.max_len = max_len
        self.vocab = WordPieceVocab(**vocab_args)
        self.term_graph = TermGraph(**term_graph_args)

    def __len__(self):
        """
        """
        return len(self.exams_df.index)

    def __getitem__(self, idx):
        """
        """
        dataset = H5Dataset(self.dataset_name, self.data_dir, mode="read")
        exam = self.exams_df.iloc[int(idx)]
        label, exam_id, patient_id = exam['label'], exam['exam_id'], exam[
            'patient_id']

        report = self._get_report(exam, dataset)
        images = self._get_images(exam,
                                  dataset) if not self.skip_scans else None
        targets = {}

        # must perform scan_match first
        if "scan_match" in self.task_configs:
            args = self.task_configs["scan_match"]["args"]
            report, labels = self.scan_match(exam_id, report, dataset, **args)
            targets["scan_match"] = labels

        # tokenize, trim if over max length, and wrap sentence
        report = self.vocab.tokenize(report)
        if self.max_len is not None and self.max_len < len(report):
            report = report[:self.max_len]
        report = self.vocab.wrap_sentence(report)

        if "scan_mlm" in self.task_configs:
            args = self.task_configs["scan_mlm"]["args"]
            report, labels = self.scan_mlm(exam_id, report, **args)
            targets["scan_mlm"] = labels

        info = {"exam_id": exam_id, "patient_id": patient_id}

        inputs = {"report": report, "scan": images}

        return inputs, targets, info

    def scan_match(self, exam_id, report, dataset, pct_same):
        """
        """
        sample = torch.rand(1).item()
        label = torch.tensor(1)
        if sample >= pct_same:
            self.mismatched = True
            label = torch.tensor(0)
            random_idx = np.random.randint(len(self.exams_df))
            random_exam = self.exams_df.iloc[random_idx]
            report = self._get_report(random_exam, dataset)

        return report, label

    def scan_mlm(self,
                 exam_id,
                 report,
                 rand_default_mask_prob=0.025,
                 term_default_mask_prob=0.85,
                 term_to_mask_prob={}):
        """
        """
        # TODO:
        matches = self.term_graph.get_matches(report)
        #
        masked_report = report.copy()
        labels = -1 * torch.ones(len(report), dtype=torch.long)

        # mask tokens
        self.term_graph.bernoulli_sample(term_to_prob=term_to_mask_prob,
                                         default_prob=term_default_mask_prob,
                                         sample_name="mask_sample")

        for token_idx, token_matches in enumerate(matches):
            token = report[token_idx]

            masked = False
            if len(token_matches) == 0:
                masked = bool(
                    torch.bernoulli(torch.tensor(rand_default_mask_prob)))
            else:
                token = report[token_idx]
                for match in token_matches:
                    if self.term_graph[match]["mask_sample"]:
                        masked = True
                        break

            if masked:
                masked_report[token_idx] = "[MASK]"
                labels[token_idx] = self.vocab.token_to_idx[token]
        return masked_report, labels
class ExamLabelsPredictor(Process):
    """
    """
    def __init__(self,
                 dir,
                 model_dir,
                 dataset_class="ReportDataset",
                 dataset_args={},
                 term_graph_dir="data/pet_ct_terms/terms.json",
                 terms="all",
                 match_task="fdg_abnorm",
                 split_fn="split_impression_sections",
                 max_len=200,
                 vocab_args={},
                 seed=123,
                 cuda=True,
                 devices=[0]):
        """
        """
        super().__init__(dir)
        self.cuda = cuda
        self.devices = devices
        self.device = devices[0]

        self.split_fn = globals()[split_fn]

        dataset = getattr(datasets, dataset_class)(**dataset_args)
        self.dataloader = DataLoader(dataset, batch_size=1)

        logging.info("Loading TermGraph and Vocab...")
        self.match_task = match_task
        self.terms = terms
        self.term_graph = TermGraph(term_graph_dir)
        if terms == "all":
            self.terms = self.term_graph.term_names
        else:
            self.terms = terms

        self.vocab = WordPieceVocab(**vocab_args)
        self.max_len = max_len

        logging.info("Loading Model...")
        self._load_model(model_dir)

    def _load_model(self, model_dir):
        """
        """
        with open(os.path.join(model_dir, "params.json")) as f:
            args = json.load(f)["process_args"]
            model_class = args["model_class"]
            model_args = args["model_args"]
            if "task_configs" in args:
                new_task_configs = []
                for task_config in args["task_configs"]:
                    new_task_config = args["default_task_config"].copy()
                    new_task_config.update(task_config)
                    new_task_configs.append(new_task_config)
            task_configs = new_task_configs

            model_args["task_configs"] = task_configs

        model_class = getattr(models, model_class)
        self.model = model_class(cuda=self.cuda,
                                 devices=self.devices,
                                 **model_args)

        model_dir = os.path.join(model_dir, "best")
        model_path = os.path.join(model_dir, "weights.pth.tar")
        if not os.path.isfile(model_path):
            model_path = os.path.join(model_dir, "weights.link")

        self.model.load_weights(model_path, device=self.device)

    def label_exam(self, label, report, info):
        """
        """
        report_sections = self.split_fn(report[0].lower())
        term_to_outputs = defaultdict(list)

        #logging.info(f"exam_id: {info['exam_id']}")
        for report_section in report_sections:
            curr_matches = self.term_graph.match_string(report_section)
            if not curr_matches:
                # skip report sections without matches
                continue

            tokens = self.vocab.tokenize(report_section)

            if len(tokens) > self.max_len:
                tokens = tokens[:self.max_len]

            tokens = self.vocab.wrap_sentence(tokens)
            inputs = {"report": [tokens]}
            output = self.model.predict(inputs)[self.match_task]
            output = output.cpu().detach().numpy().squeeze()

            #logging.info(f"section:{report_section}")
            for match in curr_matches:
                match_idxs = self.vocab.get_tokens_in_range(
                    tokens, report_section, match["start"], match["end"])

                match["output"] = output[match_idxs, 1]
                term = match["term_name"]
                term_to_outputs[term].append(np.mean(match["output"]))
                #logging.info(f"term: {match['term_name']} - {match['output']}")
            #logging.info("-"*5)

        labels = {}
        for term in self.terms:
            all_outputs = term_to_outputs[term][:]
            for descendant in self.term_graph.get_descendants(term):
                all_outputs.extend(term_to_outputs[descendant])
            all_outputs = np.array(all_outputs)
            prob = 1 - np.prod(1 - all_outputs)

            labels[(term, 0)] = 1 - prob
            labels[(term, 1)] = prob

            #logging.info(f"term: {term}")
            #logging.info(f"all_outputs: {all_outputs}")
            #logging.info(f"prob: {prob}")

        #logging.info("="*30 + "\n")
        return labels

    def _run(self, overwrite=False):
        """
        """
        exam_id_to_labels = {}
        for idx, (label, report, info) in enumerate(tqdm(self.dataloader)):
            labels = self.label_exam(label, report, info)
            exam_id_to_labels[info["exam_id"][0]] = labels

        labels_df = pd.DataFrame.from_dict(exam_id_to_labels, orient="index")
        labels_df.to_csv(os.path.join(self.dir, "exam_labels.csv"))