Esempio n. 1
0
 def __init__(self, model_params=None, json_path=None):
     super().__init__()
     self.model_params = get_from_dicts(model_params, default_parameters)
     self.model_params = get_from_json(json_path, self.model_params)
     self._sanity_checks()
     logging.info("Model parameters : %s", self.model_params)
     self.input_type = self.model_params['input']['type']
     self.model_dir = self.model_params['output']['save_model_dir']
     self.config = get_tf_config(self.model_params)
     self.model = tf.estimator.Estimator(
         model_fn=self.model_fn, model_dir=self.model_dir,
         params=self.model_params, config=self.config)
Esempio n. 2
0
    def __init__(self, params=None, json_path=None):
        super().__init__()
        self.params = get_from_dicts(params, default_parameters)
        self.params = get_from_json(json_path, self.params)
        self._sanity_check()
        logging.info("Model parameters : %s", self.params)
        self.input_type = self.params['input']['type']

        self.init_process_group()
        self.model = DistributedDataParallel(
            self.model_fn(self.params).to(get_device(self.params)))
        self.optimizer = get_optimizer(key=self.params['optimizer']['name'])(
            params=self.model.parameters(),
            **self.params['optimizer']['kwargs'])
        self.loss = get_loss_fn(key=self.params['loss']['name'])(
            **self.params['loss']['kwargs'])
        self.metric = get_metric_fn(key=self.params['output']['metric'])
Esempio n. 3
0
    def __init__(self, params=None, json_path=None):
        super().__init__()
        self.params = get_from_dicts(params, default_parameters)
        self.params = get_from_json(json_path, self.params)
        self._sanity_check()
        Path(self.params["output"]["save_model_dir"]).expanduser().resolve().mkdir(
            parents=True, exist_ok=True
        )
        logging.info("Model parameters : %s", self.params)
        self.input_type = self.params["input"]["type"]

        self.init_process_group()
        self.model = DistributedDataParallel(self.model_fn(self.params).to(get_device(self.params)))
        self.optimizer = get_optimizer(key=self.params["optimizer"]["name"])(
            params=self.model.parameters(), **self.params["optimizer"]["kwargs"]
        )
        self.loss = get_loss_fn(key=self.params["loss"]["name"])(**self.params["loss"]["kwargs"])
        self.metric = get_metric_fn(key=self.params["output"]["metric"])
Esempio n. 4
0
def test_merge_dicts():
    params = {"learning_rate": 0.05}
    default_params = {"learning_rate": 0.08, "embedding_size": 256}
    final = get_from_dicts(params, default_params)
    assert final['learning_rate'] == 0.05
    assert final['embedding_size'] == 256