def test_create_subset_random_sampler_range_str(self): """ Tests whther SubsetRandomSampler accepts 'indices' with the option 2: range as str. """ range_str = '0, 20' config = ConfigInterface() config.add_default_params({'name': 'SubsetRandomSampler', 'indices': range_str}) # Create the sampler. sampler = SamplerFactory.build(TestProblemMockup(), config) # Check number of samples. self.assertEqual(len(sampler), 20)
def test_create_subset_random_sampler_list_of_indices(self): """ Tests whther SubsetRandomSampler accepts 'indices' with the option 3: list of indices. """ yaml_list = yaml.load('[0, 2, 5, 10]') config = ConfigInterface() config.add_default_params({'name': 'SubsetRandomSampler', 'indices': yaml_list}) # Create the sampler. sampler = SamplerFactory.build(TestProblemMockup(), config) # Check number of samples. self.assertEqual(len(sampler), 4)
def test_create_subset_random_sampler_range(self): """ Tests whther SubsetRandomSampler accepts 'indices' with the option 1: range. """ indices = range(20) config = ConfigInterface() config.add_default_params({ 'type': 'SubsetRandomSampler', 'indices': indices }) # Create the sampler. sampler = SamplerFactory.build(TestTaskMockup(), config, "training") # Check number of samples. self.assertEqual(len(sampler), 20)
def test_create_subset_random_sampler_file(self): """ Tests whther SubsetRandomSampler accepts 'indices' with the option 4: name of the file containing indices. """ filename = "/tmp/tmp_indices.txt" # Store indices to file. indices = np.asarray([1,2,3,4,5],dtype=int) # Write array to file, separate elements with commas. indices.tofile(filename, sep=',', format="%s") config = ConfigInterface() config.add_default_params({'name': 'SubsetRandomSampler', 'indices': filename}) # Create the sampler. sampler = SamplerFactory.build(TestProblemMockup(), config) # Check number of samples. self.assertEqual(len(sampler), 5) #if __name__ == "__main__": # unittest.main()
def build(self, log=True): """ Method creates a problem on the basis of configuration section. :param log: Logs information and the detected errors (DEFAULT: TRUE) :return: number of detected errors """ try: # Create component. component, class_obj = ComponentFactory.build( "problem", self.config["problem"]) # Check if class is derived (even indirectly) from Problem. if not ComponentFactory.check_inheritance(class_obj, ptp.Problem.__name__): raise ConfigurationError( "Class '{}' is not derived from the Problem class!".format( class_obj.__name__)) # Set problem. self.problem = component # Try to build the sampler. self.sampler = SamplerFactory.build(self.problem, self.config["sampler"]) if self.sampler is not None: # Set shuffle to False - REQUIRED as those two are exclusive. self.config['dataloader'].add_config_params({'shuffle': False}) # build the DataLoader on top of the validation problem self.dataloader = DataLoader( dataset=self.problem, batch_size=self.config['problem']['batch_size'], shuffle=self.config['dataloader']['shuffle'], sampler=self.sampler, batch_sampler=self.config['dataloader']['batch_sampler'], num_workers=self.config['dataloader']['num_workers'], collate_fn=self.problem.collate_fn, pin_memory=self.config['dataloader']['pin_memory'], drop_last=self.config['dataloader']['drop_last'], timeout=self.config['dataloader']['timeout'], worker_init_fn=self.worker_init_fn) # Display sizes. if log: self.logger.info("Problem for '{}' loaded (size: {})".format( self.name, len(self.problem))) if (self.sampler is not None): self.logger.info( "Sampler for '{}' created (size: {})".format( self.name, len(self.sampler))) # Ok, success. return 0 except ConfigurationError as e: if log: self.logger.error( "Detected configuration error while creating the problem instance:\n {}" .format(e)) # Return error. return 1 except KeyError as e: if log: self.logger.error( "Detected key error while creating the problem instance: required key {} is missing" .format(e)) # Return error. return 1