예제 #1
0
    def test_loads_from_cache(self):
        transformers = [self.transformers['c1']]
        loader = image_loader.image_loader(self.cache_path, transformers)
        image_path = '/some/test'

        # Trick loader by storing specific image in cache
        key = transformers[0].caching_keys()
        loader.cache.add(image_path, key, blue_image)

        # Load image
        output = loader.get_image(image_path, None)
        self.assertEqual(output.getdata()[0], blue_image.getdata()[0])
예제 #2
0
    def __init__(self, args, transformers, split_group):
        '''
        params: args - config.
        params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor
        params: split_group - ['train'|'dev'|'test'].

        constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching
        '''
        super(Abstract_Onco_Dataset, self).__init__()
        args.metadata_path = os.path.join(args.metadata_dir,
                                          self.METADATA_FILENAME)

        self.args = args
        self.image_loader = image_loader(args.cache_path, transformers)
        try:
            self.metadata_json = json.load(open(args.metadata_path, 'r'))
        except Exception as e:
            raise Exception(METAFILE_NOTFOUND_ERR.format(
                args.metadata_path, e))

        self.dataset = self.create_dataset(split_group, args.img_dir)
        if split_group == 'train' and self.args.data_fraction < 1.0:
            self.dataset = np.random.choice(
                self.dataset,
                int(len(self.dataset) * self.args.data_fraction),
                replace=False)
        self.risk_factor_vectorizer = RiskFactorVectorizer(args)
        if self.args.use_risk_factors:
            self.add_risk_factors_to_dataset()

        if 'dist_key' in self.dataset[0] and args.year_weighted_class_bal:
            dist_key = 'dist_key'
        else:
            dist_key = 'y'

        label_dist = [d[dist_key] for d in self.dataset]
        label_counts = Counter(label_dist)
        weight_per_label = 1. / len(label_counts)
        label_weights = {
            label: weight_per_label / count
            for label, count in label_counts.items()
        }
        if args.year_weighted_class_bal or args.class_bal:
            print("Label weights are {}".format(label_weights))
        self.weights = [label_weights[d[dist_key]] for d in self.dataset]
예제 #3
0
    def test_adds_to_cache(self):
        transformers = [self.transformers['c1']]
        loader = image_loader.image_loader(self.cache_path, transformers)

        # save some test image (cached dir is used only for convinient as tmp)
        image_path = self.cache_path + 'test.png'
        blue_image.save(image_path)

        # Load image
        output = loader.get_image(image_path, None)
        self.assertEqual(output.getdata()[0], white_image.getdata()[0])

        # Validate that correct images were cached
        key = transformers[0].caching_keys()
        default_image = loader.cache.get(image_path, 'default/')
        self.assertEqual(default_image.getdata()[0], blue_image.getdata()[0])

        c1_image = loader.cache.get(image_path, key)
        self.assertEqual(c1_image.getdata()[0], white_image.getdata()[0])
예제 #4
0
    def test_non_cachable(self):
        transformers = [self.transformers['nc'], self.transformers['c1']]
        loader = image_loader.image_loader(self.cache_path, transformers)

        # save some test image
        image_path = self.cache_path + 'test.png'
        blue_image.save(image_path)

        # Load image
        output = loader.get_image(image_path, None)
        self.assertEqual(output.getdata()[0], white_image.getdata()[0])

        # Validate that default image was cached
        default_image = loader.cache.get(image_path, 'default/')
        self.assertEqual(default_image.getdata()[0], blue_image.getdata()[0])

        # Validate that 'nc' wasn't cached
        key = transformers[0].caching_keys()
        self.assertFalse(loader.cache.exists(image_path, key))

        # Validate that 'c1' wasn't cached because it is after a non cachable
        key = transformers[1].caching_keys()
        self.assertFalse(loader.cache.exists(image_path, key))
예제 #5
0
    def __init__(self, args, transformers, split_group):
        '''
        params: args - config.
        params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor
        params: split_group - ['train'|'dev'|'test'].

        constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching
        '''
        super(Abstract_Onco_Dataset, self).__init__()

        if args.metadata_dir is not None and args.metadata_path is None:
            args.metadata_path = os.path.join(args.metadata_dir,
                                              self.METADATA_FILENAME)

        self.split_group = split_group
        self.args = args
        self.image_loader = image_loader(args.cache_path, transformers)

        try:
            if 'json' in args.metadata_path:
                self.metadata_json = json.load(open(args.metadata_path, 'r'))
            else:
                assert 'csv' in args.metadata_path
                _reader = csv.DictReader(open(args.metadata_path, 'r'))
                self.metadata_json = [r for r in _reader]
        except Exception as e:
            raise Exception(METAFILE_NOTFOUND_ERR.format(
                args.metadata_path, e))

        self.path_to_hidden_dict = {}
        self.dataset = self.create_dataset(split_group, args.img_dir)
        if len(self.dataset) == 0:
            return
        if split_group == 'train' and self.args.data_fraction < 1.0:
            self.dataset = np.random.choice(
                self.dataset,
                int(len(self.dataset) * self.args.data_fraction),
                replace=False)
        try:
            self.add_device_to_dataset()
            if "all" not in self.args.allowed_devices:
                self.dataset = [
                    d for d in self.dataset
                    if (d['device_name'] if isinstance(d['device_name'], str)
                        else d['device_name'][0]) in self.args.allowed_devices
                ]
        except:
            print("Could not add device information to dataset")
        for d in self.dataset:
            if 'exam' in d and 'year' in d:
                args.exam_to_year_dict[d['exam']] = d['year']
            if 'device_name' in d and 'exam' in d:
                args.exam_to_device_dict[d['exam']] = d['device_name']
        print(self.get_summary_statement(self.dataset, split_group))
        if args.use_region_annotation:
            self.region_annotations = parse_region_annotations(args)
        args.h_arr, args.w_arr = None, None
        self.risk_factor_vectorizer = None
        if self.args.use_risk_factors:
            self.risk_factor_vectorizer = RiskFactorVectorizer(args)
            self.add_risk_factors_to_dataset()

        if 'dist_key' in self.dataset[0] and (
                args.year_weighted_class_bal
                or args.shift_class_bal_towards_imediate_cancers
                or args.device_class_bal):
            dist_key = 'dist_key'
        else:
            dist_key = 'y'

        label_dist = [d[dist_key] for d in self.dataset]
        label_counts = Counter(label_dist)
        weight_per_label = 1. / len(label_counts)
        label_weights = {
            label: weight_per_label / count
            for label, count in label_counts.items()
        }
        if args.year_weighted_class_bal or args.class_bal:
            print("Label weights are {}".format(label_weights))
        self.weights = [label_weights[d[dist_key]] for d in self.dataset]