def __init__(self, dataset_dir, vocab_args, term_graph_args, image_types=None, size=None, normalize=True, image_transform_configs=[], sampling_window=None, sampling_rate=1, report_transform_configs=[], task_configs=[], max_len=None, skip_scans=False, split=None): """ """ self.skip_scans = skip_scans if not skip_scans: ImageDataset.__init__(self, dataset_dir, image_types=image_types, size=size, normalize=normalize, sampling_window=sampling_window, transform_configs=image_transform_configs, sampling_rate=sampling_rate, split=split) ReportDataset.__init__(self, dataset_dir, split=split, transform_configs=report_transform_configs) self.task_configs = { task_config["fn"]: task_config for task_config in task_configs } self.max_len = max_len self.vocab = WordPieceVocab(**vocab_args) self.term_graph = TermGraph(**term_graph_args)
def __init__(self, dataset_dir, vocab_args={}, transform_configs=[], pct_tokens=0.15, pct_random=0.10, pct_same=0.10, max_len=None, split=None): ReportDataset.__init__(self, dataset_dir, transform_configs, split) """ WARNING: The vocab file used here MUST match the vocab file used in your model. """ self.vocab = WordPieceVocab(**vocab_args) # parameters for mlm masking self.pct_tokens = pct_tokens self.pct_random = pct_random self.pct_same = pct_same self.max_len = max_len
def __init__(self, dir, model_dir, dataset_class="ReportDataset", dataset_args={}, term_graph_dir="data/pet_ct_terms/terms.json", terms="all", match_task="fdg_abnorm", split_fn="split_impression_sections", max_len=200, vocab_args={}, seed=123, cuda=True, devices=[0]): """ """ super().__init__(dir) self.cuda = cuda self.devices = devices self.device = devices[0] self.split_fn = globals()[split_fn] dataset = getattr(datasets, dataset_class)(**dataset_args) self.dataloader = DataLoader(dataset, batch_size=1) logging.info("Loading TermGraph and Vocab...") self.match_task = match_task self.terms = terms self.term_graph = TermGraph(term_graph_dir) if terms == "all": self.terms = self.term_graph.term_names else: self.terms = terms self.vocab = WordPieceVocab(**vocab_args) self.max_len = max_len logging.info("Loading Model...") self._load_model(model_dir)
def __init__(self, dataset_dir, split=None, task_configs={}, vocab_args={}, max_len=200, cls_label=True): """ """ self.dataset_dir = dataset_dir self.split = split if split is None or split == "all": matches_path = os.path.join(dataset_dir, 'match_labels.csv') else: matches_path = os.path.join(dataset_dir, 'split', f'{split}.csv') self.matches_df = pd.read_csv(matches_path) self.matches_df = self.matches_df[self.matches_df["not_applicable"] == False] self.vocab = WordPieceVocab(**vocab_args) self.max_len = max_len self.cls_label = cls_label
def __init__(self, scan_encoder_class=None, scan_encoder_args={}, bert_class=None, bert_args={}, scan_decoder_class=None, scan_decoder_args={}, task_configs=[], vocab_args={}, loss_weighting=None, optim_class="Adam", optim_args={}, scheduler_class=None, scheduler_args={}, pretrained_configs=[], cuda=True, devices=[0]): """ """ super().__init__(optim_class, optim_args, scheduler_class, scheduler_args, pretrained_configs, cuda, devices) self.encodes_scans = scan_encoder_class is not None if self.encodes_scans: self.scan_encoder = getattr( modules, scan_encoder_class)(**scan_encoder_args) self.scan_encoder = nn.DataParallel(self.scan_encoder, device_ids=self.devices) if bert_class == "BertModelPreTrained": self.bert = BertModel.from_pretrained(**bert_args) elif bert_class == "BertForPretraining": self.bert = BertForPreTraining.from_pretrained(**bert_args) elif bert_class == "BertModel": bert_args["config"] = BertConfig.from_dict(bert_args["config"]) self.bert = BertModel(**bert_args) else: self.bert = getattr(modules, bert_class)(**bert_args) self.bert = nn.DataParallel(self.bert, device_ids=self.devices) self.decodes_scans = scan_decoder_class is not None if self.decodes_scans: self.scan_decoder = getattr( modules, scan_decoder_class)(**scan_decoder_args) self.task_heads = {} self.task_inputs = {} for task_head_config in task_configs: task = task_head_config["task"] head_class = getattr(modules, task_head_config["class"]) args = task_head_config["args"] self.task_inputs[task] = (task_head_config["inputs"] if "inputs" in task_head_config else "pool") if "config" in args: # bert task heads take config object for parameters, must convert from dict config = args["config"] args["config"] = namedtuple("Config", config.keys())(*config.values()) if head_class is BertOnlyMLMHead: embs = self.bert.module.embeddings.word_embeddings.weight self.task_heads[task] = head_class( bert_model_embedding_weights=embs, **args) else: self.task_heads[task] = head_class(**args) self.task_heads = torch.nn.ModuleDict(self.task_heads) self.vocab = WordPieceVocab(**vocab_args) self._build_loss(loss_weighting) self._post_init()
class BertScanModel(BaseModel): """ """ def __init__(self, scan_encoder_class=None, scan_encoder_args={}, bert_class=None, bert_args={}, scan_decoder_class=None, scan_decoder_args={}, task_configs=[], vocab_args={}, loss_weighting=None, optim_class="Adam", optim_args={}, scheduler_class=None, scheduler_args={}, pretrained_configs=[], cuda=True, devices=[0]): """ """ super().__init__(optim_class, optim_args, scheduler_class, scheduler_args, pretrained_configs, cuda, devices) self.encodes_scans = scan_encoder_class is not None if self.encodes_scans: self.scan_encoder = getattr( modules, scan_encoder_class)(**scan_encoder_args) self.scan_encoder = nn.DataParallel(self.scan_encoder, device_ids=self.devices) if bert_class == "BertModelPreTrained": self.bert = BertModel.from_pretrained(**bert_args) elif bert_class == "BertForPretraining": self.bert = BertForPreTraining.from_pretrained(**bert_args) elif bert_class == "BertModel": bert_args["config"] = BertConfig.from_dict(bert_args["config"]) self.bert = BertModel(**bert_args) else: self.bert = getattr(modules, bert_class)(**bert_args) self.bert = nn.DataParallel(self.bert, device_ids=self.devices) self.decodes_scans = scan_decoder_class is not None if self.decodes_scans: self.scan_decoder = getattr( modules, scan_decoder_class)(**scan_decoder_args) self.task_heads = {} self.task_inputs = {} for task_head_config in task_configs: task = task_head_config["task"] head_class = getattr(modules, task_head_config["class"]) args = task_head_config["args"] self.task_inputs[task] = (task_head_config["inputs"] if "inputs" in task_head_config else "pool") if "config" in args: # bert task heads take config object for parameters, must convert from dict config = args["config"] args["config"] = namedtuple("Config", config.keys())(*config.values()) if head_class is BertOnlyMLMHead: embs = self.bert.module.embeddings.word_embeddings.weight self.task_heads[task] = head_class( bert_model_embedding_weights=embs, **args) else: self.task_heads[task] = head_class(**args) self.task_heads = torch.nn.ModuleDict(self.task_heads) self.vocab = WordPieceVocab(**vocab_args) self._build_loss(loss_weighting) self._post_init() def show_attention(self, inputs): """ Must be initialized with BertConfig where config.output_attentions = True """ # turn on output attentions report_inputs = inputs["report"] report_input_ids, attention_mask = self.vocab.to_input_tensor( report_inputs, device=self.device) outputs = self.bert(input_ids=report_input_ids, attention_mask=attention_mask) return outputs def forward(self, inputs, targets): """ Args: inputs (torch.Tensor) a (batch_size, ...) shaped input tensor targets (list) a list of task targets from t0, t1...tn. The last element should be a list(list(str)) representing the target report. Return: outputs (list) of """ if self.encodes_scans: scan_inputs = inputs["scan"] scan_encodings = self.scan_encoder(scan_inputs) # TODO: 3d positional encodings scan_encodings = scan_encodings.view(scan_encodings.shape[0], scan_encodings.shape[1], -1) scan_encodings = scan_encodings.permute(0, 2, 1) report_inputs = inputs["report"] report_input_ids, attention_mask = self.vocab.to_input_tensor( report_inputs, device=self.device) bert_seq, pool = self.bert(input_ids=report_input_ids, attention_mask=attention_mask, output_all_encoded_layers=False) if self.decodes_scans: scan_seq, pool = self.scan_decoder(scan_encodings=scan_encodings, report_encodings=bert_seq, attention_mask=attention_mask, output_all_encoded_layers=False) # residual seq = bert_seq + scan_seq else: seq = bert_seq task_outputs = {} for task, task_head in self.task_heads.items(): if self.task_inputs[task] == "seq": task_outputs[task] = task_head(seq) elif self.task_inputs[task] == "pool": task_outputs[task] = task_head(pool) else: raise ValueError("Input type not recognized.") return task_outputs def predict(self, inputs, probabilities=True): """ """ task_outputs = self.forward(inputs, None) task_predictions = { task: nn.functional.softmax(output, dim=-1) for task, output in task_outputs.items() } if not probabilities: task_predictions = { task: torch.argmax(output, dim=-1, keepdim=True) for task, output in task_predictions.items() } return task_predictions def loss(self, outputs, targets): """ """ total_loss = 0 for task, output in outputs.items(): curr_loss = self.loss_fn(output.view(-1, output.shape[-1]), targets[task].view(-1)) total_loss += curr_loss return total_loss def _build_loss(self, loss_weighting=None): """ """ if loss_weighting is None: class_weights = None elif loss_weighting == "log": class_weights = torch.ones(len(self.vocab)) for token, freq in self.vocab.token_to_freq.items(): if freq == 0: continue class_weights[ self.vocab.token_to_idx[token]] = 1 / np.log(freq + 1) else: raise ValueError("Loss weighting scheme not recognized.") self.loss_fn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1) def _log_predictions(self, inputs, targets, predictions, info=None): """ """ predictions = { task: torch.argmax(output, dim=-1) for task, output in predictions.items() } if "scan_mlm" in predictions: logging.info(f"Exam ID: {info[0]['exam_id']}") mlm_preds, mlm_targets = predictions["scan_mlm"], targets[ "scan_mlm"] mlm_preds = mlm_preds[mlm_targets != -1] mlm_targets = mlm_targets[mlm_targets != -1] # handle batch size of 1 if len(mlm_targets.shape) < 2: mlm_preds = mlm_preds.unsqueeze(0) mlm_targets = mlm_targets.unsqueeze(0) if mlm_preds.shape[-1] != 0: mlm_preds = self.vocab.from_output_tensor(mlm_preds) mlm_targets = self.vocab.from_output_tensor(mlm_targets) for curr_inputs, target, pred in zip(inputs["report"], mlm_targets, mlm_preds): logging.info(f"Inputs: {curr_inputs}") logging.info(f"Targets: {' --- '.join(target)}") logging.info(f"Predict: {' --- '.join(pred)}") if "scan_match" in predictions: logging.info(f"Inputs: {inputs['report'][0]}") logging.info(f"Matched: {targets['scan_match']}") logging.info(f"Predicted: {predictions['scan_match']}")
class MatchDataset: def __init__(self, dataset_dir, split=None, task_configs={}, vocab_args={}, max_len=200, cls_label=True): """ """ self.dataset_dir = dataset_dir self.split = split if split is None or split == "all": matches_path = os.path.join(dataset_dir, 'match_labels.csv') else: matches_path = os.path.join(dataset_dir, 'split', f'{split}.csv') self.matches_df = pd.read_csv(matches_path) self.matches_df = self.matches_df[self.matches_df["not_applicable"] == False] self.vocab = WordPieceVocab(**vocab_args) self.max_len = max_len self.cls_label = cls_label def __len__(self): """ """ return len(self.matches_df) def __getitem__(self, idx): """ """ match = self.matches_df.iloc[idx] text = match["text"] tokens = self.vocab.tokenize(text) if self.max_len is not None and self.max_len < len(tokens): tokens = tokens[:self.max_len] tokens = self.vocab.wrap_sentence(tokens) # get numeric label if match["fdg_abnormality_label"] == "abnormal": label = 1 elif match["fdg_abnormality_label"] == "normal": label = 0 else: label = 0 labels = -1 * torch.ones(len(tokens), dtype=torch.long) if self.cls_label: labels[0] = label # label tokens match_start, match_end = match["start"], match["end"] find_start = 0 for idx, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue if token.startswith("##"): # remove pounds token = token[2:] token_start = text.find(token, find_start) token_end = token_start + len(token) find_start = token_end if ((token_start >= match_start and token_start < match_end) or (token_end >= match_start and token_end < match_end)): labels[idx] = label inputs = {"report": tokens} targets = {"fdg_abnorm": labels} info = {"exam_id": match["exam_id"], "term_name": match["term_name"]} return inputs, targets, info
class BertPretrainingDataset(ReportDataset): """ """ def __init__(self, dataset_dir, vocab_args={}, transform_configs=[], pct_tokens=0.15, pct_random=0.10, pct_same=0.10, max_len=None, split=None): ReportDataset.__init__(self, dataset_dir, transform_configs, split) """ WARNING: The vocab file used here MUST match the vocab file used in your model. """ self.vocab = WordPieceVocab(**vocab_args) # parameters for mlm masking self.pct_tokens = pct_tokens self.pct_random = pct_random self.pct_same = pct_same self.max_len = max_len def get_targets(self, tasks=[]): """ """ dataset = H5Dataset(self.dataset_name, self.data_dir, mode="read") for idx, exam in self.exams_df.iterrows(): yield self._get_targets(exam, dataset, tasks=tasks) def _get_targets(self, exam, dataset=None, tasks=[]): """ """ targets = {} if "abnorm" in tasks: targets["abnorm"] = torch.tensor(exam['label']) return targets def __len__(self): """ """ return len(self.exams_df.index) def __getitem__(self, idx): """ """ dataset = H5Dataset(self.dataset_name, self.data_dir, mode="read") exam = self.exams_df.iloc[int(idx)] label = exam['label'] exam_id = exam['exam_id'] patient_id = exam['patient_id'] report = self._get_report(exam, dataset) report = self.vocab.tokenize(report) if self.max_len is not None and self.max_len < len(report): report = report[:self.max_len] report = self.vocab.wrap_sentence(report) report, mask_labels = self._mask_inputs(report) inputs = {"report": report} targets = {"mlm": mask_labels, "abnorm": torch.tensor(label)} info = {"exam_id": exam_id, "patient_id": patient_id} return inputs, targets, info def _mask_inputs(self, report): """ """ # don't sample [CLS] or [SEP] tokens TODO: make this work for middle of sequence [SEP] n_tokens = max( 0, np.ceil((len(report) - 2) * self.pct_tokens).astype(int)) token_idxs = np.random.choice(np.arange(1, len(report) - 1), size=n_tokens, replace=False) labels = -1 * torch.ones(len(report), dtype=torch.long) masked_report = report.copy() for token_idx in token_idxs: sample = np.random.rand() if sample < self.pct_random: # TODO: check if this is correct, do we also change the input word labels[token_idx] = int( np.random.choice(list(self.vocab.idx_to_token.keys()))) elif sample < self.pct_random + self.pct_same: labels[token_idx] = self.vocab.token_to_idx[report[token_idx]] else: masked_report[token_idx] = "[MASK]" labels[token_idx] = self.vocab.token_to_idx[report[token_idx]] return masked_report, labels
class BertScanDataset(ImageDataset, ReportDataset): """ """ def __init__(self, dataset_dir, vocab_args, term_graph_args, image_types=None, size=None, normalize=True, image_transform_configs=[], sampling_window=None, sampling_rate=1, report_transform_configs=[], task_configs=[], max_len=None, skip_scans=False, split=None): """ """ self.skip_scans = skip_scans if not skip_scans: ImageDataset.__init__(self, dataset_dir, image_types=image_types, size=size, normalize=normalize, sampling_window=sampling_window, transform_configs=image_transform_configs, sampling_rate=sampling_rate, split=split) ReportDataset.__init__(self, dataset_dir, split=split, transform_configs=report_transform_configs) self.task_configs = { task_config["fn"]: task_config for task_config in task_configs } self.max_len = max_len self.vocab = WordPieceVocab(**vocab_args) self.term_graph = TermGraph(**term_graph_args) def __len__(self): """ """ return len(self.exams_df.index) def __getitem__(self, idx): """ """ dataset = H5Dataset(self.dataset_name, self.data_dir, mode="read") exam = self.exams_df.iloc[int(idx)] label, exam_id, patient_id = exam['label'], exam['exam_id'], exam[ 'patient_id'] report = self._get_report(exam, dataset) images = self._get_images(exam, dataset) if not self.skip_scans else None targets = {} # must perform scan_match first if "scan_match" in self.task_configs: args = self.task_configs["scan_match"]["args"] report, labels = self.scan_match(exam_id, report, dataset, **args) targets["scan_match"] = labels # tokenize, trim if over max length, and wrap sentence report = self.vocab.tokenize(report) if self.max_len is not None and self.max_len < len(report): report = report[:self.max_len] report = self.vocab.wrap_sentence(report) if "scan_mlm" in self.task_configs: args = self.task_configs["scan_mlm"]["args"] report, labels = self.scan_mlm(exam_id, report, **args) targets["scan_mlm"] = labels info = {"exam_id": exam_id, "patient_id": patient_id} inputs = {"report": report, "scan": images} return inputs, targets, info def scan_match(self, exam_id, report, dataset, pct_same): """ """ sample = torch.rand(1).item() label = torch.tensor(1) if sample >= pct_same: self.mismatched = True label = torch.tensor(0) random_idx = np.random.randint(len(self.exams_df)) random_exam = self.exams_df.iloc[random_idx] report = self._get_report(random_exam, dataset) return report, label def scan_mlm(self, exam_id, report, rand_default_mask_prob=0.025, term_default_mask_prob=0.85, term_to_mask_prob={}): """ """ # TODO: matches = self.term_graph.get_matches(report) # masked_report = report.copy() labels = -1 * torch.ones(len(report), dtype=torch.long) # mask tokens self.term_graph.bernoulli_sample(term_to_prob=term_to_mask_prob, default_prob=term_default_mask_prob, sample_name="mask_sample") for token_idx, token_matches in enumerate(matches): token = report[token_idx] masked = False if len(token_matches) == 0: masked = bool( torch.bernoulli(torch.tensor(rand_default_mask_prob))) else: token = report[token_idx] for match in token_matches: if self.term_graph[match]["mask_sample"]: masked = True break if masked: masked_report[token_idx] = "[MASK]" labels[token_idx] = self.vocab.token_to_idx[token] return masked_report, labels
class ExamLabelsPredictor(Process): """ """ def __init__(self, dir, model_dir, dataset_class="ReportDataset", dataset_args={}, term_graph_dir="data/pet_ct_terms/terms.json", terms="all", match_task="fdg_abnorm", split_fn="split_impression_sections", max_len=200, vocab_args={}, seed=123, cuda=True, devices=[0]): """ """ super().__init__(dir) self.cuda = cuda self.devices = devices self.device = devices[0] self.split_fn = globals()[split_fn] dataset = getattr(datasets, dataset_class)(**dataset_args) self.dataloader = DataLoader(dataset, batch_size=1) logging.info("Loading TermGraph and Vocab...") self.match_task = match_task self.terms = terms self.term_graph = TermGraph(term_graph_dir) if terms == "all": self.terms = self.term_graph.term_names else: self.terms = terms self.vocab = WordPieceVocab(**vocab_args) self.max_len = max_len logging.info("Loading Model...") self._load_model(model_dir) def _load_model(self, model_dir): """ """ with open(os.path.join(model_dir, "params.json")) as f: args = json.load(f)["process_args"] model_class = args["model_class"] model_args = args["model_args"] if "task_configs" in args: new_task_configs = [] for task_config in args["task_configs"]: new_task_config = args["default_task_config"].copy() new_task_config.update(task_config) new_task_configs.append(new_task_config) task_configs = new_task_configs model_args["task_configs"] = task_configs model_class = getattr(models, model_class) self.model = model_class(cuda=self.cuda, devices=self.devices, **model_args) model_dir = os.path.join(model_dir, "best") model_path = os.path.join(model_dir, "weights.pth.tar") if not os.path.isfile(model_path): model_path = os.path.join(model_dir, "weights.link") self.model.load_weights(model_path, device=self.device) def label_exam(self, label, report, info): """ """ report_sections = self.split_fn(report[0].lower()) term_to_outputs = defaultdict(list) #logging.info(f"exam_id: {info['exam_id']}") for report_section in report_sections: curr_matches = self.term_graph.match_string(report_section) if not curr_matches: # skip report sections without matches continue tokens = self.vocab.tokenize(report_section) if len(tokens) > self.max_len: tokens = tokens[:self.max_len] tokens = self.vocab.wrap_sentence(tokens) inputs = {"report": [tokens]} output = self.model.predict(inputs)[self.match_task] output = output.cpu().detach().numpy().squeeze() #logging.info(f"section:{report_section}") for match in curr_matches: match_idxs = self.vocab.get_tokens_in_range( tokens, report_section, match["start"], match["end"]) match["output"] = output[match_idxs, 1] term = match["term_name"] term_to_outputs[term].append(np.mean(match["output"])) #logging.info(f"term: {match['term_name']} - {match['output']}") #logging.info("-"*5) labels = {} for term in self.terms: all_outputs = term_to_outputs[term][:] for descendant in self.term_graph.get_descendants(term): all_outputs.extend(term_to_outputs[descendant]) all_outputs = np.array(all_outputs) prob = 1 - np.prod(1 - all_outputs) labels[(term, 0)] = 1 - prob labels[(term, 1)] = prob #logging.info(f"term: {term}") #logging.info(f"all_outputs: {all_outputs}") #logging.info(f"prob: {prob}") #logging.info("="*30 + "\n") return labels def _run(self, overwrite=False): """ """ exam_id_to_labels = {} for idx, (label, report, info) in enumerate(tqdm(self.dataloader)): labels = self.label_exam(label, report, info) exam_id_to_labels[info["exam_id"][0]] = labels labels_df = pd.DataFrame.from_dict(exam_id_to_labels, orient="index") labels_df.to_csv(os.path.join(self.dir, "exam_labels.csv"))