def test_get_metric_fn():
    metric_keys = ['f1_score', 'accuracy', 'roc_auc', 'precision', 'recall']
    invalid_metric_keys = ['NotExistMetric']

    for key in metric_keys:
        get_metric_fn(key)

    for key_invalid in invalid_metric_keys:
        with pytest.raises(ValueError, match='Invalid metric_key:'):
            get_metric_fn(key_invalid)
Example #2
0
def test_get_metric_fn():
    metric_keys = ["f1_score", "accuracy", "roc_auc", "precision", "recall"]
    invalid_metric_keys = ["NotExistMetric"]

    for key in metric_keys:
        get_metric_fn(key)

    for key_invalid in invalid_metric_keys:
        with pytest.raises(ValueError, match="Invalid metric_key:"):
            get_metric_fn(key_invalid)
    def __init__(self, params=None, json_path=None):
        super().__init__()
        self.params = get_from_dicts(params, default_parameters)
        self.params = get_from_json(json_path, self.params)
        self._sanity_check()
        logging.info("Model parameters : %s", self.params)
        self.input_type = self.params['input']['type']

        self.init_process_group()
        self.model = DistributedDataParallel(
            self.model_fn(self.params).to(get_device(self.params)))
        self.optimizer = get_optimizer(key=self.params['optimizer']['name'])(
            params=self.model.parameters(),
            **self.params['optimizer']['kwargs'])
        self.loss = get_loss_fn(key=self.params['loss']['name'])(
            **self.params['loss']['kwargs'])
        self.metric = get_metric_fn(key=self.params['output']['metric'])
Example #4
0
    def __init__(self, params=None, json_path=None):
        super().__init__()
        self.params = get_from_dicts(params, default_parameters)
        self.params = get_from_json(json_path, self.params)
        self._sanity_check()
        Path(self.params["output"]["save_model_dir"]).expanduser().resolve().mkdir(
            parents=True, exist_ok=True
        )
        logging.info("Model parameters : %s", self.params)
        self.input_type = self.params["input"]["type"]

        self.init_process_group()
        self.model = DistributedDataParallel(self.model_fn(self.params).to(get_device(self.params)))
        self.optimizer = get_optimizer(key=self.params["optimizer"]["name"])(
            params=self.model.parameters(), **self.params["optimizer"]["kwargs"]
        )
        self.loss = get_loss_fn(key=self.params["loss"]["name"])(**self.params["loss"]["kwargs"])
        self.metric = get_metric_fn(key=self.params["output"]["metric"])