def test_get_model_spec(self): ( model, dataset_fn, loss, optimizer, eval_metrics_fn, prediction_outputs_processor, ) = get_model_spec( model_zoo=_model_zoo_path, model_def="test_module.custom_model", dataset_fn="dataset_fn", loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", model_params="", prediction_outputs_processor="PredictionOutputsProcessor", ) self.assertTrue(model is not None) self.assertTrue(dataset_fn is not None) self.assertTrue(loss is not None) self.assertTrue(optimizer is not None) self.assertTrue(eval_metrics_fn is not None) self.assertTrue(prediction_outputs_processor is not None)
def __init__( self, worker_id, job_type, minibatch_size, model_zoo, dataset_fn="dataset_fn", loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", channel=None, embedding_service_endpoint=None, model_def=None, model_params="", prediction_outputs_processor="PredictionOutputsProcessor", max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM, get_model_steps=1, ): """ Arguments: model_file: A module to define the model channel: grpc channel max_minibatch_retry_num: The maximum number of a minibatch retry as its results (e.g. gradients) are not accepted by master. """ self._worker_id = worker_id self._job_type = job_type self._minibatch_size = minibatch_size ( self._model, self._dataset_fn, self._loss, self._opt_fn, self._eval_metrics_fn, self._prediction_outputs_processor, ) = get_model_spec( model_zoo=model_zoo, model_def=model_def, dataset_fn=dataset_fn, loss=loss, optimizer=optimizer, eval_metrics_fn=eval_metrics_fn, model_params=model_params, prediction_outputs_processor=prediction_outputs_processor, ) self._init_embedding_layer() self._var_created = self._model.built if channel is None: self._stub = None else: self._stub = elasticdl_pb2_grpc.MasterStub(channel) self._embedding_service_endpoint = embedding_service_endpoint self._max_minibatch_retry_num = max_minibatch_retry_num self._model_version = -1 self._task_data_service = TaskDataService( self, self._job_type == JobType.TRAINING_WITH_EVALUATION ) self._get_model_steps = get_model_steps
def __init__( self, worker_id, job_type, minibatch_size, model_zoo, dataset_fn="dataset_fn", loss="loss", optimizer="optimizer", eval_metrics_fn="eval_metrics_fn", channel=None, embedding_service_endpoint=None, model_def=None, model_params="", prediction_outputs_processor="PredictionOutputsProcessor", max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM, get_model_steps=1, ): """ Arguments: worker_id: The worker ID. job_type: The job type. minibatch_size: The size of the minibatch used for each iteration. model_zoo: The directory that contains user-defined model files or a specific model file. dataset_fn: The name of the dataset function defined in the model file. loss: The name of the loss function defined in the model file. optimizer: The name of the optimizer defined in the model file. eval_metrics_fn: The name of the evaluation metrics function defined in the model file. channel: The channel for the gRPC master service. embedding_service_endpoint: The endpoint to the embedding service. model_def: The import path to the model definition function/class in the model zoo, e.g. "cifar10_subclass.CustomModel". model_params: The dictionary of model parameters in a string that will be used to instantiate the model, e.g. "param1=1,param2=2". prediction_outputs_processor: The name of the prediction output processor class defined in the model file. get_model_steps: Worker will perform `get_model` from the parameter server every this many steps. max_minibatch_retry_num: The maximum number of a minibatch retry as its results (e.g. gradients) are not accepted by master. """ self._worker_id = worker_id self._job_type = job_type self._minibatch_size = minibatch_size ( self._model, self._dataset_fn, self._loss, self._opt_fn, self._eval_metrics_fn, self._prediction_outputs_processor, ) = get_model_spec( model_zoo=model_zoo, model_def=model_def, dataset_fn=dataset_fn, loss=loss, optimizer=optimizer, eval_metrics_fn=eval_metrics_fn, model_params=model_params, prediction_outputs_processor=prediction_outputs_processor, ) self._init_embedding_layer() self._var_created = self._model.built if channel is None: self._stub = None else: self._stub = elasticdl_pb2_grpc.MasterStub(channel) self._embedding_service_endpoint = embedding_service_endpoint self._max_minibatch_retry_num = max_minibatch_retry_num self._model_version = -1 self._task_data_service = TaskDataService( self, self._job_type == JobType.TRAINING_WITH_EVALUATION) self._get_model_steps = get_model_steps