def test_loads_from_cache(self): transformers = [self.transformers['c1']] loader = image_loader.image_loader(self.cache_path, transformers) image_path = '/some/test' # Trick loader by storing specific image in cache key = transformers[0].caching_keys() loader.cache.add(image_path, key, blue_image) # Load image output = loader.get_image(image_path, None) self.assertEqual(output.getdata()[0], blue_image.getdata()[0])
def __init__(self, args, transformers, split_group): ''' params: args - config. params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor params: split_group - ['train'|'dev'|'test']. constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching ''' super(Abstract_Onco_Dataset, self).__init__() args.metadata_path = os.path.join(args.metadata_dir, self.METADATA_FILENAME) self.args = args self.image_loader = image_loader(args.cache_path, transformers) try: self.metadata_json = json.load(open(args.metadata_path, 'r')) except Exception as e: raise Exception(METAFILE_NOTFOUND_ERR.format( args.metadata_path, e)) self.dataset = self.create_dataset(split_group, args.img_dir) if split_group == 'train' and self.args.data_fraction < 1.0: self.dataset = np.random.choice( self.dataset, int(len(self.dataset) * self.args.data_fraction), replace=False) self.risk_factor_vectorizer = RiskFactorVectorizer(args) if self.args.use_risk_factors: self.add_risk_factors_to_dataset() if 'dist_key' in self.dataset[0] and args.year_weighted_class_bal: dist_key = 'dist_key' else: dist_key = 'y' label_dist = [d[dist_key] for d in self.dataset] label_counts = Counter(label_dist) weight_per_label = 1. / len(label_counts) label_weights = { label: weight_per_label / count for label, count in label_counts.items() } if args.year_weighted_class_bal or args.class_bal: print("Label weights are {}".format(label_weights)) self.weights = [label_weights[d[dist_key]] for d in self.dataset]
def test_adds_to_cache(self): transformers = [self.transformers['c1']] loader = image_loader.image_loader(self.cache_path, transformers) # save some test image (cached dir is used only for convinient as tmp) image_path = self.cache_path + 'test.png' blue_image.save(image_path) # Load image output = loader.get_image(image_path, None) self.assertEqual(output.getdata()[0], white_image.getdata()[0]) # Validate that correct images were cached key = transformers[0].caching_keys() default_image = loader.cache.get(image_path, 'default/') self.assertEqual(default_image.getdata()[0], blue_image.getdata()[0]) c1_image = loader.cache.get(image_path, key) self.assertEqual(c1_image.getdata()[0], white_image.getdata()[0])
def test_non_cachable(self): transformers = [self.transformers['nc'], self.transformers['c1']] loader = image_loader.image_loader(self.cache_path, transformers) # save some test image image_path = self.cache_path + 'test.png' blue_image.save(image_path) # Load image output = loader.get_image(image_path, None) self.assertEqual(output.getdata()[0], white_image.getdata()[0]) # Validate that default image was cached default_image = loader.cache.get(image_path, 'default/') self.assertEqual(default_image.getdata()[0], blue_image.getdata()[0]) # Validate that 'nc' wasn't cached key = transformers[0].caching_keys() self.assertFalse(loader.cache.exists(image_path, key)) # Validate that 'c1' wasn't cached because it is after a non cachable key = transformers[1].caching_keys() self.assertFalse(loader.cache.exists(image_path, key))
def __init__(self, args, transformers, split_group): ''' params: args - config. params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor params: split_group - ['train'|'dev'|'test']. constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching ''' super(Abstract_Onco_Dataset, self).__init__() if args.metadata_dir is not None and args.metadata_path is None: args.metadata_path = os.path.join(args.metadata_dir, self.METADATA_FILENAME) self.split_group = split_group self.args = args self.image_loader = image_loader(args.cache_path, transformers) try: if 'json' in args.metadata_path: self.metadata_json = json.load(open(args.metadata_path, 'r')) else: assert 'csv' in args.metadata_path _reader = csv.DictReader(open(args.metadata_path, 'r')) self.metadata_json = [r for r in _reader] except Exception as e: raise Exception(METAFILE_NOTFOUND_ERR.format( args.metadata_path, e)) self.path_to_hidden_dict = {} self.dataset = self.create_dataset(split_group, args.img_dir) if len(self.dataset) == 0: return if split_group == 'train' and self.args.data_fraction < 1.0: self.dataset = np.random.choice( self.dataset, int(len(self.dataset) * self.args.data_fraction), replace=False) try: self.add_device_to_dataset() if "all" not in self.args.allowed_devices: self.dataset = [ d for d in self.dataset if (d['device_name'] if isinstance(d['device_name'], str) else d['device_name'][0]) in self.args.allowed_devices ] except: print("Could not add device information to dataset") for d in self.dataset: if 'exam' in d and 'year' in d: args.exam_to_year_dict[d['exam']] = d['year'] if 'device_name' in d and 'exam' in d: args.exam_to_device_dict[d['exam']] = d['device_name'] print(self.get_summary_statement(self.dataset, split_group)) if args.use_region_annotation: self.region_annotations = parse_region_annotations(args) args.h_arr, args.w_arr = None, None self.risk_factor_vectorizer = None if self.args.use_risk_factors: self.risk_factor_vectorizer = RiskFactorVectorizer(args) self.add_risk_factors_to_dataset() if 'dist_key' in self.dataset[0] and ( args.year_weighted_class_bal or args.shift_class_bal_towards_imediate_cancers or args.device_class_bal): dist_key = 'dist_key' else: dist_key = 'y' label_dist = [d[dist_key] for d in self.dataset] label_counts = Counter(label_dist) weight_per_label = 1. / len(label_counts) label_weights = { label: weight_per_label / count for label, count in label_counts.items() } if args.year_weighted_class_bal or args.class_bal: print("Label weights are {}".format(label_weights)) self.weights = [label_weights[d[dist_key]] for d in self.dataset]