def test_should_throw_on_invalid_model_trainer_config(self): config_dict = YamlConfigurationParser.parse_section( self.INVALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG) assert_that( calling(ModelConfiguration.from_dict).with_args( self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME]), raises(InvalidConfigurationError))
def setUp(self) -> None: config_dict = YamlConfigurationParser.parse_section( self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG) self._model_trainer_configs = [ ModelConfiguration.from_dict(self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME]), ModelConfiguration.from_dict(self.SIMPLE_NET_NAME_2, config_dict[self.SIMPLE_NET_NAME_2]) ] self._run_config = RunConfiguration() self._model = nn.Conv2d(1, 32, (3, 3))
def from_yml(cls, yml_file, yml_tag="visdom"): config = YamlConfigurationParser.parse_section(yml_file, yml_tag) return VisdomConfiguration.from_dict(config)
if __name__ == '__main__': # Basic settings logging.basicConfig(level=logging.INFO) torch.set_num_threads(multiprocessing.cpu_count()) torch.set_num_interop_threads(multiprocessing.cpu_count()) args = ArgsParserFactory.create_parser( ArgsParserType.MODEL_TRAINING).parse_args() # Create configurations. run_config = RunConfiguration(use_amp=args.use_amp, local_rank=args.local_rank, amp_opt_level=args.amp_opt_level) model_trainer_configs, training_config = YamlConfigurationParser.parse( args.config_file) dataset_configs = YamlConfigurationParser.parse_section( args.config_file, "dataset") dataset_configs = { k: DatasetConfiguration(v) for k, v, in dataset_configs.items() } data_augmentation_config = YamlConfigurationParser.parse_section( args.config_file, "data_augmentation") config_html = [ training_config.to_html(), list(map(lambda config: config.to_html(), dataset_configs.values())), list(map(lambda config: config.to_html(), model_trainer_configs)) ] # Prepare the data. train_datasets = list() valid_datasets = list()
cudnn.enabled = True np.random.seed(42) random.seed(42) if __name__ == '__main__': # Basic settings logging.basicConfig(level=logging.INFO) torch.set_num_threads(multiprocessing.cpu_count()) torch.set_num_interop_threads(multiprocessing.cpu_count()) args = ArgsParserFactory.create_parser(ArgsParserType.MODEL_TRAINING).parse_args() # Create configurations. run_config = RunConfiguration(use_amp=args.use_amp, local_rank=args.local_rank, amp_opt_level=args.amp_opt_level) model_trainer_configs, training_config = YamlConfigurationParser.parse(args.config_file) dataset_configs = YamlConfigurationParser.parse_section(args.config_file, "dataset") dataset_configs = {k: DatasetConfiguration(v) for k, v, in dataset_configs.items()} config_html = [training_config.to_html(), list(map(lambda config: config.to_html(), dataset_configs.values())), list(map(lambda config: config.to_html(), model_trainer_configs))] # Prepare the data. train_datasets = list() valid_datasets = list() test_datasets = list() reconstruction_datasets = list() augmented_reconstruction_datasets = list() normalized_reconstructors = list() segmentation_reconstructors = list() input_reconstructors = list() gt_reconstructors = list() augmented_input_reconstructors = list()
def test_should_parse_valid_model_trainer_config(self): expected_config_dict = { self.SIMPLE_NET_NAME: { 'type': self.SIMPLE_NET_TYPE, 'optimizer': { 'type': self.SIMPLE_NET_OPTIMIZER_TYPE, 'params': self.SIMPLE_NET_OPTIMIZER_PARAMS }, 'scheduler': { 'type': self.SIMPLE_NET_SCHEDULER_TYPE, 'params': self.SIMPLE_NET_SCHEDULER_PARAMS }, 'criterion': { "cycle": { 'type': self.SIMPLE_NET_CRITERION_TYPE }, "gan": { 'type': self.SIMPLE_NET_CRITERION_TYPE_2 } }, 'metrics': { 'Dice': { 'type': self.SIMPLE_NET_METRIC_TYPE_1, 'params': self.SIMPLE_NET_METRIC_PARAMS_1 }, 'Accuracy': { 'type': self.SIMPLE_NET_METRIC_TYPE_2 } }, 'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING }, self.SIMPLE_NET_NAME_2: { 'type': self.SIMPLE_NET_TYPE, 'optimizer': { 'type': self.SIMPLE_NET_OPTIMIZER_TYPE, 'params': self.SIMPLE_NET_OPTIMIZER_PARAMS }, 'scheduler': { 'type': self.SIMPLE_NET_SCHEDULER_TYPE, 'params': self.SIMPLE_NET_SCHEDULER_PARAMS }, 'criterion': { "cycle": { 'type': self.SIMPLE_NET_CRITERION_TYPE }, "gan": { 'type': self.SIMPLE_NET_CRITERION_TYPE_2 } }, 'metrics': { 'Dice': { 'type': self.SIMPLE_NET_METRIC_TYPE_1, 'params': self.SIMPLE_NET_METRIC_PARAMS_1 }, 'Accuracy': { 'type': self.SIMPLE_NET_METRIC_TYPE_2 } }, 'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING } } config_dict = YamlConfigurationParser.parse_section( self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG) model_trainer_config = ModelConfiguration.from_dict( self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME]) assert_that(config_dict, equal_to(expected_config_dict)) assert_that(model_trainer_config.optimizer_config.type, equal_to(self.SIMPLE_NET_OPTIMIZER_TYPE)) assert_that(model_trainer_config.optimizer_config.params, equal_to(self.SIMPLE_NET_OPTIMIZER_PARAMS)) assert_that(model_trainer_config.scheduler_config.type, equal_to(self.SIMPLE_NET_SCHEDULER_TYPE)) assert_that(model_trainer_config.scheduler_config.params, equal_to(self.SIMPLE_NET_SCHEDULER_PARAMS)) assert_that(model_trainer_config.criterions_configs[0].type, equal_to(self.SIMPLE_NET_CRITERION_TYPE)) assert_that(model_trainer_config.criterions_configs[1].type, equal_to(self.SIMPLE_NET_CRITERION_TYPE_2)) assert_that(model_trainer_config.metrics_configs[0].type, equal_to(self.SIMPLE_NET_METRIC_TYPE_1)) assert_that(model_trainer_config.metrics_configs[0].params, equal_to(self.SIMPLE_NET_METRIC_PARAMS_1))