Exemplo n.º 1
0
    def test_should_throw_on_invalid_model_trainer_config(self):
        config_dict = YamlConfigurationParser.parse_section(
            self.INVALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)

        assert_that(
            calling(ModelConfiguration.from_dict).with_args(
                self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME]),
            raises(InvalidConfigurationError))
Exemplo n.º 2
0
    def setUp(self) -> None:
        config_dict = YamlConfigurationParser.parse_section(
            self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)

        self._model_trainer_configs = [
            ModelConfiguration.from_dict(self.SIMPLE_NET_NAME,
                                         config_dict[self.SIMPLE_NET_NAME]),
            ModelConfiguration.from_dict(self.SIMPLE_NET_NAME_2,
                                         config_dict[self.SIMPLE_NET_NAME_2])
        ]
        self._run_config = RunConfiguration()
        self._model = nn.Conv2d(1, 32, (3, 3))
Exemplo n.º 3
0
 def from_yml(cls, yml_file, yml_tag="visdom"):
     config = YamlConfigurationParser.parse_section(yml_file, yml_tag)
     return VisdomConfiguration.from_dict(config)
Exemplo n.º 4
0
np.random.seed(43)
random.seed(43)

if __name__ == '__main__':
    # Basic settings
    logging.basicConfig(level=logging.INFO)
    torch.set_num_threads(multiprocessing.cpu_count())
    torch.set_num_interop_threads(multiprocessing.cpu_count())
    args = ArgsParserFactory.create_parser(
        ArgsParserType.MODEL_TRAINING).parse_args()

    # Create configurations.
    run_config = RunConfiguration(use_amp=args.use_amp,
                                  local_rank=args.local_rank,
                                  amp_opt_level=args.amp_opt_level)
    model_trainer_configs, training_config = YamlConfigurationParser.parse(
        args.config_file)
    dataset_configs = YamlConfigurationParser.parse_section(
        args.config_file, "dataset")
    dataset_configs = {
        k: DatasetConfiguration(v)
        for k, v, in dataset_configs.items()
    }
    data_augmentation_config = YamlConfigurationParser.parse_section(
        args.config_file, "data_augmentation")
    config_html = [
        training_config.to_html(),
        list(map(lambda config: config.to_html(), dataset_configs.values())),
        list(map(lambda config: config.to_html(), model_trainer_configs))
    ]

    # Prepare the data.
Exemplo n.º 5
0
        parser.add_argument("--local_rank",
                            dest="local_rank",
                            default=0,
                            type=int,
                            help="The local_rank of the GPU.")
        return parser


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    CONFIG_FILE_PATH = "config.yml"
    args = ArgsParserFactory.create_parser().parse_args()
    run_config = RunConfiguration(args.use_amp, args.amp_opt_level,
                                  args.local_rank)

    model_trainer_config, training_config = YamlConfigurationParser.parse(
        CONFIG_FILE_PATH)

    # Initialize the dataset. This is the only part the user must define manually.
    train_dataset = torchvision.datasets.MNIST(
        './files/',
        train=True,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))]))
    test_dataset = torchvision.datasets.MNIST(
        './files/',
        train=False,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))]))
Exemplo n.º 6
0
    def test_should_parse_valid_model_trainer_config(self):
        expected_config_dict = {
            self.SIMPLE_NET_NAME: {
                'type': self.SIMPLE_NET_TYPE,
                'optimizer': {
                    'type': self.SIMPLE_NET_OPTIMIZER_TYPE,
                    'params': self.SIMPLE_NET_OPTIMIZER_PARAMS
                },
                'scheduler': {
                    'type': self.SIMPLE_NET_SCHEDULER_TYPE,
                    'params': self.SIMPLE_NET_SCHEDULER_PARAMS
                },
                'criterion': {
                    "cycle": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE
                    },
                    "gan": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE_2
                    }
                },
                'metrics': {
                    'Dice': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_1,
                        'params': self.SIMPLE_NET_METRIC_PARAMS_1
                    },
                    'Accuracy': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_2
                    }
                },
                'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING
            },
            self.SIMPLE_NET_NAME_2: {
                'type': self.SIMPLE_NET_TYPE,
                'optimizer': {
                    'type': self.SIMPLE_NET_OPTIMIZER_TYPE,
                    'params': self.SIMPLE_NET_OPTIMIZER_PARAMS
                },
                'scheduler': {
                    'type': self.SIMPLE_NET_SCHEDULER_TYPE,
                    'params': self.SIMPLE_NET_SCHEDULER_PARAMS
                },
                'criterion': {
                    "cycle": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE
                    },
                    "gan": {
                        'type': self.SIMPLE_NET_CRITERION_TYPE_2
                    }
                },
                'metrics': {
                    'Dice': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_1,
                        'params': self.SIMPLE_NET_METRIC_PARAMS_1
                    },
                    'Accuracy': {
                        'type': self.SIMPLE_NET_METRIC_TYPE_2
                    }
                },
                'gradients': self.SIMPLE_NET_GRADIENT_CLIPPING
            }
        }
        config_dict = YamlConfigurationParser.parse_section(
            self.VALID_CONFIG_FILE_PATH, self.MODELS_CONFIG_YML_TAG)
        model_trainer_config = ModelConfiguration.from_dict(
            self.SIMPLE_NET_NAME, config_dict[self.SIMPLE_NET_NAME])

        assert_that(config_dict, equal_to(expected_config_dict))

        assert_that(model_trainer_config.optimizer_config.type,
                    equal_to(self.SIMPLE_NET_OPTIMIZER_TYPE))
        assert_that(model_trainer_config.optimizer_config.params,
                    equal_to(self.SIMPLE_NET_OPTIMIZER_PARAMS))
        assert_that(model_trainer_config.scheduler_config.type,
                    equal_to(self.SIMPLE_NET_SCHEDULER_TYPE))
        assert_that(model_trainer_config.scheduler_config.params,
                    equal_to(self.SIMPLE_NET_SCHEDULER_PARAMS))
        assert_that(model_trainer_config.criterions_configs[0].type,
                    equal_to(self.SIMPLE_NET_CRITERION_TYPE))
        assert_that(model_trainer_config.criterions_configs[1].type,
                    equal_to(self.SIMPLE_NET_CRITERION_TYPE_2))
        assert_that(model_trainer_config.metrics_configs[0].type,
                    equal_to(self.SIMPLE_NET_METRIC_TYPE_1))
        assert_that(model_trainer_config.metrics_configs[0].params,
                    equal_to(self.SIMPLE_NET_METRIC_PARAMS_1))