def test_get_optimizer():
    optimizer_keys = ["adam", "adagrad", "sgd"]
    invalid_optimizer_keys = ["adddam"]

    for key in optimizer_keys:
        get_optimizer(key)

    for key_invalid in invalid_optimizer_keys:
        with pytest.raises(ValueError, match="Invalid optimizer_key:"):
            get_optimizer(key_invalid)
def test_get_optimizer():
    optimizer_keys = ['adam', 'adagrad', 'sgd']
    invalid_optimizer_keys = ['adddam']

    for key in optimizer_keys:
        get_optimizer(key)

    for key_invalid in invalid_optimizer_keys:
        with pytest.raises(ValueError, match='Invalid optimizer_key:'):
            get_optimizer(key_invalid)
    def __init__(self, params=None, json_path=None):
        super().__init__()
        self.params = get_from_dicts(params, default_parameters)
        self.params = get_from_json(json_path, self.params)
        self._sanity_check()
        logging.info("Model parameters : %s", self.params)
        self.input_type = self.params['input']['type']

        self.init_process_group()
        self.model = DistributedDataParallel(
            self.model_fn(self.params).to(get_device(self.params)))
        self.optimizer = get_optimizer(key=self.params['optimizer']['name'])(
            params=self.model.parameters(),
            **self.params['optimizer']['kwargs'])
        self.loss = get_loss_fn(key=self.params['loss']['name'])(
            **self.params['loss']['kwargs'])
        self.metric = get_metric_fn(key=self.params['output']['metric'])
Example #4
0
    def __init__(self, params=None, json_path=None):
        super().__init__()
        self.params = get_from_dicts(params, default_parameters)
        self.params = get_from_json(json_path, self.params)
        self._sanity_check()
        Path(self.params["output"]["save_model_dir"]).expanduser().resolve().mkdir(
            parents=True, exist_ok=True
        )
        logging.info("Model parameters : %s", self.params)
        self.input_type = self.params["input"]["type"]

        self.init_process_group()
        self.model = DistributedDataParallel(self.model_fn(self.params).to(get_device(self.params)))
        self.optimizer = get_optimizer(key=self.params["optimizer"]["name"])(
            params=self.model.parameters(), **self.params["optimizer"]["kwargs"]
        )
        self.loss = get_loss_fn(key=self.params["loss"]["name"])(**self.params["loss"]["kwargs"])
        self.metric = get_metric_fn(key=self.params["output"]["metric"])