def test_get_optimizer(): optimizer_keys = ["adam", "adagrad", "sgd"] invalid_optimizer_keys = ["adddam"] for key in optimizer_keys: get_optimizer(key) for key_invalid in invalid_optimizer_keys: with pytest.raises(ValueError, match="Invalid optimizer_key:"): get_optimizer(key_invalid)
def test_get_optimizer(): optimizer_keys = ['adam', 'adagrad', 'sgd'] invalid_optimizer_keys = ['adddam'] for key in optimizer_keys: get_optimizer(key) for key_invalid in invalid_optimizer_keys: with pytest.raises(ValueError, match='Invalid optimizer_key:'): get_optimizer(key_invalid)
def __init__(self, params=None, json_path=None): super().__init__() self.params = get_from_dicts(params, default_parameters) self.params = get_from_json(json_path, self.params) self._sanity_check() logging.info("Model parameters : %s", self.params) self.input_type = self.params['input']['type'] self.init_process_group() self.model = DistributedDataParallel( self.model_fn(self.params).to(get_device(self.params))) self.optimizer = get_optimizer(key=self.params['optimizer']['name'])( params=self.model.parameters(), **self.params['optimizer']['kwargs']) self.loss = get_loss_fn(key=self.params['loss']['name'])( **self.params['loss']['kwargs']) self.metric = get_metric_fn(key=self.params['output']['metric'])
def __init__(self, params=None, json_path=None): super().__init__() self.params = get_from_dicts(params, default_parameters) self.params = get_from_json(json_path, self.params) self._sanity_check() Path(self.params["output"]["save_model_dir"]).expanduser().resolve().mkdir( parents=True, exist_ok=True ) logging.info("Model parameters : %s", self.params) self.input_type = self.params["input"]["type"] self.init_process_group() self.model = DistributedDataParallel(self.model_fn(self.params).to(get_device(self.params))) self.optimizer = get_optimizer(key=self.params["optimizer"]["name"])( params=self.model.parameters(), **self.params["optimizer"]["kwargs"] ) self.loss = get_loss_fn(key=self.params["loss"]["name"])(**self.params["loss"]["kwargs"]) self.metric = get_metric_fn(key=self.params["output"]["metric"])