Esempio n. 1
0
 def test_get_dict_from_params_str(self):
     self.assertEqual(
         get_dict_from_params_str('ls=["a", "b"]'), {"ls": ["a", "b"]}
     )
     self.assertEqual(
         get_dict_from_params_str('ls=["a", "b"]; d={"a": 3}'),
         {"ls": ["a", "b"], "d": {"a": 3}},
     )
     self.assertEqual(get_dict_from_params_str(""), None)
Esempio n. 2
0
 def _maybe_create_shards(data_origin):
     kwargs = get_dict_from_params_str(data_reader_params)
     partition = kwargs.get("partition", None) if kwargs else None
     return (create_data_reader_fn(
         data_origin=data_origin,
         records_per_task=records_per_task,
         partition=partition,
     ).create_shards() if data_origin else {})
Esempio n. 3
0
 def _init_task_data_service(self, args):
     self._task_data_service = TaskDataService(
         self._data_shard_service,
         custom_data_reader=self._custom_data_reader,
         data_reader_params=get_dict_from_params_str(
             args.data_reader_params),
         data_origin=args.training_data,
     )
Esempio n. 4
0
    def _init_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._worker_id = args.worker_id
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
            self._custom_data_reader,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor=args.prediction_outputs_processor,
            custom_data_reader=args.custom_data_reader,
        )

        self._collective_communicator = (
            CollectiveCommunicator() if self._distribution_strategy
            == DistributionStrategy.ALLREDUCE else None)
        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        model_inst = self._model_handler.get_model_to_train(model_inst)
        self.set_model(model_inst)

        self._model_version = -1
        self._model_versions_from_ps = [-1 for _ in range(self._ps_num)]
        self._task_data_service = TaskDataService(
            self,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            data_reader_params=get_dict_from_params_str(
                args.data_reader_params),
        )
        if self._dataset_fn is None:
            if hasattr(self._task_data_service.data_reader,
                       "default_dataset_fn"):
                self._dataset_fn = (
                    self._task_data_service.data_reader.default_dataset_fn())
            else:
                raise ValueError(
                    "dataset_fn is required if the data_reader used does "
                    "not provide default implementation of dataset_fn")
        self._get_model_steps = args.get_model_steps
        if self._get_model_steps > 1:
            self._opt = self._opt_fn()
        self._non_embed_grads = {}
        self._evaluation_result = {}
Esempio n. 5
0
 def test_get_dict_from_params_str(self):
     self.assertEqual(get_dict_from_params_str('ls=["a", "b"]'),
                      {"ls": ["a", "b"]})
     self.assertEqual(
         get_dict_from_params_str('ls=["a", "b"]; d={"a": 3}'),
         {
             "ls": ["a", "b"],
             "d": {
                 "a": 3
             }
         },
     )
     self.assertEqual(
         get_dict_from_params_str('ls=["a", "b"];partition=dt=20190011'),
         {
             "ls": ["a", "b"],
             "partition": "dt=20190011"
         },
     )
     self.assertEqual(get_dict_from_params_str(""), None)
Esempio n. 6
0
    def _init_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._worker_id = args.worker_id
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor=args.prediction_outputs_processor,
        )

        self._embedding_service_endpoint = eval(
            args.embedding_service_endpoint)

        self._distribution_strategy = args.distribution_strategy
        self._collective_communicator = (
            CollectiveCommunicator() if self._distribution_strategy
            == DistributionStrategy.ALLREDUCE else None)
        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        model_inst = self._model_handler.get_model_to_train(model_inst)
        self.set_model(model_inst)

        self._model_version = -1
        self._task_data_service = TaskDataService(
            self,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            data_reader_params=get_dict_from_params_str(
                args.data_reader_params),
        )
        self._get_model_steps = args.get_model_steps
        if self._get_model_steps > 1:
            self._opt = self._opt_fn()
            self._non_embed_grads = None
        self._evaluation_result = {}
Esempio n. 7
0
    def __init__(self, args):
        envs = parse_envs(args.envs)
        self._init_environment(envs)

        (
            self.model_inst,
            self.dataset_fn,
            self.loss_fn,
            self.opt_fn,
            self.eval_metrics_fn,
            self.prediction_outputs_processor,
            self.custom_data_reader,
            self.callback_list,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor="",
            custom_data_reader=args.custom_data_reader,
            callbacks=args.callbacks,
        )
        self.opt = self.opt_fn()
        self.epoch = args.num_epochs
        self.evaluation_steps = args.evaluation_steps
        self.batch_size = args.minibatch_size
        self.data_reader_params = get_dict_from_params_str(
            args.data_reader_params
        )
        self.records_per_task = (
            args.minibatch_size * args.num_minibatches_per_task
        )

        create_data_reader_fn = (
            create_data_reader
            if self.custom_data_reader is None
            else self.custom_data_reader
        )
        self.data_reader = create_data_reader_fn(
            data_origin=args.training_data,
            records_per_task=self.records_per_task,
            **self.data_reader_params
        )
        self.training_data = args.training_data
        self.validation_data = args.validation_data
        self.save_model_dir = args.output
Esempio n. 8
0
    def _create_instance_manager(self, args):
        instance_manager = None

        container_command = ["/bin/bash"]
        if args.num_workers:
            assert args.worker_image, "Worker image cannot be empty"

            worker_client_command = (
                BashCommandTemplate.SET_PIPEFAIL
                + " python -m elasticdl.python.worker.main"
            )
            worker_args = [
                "--master_addr",
                self.master_addr,
                "--job_type",
                self.job_type,
            ]
            worker_args.extend(
                build_arguments_from_parsed_result(args, filter_args=["envs"])
            )
            worker_args = wrap_python_args_with_string(worker_args)
            worker_args.insert(0, worker_client_command)

            if args.use_go_ps:
                opt_type, opt_args = get_optimizer_info(self.optimizer)
                ps_command = "elasticdl_ps"
                ps_command_args = [
                    "-job_name=" + args.job_name,
                    "-namespace=" + args.namespace,
                    "-master_addr=" + self.master_addr,
                    "-port=2222",
                    "-use_async=" + ("true" if args.use_async else "false"),
                    "-grads_to_wait=" + str(args.grads_to_wait),
                    "-lr_staleness_modulation="
                    + ("true" if args.lr_staleness_modulation else "false"),
                    "-sync_version_tolerance="
                    + str(args.sync_version_tolerance),
                    "-evaluation_steps=" + str(args.evaluation_steps),
                    "-num_ps_pods=" + str(args.num_ps_pods),
                    "-num_workers=" + str(args.num_workers),
                    "-checkpoint_dir=" + str(args.checkpoint_dir),
                    "-checkpoint_steps=" + str(args.checkpoint_steps),
                    "-keep_checkpoint_max=" + str(args.keep_checkpoint_max),
                    "-checkpoint_dir_for_init="
                    + str(args.checkpoint_dir_for_init),
                    "-opt_type=" + opt_type,
                    "-opt_args=" + opt_args,
                ]
                ps_command_args = wrap_go_args_with_string(ps_command_args)
                # Execute source /root/.bashrc to add the file path
                # of `elasticdl_ps` into the PATH environment variable.
                ps_args = [
                    "source",
                    "/root/.bashrc_elasticdl",
                    "&&",
                    ps_command,
                ]
                ps_args.extend(ps_command_args)
            else:
                ps_command = (
                    BashCommandTemplate.SET_PIPEFAIL
                    + " python -m elasticdl.python.ps.main"
                )
                ps_command_args = [
                    "--grads_to_wait",
                    str(args.grads_to_wait),
                    "--lr_staleness_modulation",
                    str(args.lr_staleness_modulation),
                    "--sync_version_tolerance",
                    str(args.sync_version_tolerance),
                    "--use_async",
                    str(args.use_async),
                    "--model_zoo",
                    args.model_zoo,
                    "--model_def",
                    args.model_def,
                    "--job_name",
                    args.job_name,
                    "--port",
                    "2222",
                    "--master_addr",
                    self.master_addr,
                    "--namespace",
                    args.namespace,
                    "--evaluation_steps",
                    str(args.evaluation_steps),
                    "--checkpoint_dir",
                    str(args.checkpoint_dir),
                    "--checkpoint_steps",
                    str(args.checkpoint_steps),
                    "--keep_checkpoint_max",
                    str(args.keep_checkpoint_max),
                    "--num_ps_pods",
                    str(args.num_ps_pods),
                    "--checkpoint_dir_for_init",
                    str(args.checkpoint_dir_for_init),
                    "--num_workers",
                    str(args.num_workers),
                    "--log_level",
                    str(args.log_level),
                    "--minibatch_size",
                    str(args.minibatch_size),
                    "--num_minibatches_per_task",
                    str(args.num_minibatches_per_task),
                ]
                ps_args = wrap_python_args_with_string(ps_command_args)
                ps_args.insert(0, ps_command)

            worker_args = ["-c", " ".join(worker_args)]
            ps_args = ["-c", " ".join(ps_args)]

            env_dict = parse_envs(args.envs)
            env = []
            for key in env_dict:
                env.append(V1EnvVar(name=key, value=env_dict[key]))

            kwargs = get_dict_from_params_str(args.aux_params)
            disable_relaunch = kwargs.get("disable_relaunch", False)
            cluster_spec = self._get_image_cluster_spec(args.cluster_spec)

            instance_manager = InstanceManager(
                self.task_d,
                rendezvous_server=self.rendezvous_server,
                job_name=args.job_name,
                image_name=args.worker_image,
                worker_command=container_command,
                worker_args=worker_args,
                namespace=args.namespace,
                num_workers=args.num_workers,
                worker_resource_request=args.worker_resource_request,
                worker_resource_limit=args.worker_resource_limit,
                worker_pod_priority=args.worker_pod_priority,
                num_ps=args.num_ps_pods,
                ps_command=container_command,
                ps_args=ps_args,
                ps_resource_request=args.ps_resource_request,
                ps_resource_limit=args.ps_resource_limit,
                ps_pod_priority=args.ps_pod_priority,
                volume=args.volume,
                image_pull_policy=args.image_pull_policy,
                restart_policy=args.restart_policy,
                cluster_spec=cluster_spec,
                envs=env,
                disable_relaunch=disable_relaunch,
                log_file_path=args.log_file_path,
            )

        return instance_manager
Esempio n. 9
0
    def _init_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._worker_id = args.worker_id
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        self._log_loss_steps = args.log_loss_steps
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
            self._custom_data_reader,
            self._callbacks_list,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor=args.prediction_outputs_processor,
            custom_data_reader=args.custom_data_reader,
            callbacks=args.callbacks,
        )

        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        model_inst = self._model_handler.get_model_to_train(model_inst)
        self.set_model(model_inst)

        self._model_version = -1
        self._task_data_service = TaskDataService(
            self._mc,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            custom_data_reader=self._custom_data_reader,
            data_reader_params=get_dict_from_params_str(
                args.data_reader_params),
            data_origin=args.training_data,
        )
        if self._dataset_fn is None:
            if hasattr(self._task_data_service.data_reader,
                       "default_dataset_fn"):
                self._dataset_fn = (
                    self._task_data_service.data_reader.default_dataset_fn())
            else:
                raise ValueError(
                    "dataset_fn is required if the data_reader used does "
                    "not provide default implementation of dataset_fn")
        self._get_model_steps = args.get_model_steps
        self._opt = self._opt_fn()
        self._model.optimizer = self._opt
        self._non_embed_grads = {}
        self._evaluation_result = {}

        saved_model_exporter = SavedModelExporter(self._task_data_service,
                                                  self._dataset_fn,
                                                  self._model_handler)
        # Place default callbacks at the head to execute them firstly
        self._callbacks_list.callbacks.insert(0, saved_model_exporter)
        self._callbacks_list.set_model(model_inst)
        set_callback_parameters(
            self._callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )

        self._allreduce_trainer = None
        if self._distribution_strategy == DistributionStrategy.ALLREDUCE:
            master_addr = args.master_addr.split(":")[0]
            self._allreduce_trainer = AllReduceTrainer(self._mc, master_addr,
                                                       self._model, self._loss,
                                                       self._opt)
Esempio n. 10
0
    def __init__(
        self,
        worker_id,
        job_type,
        minibatch_size,
        model_zoo,
        dataset_fn="dataset_fn",
        loss="loss",
        optimizer="optimizer",
        eval_metrics_fn="eval_metrics_fn",
        channel=None,
        embedding_service_endpoint=None,
        model_def=None,
        model_params="",
        data_reader_params="",
        prediction_outputs_processor="PredictionOutputsProcessor",
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        get_model_steps=1,
    ):
        """
        Arguments:
            worker_id: The worker ID.
            job_type: The job type.
            minibatch_size: The size of the minibatch used for each iteration.
            model_zoo: The directory that contains user-defined model files
                or a specific model file.
            dataset_fn: The name of the dataset function defined in the
                model file.
            loss: The name of the loss function defined in the model file.
            optimizer: The name of the optimizer defined in the model file.
            eval_metrics_fn: The name of the evaluation metrics function
                defined in the model file.
            channel: The channel for the gRPC master service.
            embedding_service_endpoint: The endpoint to the embedding service.
            model_def: The import path to the model definition
                function/class in the model zoo, e.g.
                "cifar10_subclass.CustomModel".
            model_params: The dictionary of model parameters in a string
                separated by semi-colon used to instantiate the model,
                e.g. "param1=1; param2=2".
            data_reader_params: The data reader parameters in a string
                separated by semi-colon used to instantiate the data reader,
                e.g. "param1=1; param2=2".
            prediction_outputs_processor: The name of the prediction output
                processor class defined in the model file.
            get_model_steps: Worker will perform `get_model` from the
                parameter server every this many steps.
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
        """
        self._worker_id = worker_id
        self._job_type = job_type
        self._minibatch_size = minibatch_size
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
        ) = get_model_spec(
            model_zoo=model_zoo,
            model_def=model_def,
            dataset_fn=dataset_fn,
            loss=loss,
            optimizer=optimizer,
            eval_metrics_fn=eval_metrics_fn,
            model_params=model_params,
            prediction_outputs_processor=prediction_outputs_processor,
        )
        self._embedding_service_endpoint = embedding_service_endpoint
        self.set_model(model_inst)

        if channel is None:
            self._stub = None
        else:
            self._stub = elasticdl_pb2_grpc.MasterStub(channel)
        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._model_version = -1
        self._task_data_service = TaskDataService(
            self,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            data_reader_params=get_dict_from_params_str(data_reader_params),
        )
        self._get_model_steps = get_model_steps
        if self._get_model_steps > 1:
            self._opt = self._opt_fn()
            self._non_embed_grads = None