def get_risk_factors(args, ssn, exam, risk_metadata_json, json_dir, logger): ''' args: - args: - ssn: - exam: - json_dir: - logger: returns: - risk_factor_vector: ''' try: os.makedirs(json_dir) except Exception as e: pass args.metadata_path = "{}.json".format( os.path.join(json_dir, str(uuid.uuid4()))) args.risk_factor_metadata_path = "{}.json".format( os.path.join(json_dir, str(uuid.uuid4()))) # Write current request to a file to use as a metadata path prior_hist = risk_metadata_json[ssn]['any_breast_cancer'] == 1 metadata_json = [{ 'ssn': ssn, 'accessions': [{ 'accession': exam, 'prior_hist': prior_hist }] }] try: json.dump(metadata_json, open(args.metadata_path, 'w')) json.dump(risk_metadata_json, open(args.risk_factor_metadata_path, 'w')) except Exception as e: delete_jsons(args) err_msg = FAIL_TO_SAVE_METADATA_MESSAGE.format(ssn, exam, e, args) logger.error(err_msg) raise Exception(err_msg) # Load risk factor vector from metadata file and del metadata json try: risk_factor_vectorizer = RiskFactorVectorizer(args) sample = {'ssn': ssn, 'exam': exam} risk_factor_vector = risk_factor_vectorizer.get_risk_factors_for_sample( sample) logger.info(SUCCESS_RISK_VEC_MESSAGE.format(ssn, exam, args)) delete_jsons(args) return risk_factor_vector except Exception as e: delete_jsons(args) err_msg = FAIL_TO_GET_RISK_VECTOR_MESSAGE.format(ssn, exam, e, args) logger.error(err_msg) raise Exception(err_msg)
def __init__(self, args, transformers, split_group): ''' params: args - config. params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor params: split_group - ['train'|'dev'|'test']. constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching ''' super(Abstract_Onco_Dataset, self).__init__() args.metadata_path = os.path.join(args.metadata_dir, self.METADATA_FILENAME) self.args = args self.image_loader = image_loader(args.cache_path, transformers) try: self.metadata_json = json.load(open(args.metadata_path, 'r')) except Exception as e: raise Exception(METAFILE_NOTFOUND_ERR.format( args.metadata_path, e)) self.dataset = self.create_dataset(split_group, args.img_dir) if split_group == 'train' and self.args.data_fraction < 1.0: self.dataset = np.random.choice( self.dataset, int(len(self.dataset) * self.args.data_fraction), replace=False) self.risk_factor_vectorizer = RiskFactorVectorizer(args) if self.args.use_risk_factors: self.add_risk_factors_to_dataset() if 'dist_key' in self.dataset[0] and args.year_weighted_class_bal: dist_key = 'dist_key' else: dist_key = 'y' label_dist = [d[dist_key] for d in self.dataset] label_counts = Counter(label_dist) weight_per_label = 1. / len(label_counts) label_weights = { label: weight_per_label / count for label, count in label_counts.items() } if args.year_weighted_class_bal or args.class_bal: print("Label weights are {}".format(label_weights)) self.weights = [label_weights[d[dist_key]] for d in self.dataset]
def __init__(self, args, num_chan): super(RiskFactorPool, self).__init__(args, num_chan) self.args = args self.internal_pool = get_pool(args.pool_name)(args, num_chan) assert not self.internal_pool.replaces_fc() self.dropout = nn.Dropout(args.dropout) self.length_risk_factor_vector = RiskFactorVectorizer( args).vector_length self.fc = nn.Linear(self.length_risk_factor_vector + num_chan, args.num_classes) self.args.hidden_dim = self.length_risk_factor_vector + num_chan
def __init__(self, args, num_chan): super(RiskFactorPool, self).__init__(args, num_chan) self.args = args self.internal_pool = get_pool(args.pool_name)(args, num_chan) assert not self.internal_pool.replaces_fc() self.length_risk_factor_vector = RiskFactorVectorizer( args).vector_length input_dim = self.length_risk_factor_vector + num_chan self.fc1 = nn.Linear(input_dim, num_chan) self.relu = nn.ReLU(inplace=True) self.bn = nn.BatchNorm1d(num_chan) self.dropout = nn.Dropout(args.dropout) self.fc2 = nn.Linear(num_chan, args.num_classes) self.args.hidden_dim = num_chan
def __init__(self, args, num_chan): super(RiskFactorPool, self).__init__(args, num_chan) self.args = args self.internal_pool = get_pool(args.pool_name)(args, num_chan) assert not self.internal_pool.replaces_fc() self.dropout = nn.Dropout(args.dropout) self.length_risk_factor_vector = RiskFactorVectorizer( args).vector_length if args.pred_risk_factors: for key in args.risk_factor_keys: num_key_features = args.risk_factor_key_to_num_class[key] key_fc = nn.Linear(self.args.hidden_dim, num_key_features) self.add_module('{}_fc'.format(key), key_fc) self.args.img_only_dim = self.args.hidden_dim self.args.rf_dim = self.length_risk_factor_vector self.args.hidden_dim = self.args.rf_dim + self.args.img_only_dim
def __init__(self, args, transformers, split_group): ''' params: args - config. params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor params: split_group - ['train'|'dev'|'test']. constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching ''' super(Abstract_Onco_Dataset, self).__init__() if args.metadata_dir is not None and args.metadata_path is None: args.metadata_path = os.path.join(args.metadata_dir, self.METADATA_FILENAME) self.split_group = split_group self.args = args self.image_loader = image_loader(args.cache_path, transformers) try: if 'json' in args.metadata_path: self.metadata_json = json.load(open(args.metadata_path, 'r')) else: assert 'csv' in args.metadata_path _reader = csv.DictReader(open(args.metadata_path, 'r')) self.metadata_json = [r for r in _reader] except Exception as e: raise Exception(METAFILE_NOTFOUND_ERR.format( args.metadata_path, e)) self.path_to_hidden_dict = {} self.dataset = self.create_dataset(split_group, args.img_dir) if len(self.dataset) == 0: return if split_group == 'train' and self.args.data_fraction < 1.0: self.dataset = np.random.choice( self.dataset, int(len(self.dataset) * self.args.data_fraction), replace=False) try: self.add_device_to_dataset() if "all" not in self.args.allowed_devices: self.dataset = [ d for d in self.dataset if (d['device_name'] if isinstance(d['device_name'], str) else d['device_name'][0]) in self.args.allowed_devices ] except: print("Could not add device information to dataset") for d in self.dataset: if 'exam' in d and 'year' in d: args.exam_to_year_dict[d['exam']] = d['year'] if 'device_name' in d and 'exam' in d: args.exam_to_device_dict[d['exam']] = d['device_name'] print(self.get_summary_statement(self.dataset, split_group)) if args.use_region_annotation: self.region_annotations = parse_region_annotations(args) args.h_arr, args.w_arr = None, None self.risk_factor_vectorizer = None if self.args.use_risk_factors: self.risk_factor_vectorizer = RiskFactorVectorizer(args) self.add_risk_factors_to_dataset() if 'dist_key' in self.dataset[0] and ( args.year_weighted_class_bal or args.shift_class_bal_towards_imediate_cancers or args.device_class_bal): dist_key = 'dist_key' else: dist_key = 'y' label_dist = [d[dist_key] for d in self.dataset] label_counts = Counter(label_dist) weight_per_label = 1. / len(label_counts) label_weights = { label: weight_per_label / count for label, count in label_counts.items() } if args.year_weighted_class_bal or args.class_bal: print("Label weights are {}".format(label_weights)) self.weights = [label_weights[d[dist_key]] for d in self.dataset]
class Abstract_Onco_Dataset(data.Dataset): """ Abstract Object for all Onco Datasets. All datasets have some metadata property associated with them, a create_dataset method, a task, and a check label and get label function. """ __metaclass__ = ABCMeta def __init__(self, args, transformers, split_group): ''' params: args - config. params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor params: split_group - ['train'|'dev'|'test']. constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching ''' super(Abstract_Onco_Dataset, self).__init__() if args.metadata_dir is not None and args.metadata_path is None: args.metadata_path = os.path.join(args.metadata_dir, self.METADATA_FILENAME) self.split_group = split_group self.args = args self.image_loader = image_loader(args.cache_path, transformers) try: if 'json' in args.metadata_path: self.metadata_json = json.load(open(args.metadata_path, 'r')) else: assert 'csv' in args.metadata_path _reader = csv.DictReader(open(args.metadata_path, 'r')) self.metadata_json = [r for r in _reader] except Exception as e: raise Exception(METAFILE_NOTFOUND_ERR.format( args.metadata_path, e)) self.path_to_hidden_dict = {} self.dataset = self.create_dataset(split_group, args.img_dir) if len(self.dataset) == 0: return if split_group == 'train' and self.args.data_fraction < 1.0: self.dataset = np.random.choice( self.dataset, int(len(self.dataset) * self.args.data_fraction), replace=False) try: self.add_device_to_dataset() if "all" not in self.args.allowed_devices: self.dataset = [ d for d in self.dataset if (d['device_name'] if isinstance(d['device_name'], str) else d['device_name'][0]) in self.args.allowed_devices ] except: print("Could not add device information to dataset") for d in self.dataset: if 'exam' in d and 'year' in d: args.exam_to_year_dict[d['exam']] = d['year'] if 'device_name' in d and 'exam' in d: args.exam_to_device_dict[d['exam']] = d['device_name'] print(self.get_summary_statement(self.dataset, split_group)) if args.use_region_annotation: self.region_annotations = parse_region_annotations(args) args.h_arr, args.w_arr = None, None self.risk_factor_vectorizer = None if self.args.use_risk_factors: self.risk_factor_vectorizer = RiskFactorVectorizer(args) self.add_risk_factors_to_dataset() if 'dist_key' in self.dataset[0] and ( args.year_weighted_class_bal or args.shift_class_bal_towards_imediate_cancers or args.device_class_bal): dist_key = 'dist_key' else: dist_key = 'y' label_dist = [d[dist_key] for d in self.dataset] label_counts = Counter(label_dist) weight_per_label = 1. / len(label_counts) label_weights = { label: weight_per_label / count for label, count in label_counts.items() } if args.year_weighted_class_bal or args.class_bal: print("Label weights are {}".format(label_weights)) self.weights = [label_weights[d[dist_key]] for d in self.dataset] @property @abstractmethod def task(self): pass @property @abstractmethod def METADATA_FILENAME(self): pass @abstractmethod def check_label(self, row): ''' Return True if the row contains a valid label for the task :row: - metadata row ''' pass @abstractmethod def get_label(self, row): ''' Get task specific label for a given metadata row :row: - metadata row with contains label information ''' pass def get_summary_statement(self, dataset, split_group): ''' Return summary statement ''' return "" @abstractmethod def create_dataset(self, split_group, img_dir): """ Creating the dataset from the paths and labels in the json. :split_group: - ['train'|'dev'|'test']. :img_dir: - The path to the dir containing the images. """ pass @staticmethod def set_args(args): """Sets any args particular to the dataset.""" pass def __len__(self): return len(self.dataset) def __getitem__(self, index): if self.args.use_precomputed_hiddens: return self.get_vector_item(index) else: return self.get_image_item(index) def get_vector_item(self, index): try: sample = self.dataset[index] def get_hidden(path): zero_vec = np.zeros(self.args.precomputed_hidden_dim) return self.path_to_hidden_dict[ path] if path in self.path_to_hidden_dict and not self.args.zero_out_hiddens else zero_vec hiddens_for_paths = np.array( [get_hidden(path) for path in sample['paths']]) x = torch.Tensor(hiddens_for_paths) item = {'x': x, 'y': sample['y']} for key in DATASET_ITEM_KEYS: if key in sample: item[key] = sample[key] if self.args.use_risk_factors: item['risk_factors'] = sample['risk_factors'] return item except Exception: warnings.warn( LOAD_FAIL_MSG.format(sample['paths'], traceback.print_exc())) def get_image_item(self, index): sample = self.dataset[index] ''' Region annotation for each image. Dict for single image, list of dict for multi-image ''' if self.args.use_region_annotation: region_annotation = get_region_annotation_for_sample( sample, self.region_annotations, self.args) try: if self.args.multi_image: additionals = sample['additionals'] ''' Add region annotation to existing additionals so transformers can mutate them as need be (i.e if image flips or rotates) ''' if self.args.use_region_annotation: for img_index, path in enumerate(sample['paths']): if img_index == len(additionals): additionals.append({ 'region_annotation': region_annotation[img_index] }) else: additionals[img_index][ 'region_annotation'] = region_annotation[ img_index] x = self.image_loader.get_images(sample['paths'], additionals) else: additional = {} if sample['additional'] is None else sample[ 'additional'] if self.args.use_region_annotation: additional['region_annotation'] = region_annotation x = self.image_loader.get_image(sample['path'], additional) item = { 'x': x, 'path': "\t".join(sample['paths']) if self.args.multi_image else sample['path'], 'y': sample['y'] } if self.args.use_region_annotation: if self.args.multi_image: for coord in region_annotation[0]: annotation_list = ( lambda coord=coord, region_annotation= region_annotation: [ img_annotation[coord] for img_annotation in region_annotation ])() item[coord] = torch.Tensor(annotation_list) else: for coord in region_annotation: item[coord] = region_annotation[coord] for key in DATASET_ITEM_KEYS: if key in sample: item[key] = sample[key] if self.args.use_risk_factors: item['risk_factors'] = sample['risk_factors'] return item except Exception: if self.args.multi_image: warnings.warn( LOAD_FAIL_MSG.format(sample['paths'], traceback.print_exc())) else: warnings.warn( LOAD_FAIL_MSG.format(sample['path'], traceback.print_exc())) def add_risk_factors_to_dataset(self): for sample in self.dataset: sample[ 'risk_factors'] = self.risk_factor_vectorizer.get_risk_factors_for_sample( sample) def add_device_to_dataset(self): path_to_device, exam_to_device = self.build_path_to_device_map() for d in self.dataset: paths = [d['path']] if 'path' in d else d['paths'] d['device_name'], d['device'], d['device_is_known'] = [], [], [] for path in paths: device = path_to_device[path] device_id = DEVICE_TO_ID[ device] if device in DEVICE_TO_ID else 0 device_is_known = device in DEVICE_TO_ID d['device_name'].append( device.replace(' ', '_') if device is not None else "<UNK>" ) d['device'].append(device_id) d['device_is_known'].append(device_is_known) single_image = len(paths) == 1 if single_image: d['device_name'] = d['device_name'][0] d['device'] = d['device'][0] d['device_is_known'] = d['device_is_known'][0] else: d['device_name'] = np.array(d['device_name']) d['device'] = np.array(d['device']) d['device_is_known'] = np.array(d['device_is_known'], dtype=int) device_dist = Counter([ d['device'] if single_image else d['device'][-1] for d in self.dataset ]) print("Device Dist: {}".format(device_dist)) if self.split_group == 'train': device_count = list(device_dist.values()) self.args.device_entropy = entropy(device_count) print("Device Entropy: {}".format(self.args.device_entropy)) def build_path_to_device_map(self): path_to_device = {} exam_to_device = {} for mrn_row in json.load( open( '/Mounts/Isilon/metadata/mammo_metadata_all_years_only_breast_cancer_nov21_2019.json', 'r')): for exam in mrn_row['accessions']: exam_id = exam['accession'] for file, device, view in zip(exam['files'], exam['manufacturer_models'], exam['views']): device_name = '{} {}'.format( device, 'C-View') if 'C-View' in view else device path_to_device[file] = device_name exam_to_device[exam_id] = device_name return path_to_device, exam_to_device def image_paths_by_views(self, exam): ''' Determine images of left and right CCs and MLO. Args: exam - a dictionary with views and files sorted relatively. returns: 4 lists of image paths of each view by this order: left_ccs, left_mlos, right_ccs, right_mlos. Force max 1 image per view. Note: Validation of cancer side is performed in the query scripts/from_db/cancer.py in OncoQueries ''' source_dir = '/home/{}'.format( self.args.unix_username) if self.args.is_ccds_server else '' def get_view(view_name): image_paths_w_view = [ (view, image_path) for view, image_path in zip(exam['views'], exam['files']) if view.startswith(view_name) ] if self.args.use_c_view_if_available: filt_image_paths_w_view = [ (view, image_path) for view, image_path in image_paths_w_view if 'C-View' in view ] if len(filt_image_paths_w_view) > 0: image_paths_w_view = filt_image_paths_w_view else: image_paths_w_view = [ (view, image_path) for view, image_path in image_paths_w_view if 'C-View' not in view ] image_paths_w_view = image_paths_w_view[:1] image_paths = ( lambda image_paths, source_dir: [source_dir + path for _, path in image_paths_w_view])(image_paths_w_view, source_dir) return image_paths left_ccs = get_view('L CC') left_mlos = get_view('L MLO') right_ccs = get_view('R CC') right_mlos = get_view('R MLO') return left_ccs, left_mlos, right_ccs, right_mlos
class Abstract_Onco_Dataset(data.Dataset): """ Abstract Object for all Onco Datasets. All datasets have some metadata property associated with them, a create_dataset method, a task, and a check label and get label function. """ __metaclass__ = ABCMeta def __init__(self, args, transformers, split_group): ''' params: args - config. params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor params: split_group - ['train'|'dev'|'test']. constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching ''' super(Abstract_Onco_Dataset, self).__init__() args.metadata_path = os.path.join(args.metadata_dir, self.METADATA_FILENAME) self.args = args self.image_loader = image_loader(args.cache_path, transformers) try: self.metadata_json = json.load(open(args.metadata_path, 'r')) except Exception as e: raise Exception(METAFILE_NOTFOUND_ERR.format( args.metadata_path, e)) self.dataset = self.create_dataset(split_group, args.img_dir) if split_group == 'train' and self.args.data_fraction < 1.0: self.dataset = np.random.choice( self.dataset, int(len(self.dataset) * self.args.data_fraction), replace=False) self.risk_factor_vectorizer = RiskFactorVectorizer(args) if self.args.use_risk_factors: self.add_risk_factors_to_dataset() if 'dist_key' in self.dataset[0] and args.year_weighted_class_bal: dist_key = 'dist_key' else: dist_key = 'y' label_dist = [d[dist_key] for d in self.dataset] label_counts = Counter(label_dist) weight_per_label = 1. / len(label_counts) label_weights = { label: weight_per_label / count for label, count in label_counts.items() } if args.year_weighted_class_bal or args.class_bal: print("Label weights are {}".format(label_weights)) self.weights = [label_weights[d[dist_key]] for d in self.dataset] @property @abstractmethod def task(self): pass @property @abstractmethod def METADATA_FILENAME(self): pass @abstractmethod def check_label(self, row): ''' Return True if the row contains a valid label for the task :row: - metadata row ''' pass @abstractmethod def get_label(self, row): ''' Get task specific label for a given metadata row :row: - metadata row with contains label information ''' pass @abstractmethod def create_dataset(self, split_group, img_dir): """ Creating the dataset from the paths and labels in the json. :split_group: - ['train'|'dev'|'test']. :img_dir: - The path to the dir containing the images. """ pass @staticmethod def set_args(args): """Sets any args particular to the dataset.""" pass def __len__(self): return len(self.dataset) def __getitem__(self, index): sample = self.dataset[index] try: additional = {} if sample['additional'] is None else sample[ 'additional'] x = self.image_loader.get_image(sample['path'], additional) item = {'x': x, 'path': sample['path'], 'y': sample['y']} if 'exam' in sample: item['exam'] = sample['exam'] if self.args.use_risk_factors: # Note, risk factors not supported for target objects item['risk_factors'] = sample['risk_factors'] return item except Exception: warnings.warn( LOAD_FAIL_MSG.format(sample['path'], traceback.print_exc())) def add_risk_factors_to_dataset(self): for sample in self.dataset: sample[ 'risk_factors'] = self.risk_factor_vectorizer.get_risk_factors_for_sample( sample) def image_paths_by_views(self, exam): ''' Determine images of left and right CCs and MLO. Args: exam - a dictionary with views and files sorted relatively. returns: 4 lists of image paths of each view by this order: left_ccs, left_mlos, right_ccs, right_mlos. Force max 1 image per view. Note: Validation of cancer side is performed in the query scripts/from_db/cancer.py in OncoQueries ''' left_ccs = [ image_path for view, image_path in zip(exam['views'], exam['files']) if view.startswith('L CC') ][:1] left_mlos = [ image_path for view, image_path in zip(exam['views'], exam['files']) if view.startswith('L MLO') ][:1] right_ccs = [ image_path for view, image_path in zip(exam['views'], exam['files']) if view.startswith('R CC') ][:1] right_mlos = [ image_path for view, image_path in zip(exam['views'], exam['files']) if view.startswith('R MLO') ][:1] return left_ccs, left_mlos, right_ccs, right_mlos