コード例 #1
0
    def test_should_throw_on_invalid_model_trainer_config(self):
        config_dict = YamlConfigurationParser.parse_section(
            self.INVALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)

        assert_that(
            calling(ModelConfiguration.from_dict).with_args(
                self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME]),
            raises(InvalidConfigurationError))
コード例 #2
0
    def setUp(self) -> None:
        config_dict = YamlConfigurationParser.parse_section(
            self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)

        self._model_trainer_configs = [
            ModelConfiguration.from_dict(self.SIMPLE_NET_NAME,
                                         config_dict[self.SIMPLE_NET_NAME]),
            ModelConfiguration.from_dict(self.SIMPLE_NET_NAME_2,
                                         config_dict[self.SIMPLE_NET_NAME_2])
        ]
        self._run_config = RunConfiguration()
        self._model = nn.Conv2d(1, 32, (3, 3))
コード例 #3
0
ファイル: config.py プロジェクト: banctilrobitaille/kerosene
 def from_yml(cls, yml_file, yml_tag="visdom"):
     config = YamlConfigurationParser.parse_section(yml_file, yml_tag)
     return VisdomConfiguration.from_dict(config)
コード例 #4
0
ファイル: main.py プロジェクト: sami-ets/DeepNormalize
if __name__ == '__main__':
    # Basic settings
    logging.basicConfig(level=logging.INFO)
    torch.set_num_threads(multiprocessing.cpu_count())
    torch.set_num_interop_threads(multiprocessing.cpu_count())
    args = ArgsParserFactory.create_parser(
        ArgsParserType.MODEL_TRAINING).parse_args()

    # Create configurations.
    run_config = RunConfiguration(use_amp=args.use_amp,
                                  local_rank=args.local_rank,
                                  amp_opt_level=args.amp_opt_level)
    model_trainer_configs, training_config = YamlConfigurationParser.parse(
        args.config_file)
    dataset_configs = YamlConfigurationParser.parse_section(
        args.config_file, "dataset")
    dataset_configs = {
        k: DatasetConfiguration(v)
        for k, v, in dataset_configs.items()
    }
    data_augmentation_config = YamlConfigurationParser.parse_section(
        args.config_file, "data_augmentation")
    config_html = [
        training_config.to_html(),
        list(map(lambda config: config.to_html(), dataset_configs.values())),
        list(map(lambda config: config.to_html(), model_trainer_configs))
    ]

    # Prepare the data.
    train_datasets = list()
    valid_datasets = list()
コード例 #5
0
ファイル: main_cc.py プロジェクト: jizongFox/DeepNormalize
cudnn.enabled = True

np.random.seed(42)
random.seed(42)

if __name__ == '__main__':
    # Basic settings
    logging.basicConfig(level=logging.INFO)
    torch.set_num_threads(multiprocessing.cpu_count())
    torch.set_num_interop_threads(multiprocessing.cpu_count())
    args = ArgsParserFactory.create_parser(ArgsParserType.MODEL_TRAINING).parse_args()

    # Create configurations.
    run_config = RunConfiguration(use_amp=args.use_amp, local_rank=args.local_rank, amp_opt_level=args.amp_opt_level)
    model_trainer_configs, training_config = YamlConfigurationParser.parse(args.config_file)
    dataset_configs = YamlConfigurationParser.parse_section(args.config_file, "dataset")
    dataset_configs = {k: DatasetConfiguration(v) for k, v, in dataset_configs.items()}
    config_html = [training_config.to_html(), list(map(lambda config: config.to_html(), dataset_configs.values())),
                   list(map(lambda config: config.to_html(), model_trainer_configs))]

    # Prepare the data.
    train_datasets = list()
    valid_datasets = list()
    test_datasets = list()
    reconstruction_datasets = list()
    augmented_reconstruction_datasets = list()
    normalized_reconstructors = list()
    segmentation_reconstructors = list()
    input_reconstructors = list()
    gt_reconstructors = list()
    augmented_input_reconstructors = list()
コード例 #6
0
    def test_should_parse_valid_model_trainer_config(self):
        expected_config_dict = {
            self.SIMPLE_NET_NAME: {
                'type': self.SIMPLE_NET_TYPE,
                'optimizer': {
                    'type': self.SIMPLE_NET_OPTIMIZER_TYPE,
                    'params': self.SIMPLE_NET_OPTIMIZER_PARAMS
                },
                'scheduler': {
                    'type': self.SIMPLE_NET_SCHEDULER_TYPE,
                    'params': self.SIMPLE_NET_SCHEDULER_PARAMS
                },
                'criterion': {
                    "cycle": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE
                    },
                    "gan": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE_2
                    }
                },
                'metrics': {
                    'Dice': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_1,
                        'params': self.SIMPLE_NET_METRIC_PARAMS_1
                    },
                    'Accuracy': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_2
                    }
                },
                'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING
            },
            self.SIMPLE_NET_NAME_2: {
                'type': self.SIMPLE_NET_TYPE,
                'optimizer': {
                    'type': self.SIMPLE_NET_OPTIMIZER_TYPE,
                    'params': self.SIMPLE_NET_OPTIMIZER_PARAMS
                },
                'scheduler': {
                    'type': self.SIMPLE_NET_SCHEDULER_TYPE,
                    'params': self.SIMPLE_NET_SCHEDULER_PARAMS
                },
                'criterion': {
                    "cycle": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE
                    },
                    "gan": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE_2
                    }
                },
                'metrics': {
                    'Dice': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_1,
                        'params': self.SIMPLE_NET_METRIC_PARAMS_1
                    },
                    'Accuracy': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_2
                    }
                },
                'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING
            }
        }
        config_dict = YamlConfigurationParser.parse_section(
            self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)
        model_trainer_config = ModelConfiguration.from_dict(
            self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME])

        assert_that(config_dict, equal_to(expected_config_dict))

        assert_that(model_trainer_config.optimizer_config.type,
                    equal_to(self.SIMPLE_NET_OPTIMIZER_TYPE))
        assert_that(model_trainer_config.optimizer_config.params,
                    equal_to(self.SIMPLE_NET_OPTIMIZER_PARAMS))
        assert_that(model_trainer_config.scheduler_config.type,
                    equal_to(self.SIMPLE_NET_SCHEDULER_TYPE))
        assert_that(model_trainer_config.scheduler_config.params,
                    equal_to(self.SIMPLE_NET_SCHEDULER_PARAMS))
        assert_that(model_trainer_config.criterions_configs[0].type,
                    equal_to(self.SIMPLE_NET_CRITERION_TYPE))
        assert_that(model_trainer_config.criterions_configs[1].type,
                    equal_to(self.SIMPLE_NET_CRITERION_TYPE_2))
        assert_that(model_trainer_config.metrics_configs[0].type,
                    equal_to(self.SIMPLE_NET_METRIC_TYPE_1))
        assert_that(model_trainer_config.metrics_configs[0].params,
                    equal_to(self.SIMPLE_NET_METRIC_PARAMS_1))