예제 #1
0
    def _init_model_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._log_loss_steps = args.log_loss_steps
        (
            model_inst,
            self._feed,
            loss,
            opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
            self._custom_data_reader,
            self._callbacks_list,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            feed=args.feed,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            prediction_outputs_processor=args.prediction_outputs_processor,
            custom_data_reader=args.custom_data_reader,
            callbacks=args.callbacks,
        )

        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        self._model_inst = self._model_handler.get_model_to_train(model_inst)
        self._model_inst.optimizer = opt_fn()
        self._model_inst.loss = loss
        self._model_version = -1
        self._get_model_steps = args.get_model_steps
예제 #2
0
    def test_on_train_end(self):
        worker = MockWorker()
        task_data_service = TaskDataService(worker,
                                            JobType.TRAINING_WITH_EVALUATION)
        dataset = tf.data.Dataset.from_tensor_slices(
            np.array([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]))
        task_data_service._pending_train_end_callback_task = (
            "",
            0,
            1,
            elasticdl_pb2.TRAIN_END_CALLBACK,
        )
        task_data_service.get_dataset_by_task = mock.Mock(return_value=dataset)

        with tempfile.TemporaryDirectory() as temp_dir_name:
            checkpoint_dir = os.path.join(temp_dir_name, "checkpoint")
            model = custom_model_with_embedding_layer()
            save_checkpoint_without_embedding(model, checkpoint_dir)
            model_handler = ModelHandler.get_model_handler(
                distribution_strategy=DistributionStrategy.PARAMETER_SERVER,
                checkpoint_dir=checkpoint_dir,
            )
            saved_model_exporter = SavedModelExporter(task_data_service,
                                                      dataset_fn,
                                                      model_handler)
            saved_model_path = os.path.join(temp_dir_name, "test_exporter")
            params = {"batch_size": 10, "saved_model_path": saved_model_path}
            saved_model_exporter.set_params(params)
            saved_model_exporter.set_model(model)
            saved_model_exporter.on_train_end()
            self.assertTrue(os.path.exists(saved_model_path))
            self.assertTrue(
                os.path.exists(os.path.join(saved_model_path,
                                            "saved_model.pb")))
예제 #3
0
    def _init_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._worker_id = args.worker_id
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
            self._custom_data_reader,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor=args.prediction_outputs_processor,
            custom_data_reader=args.custom_data_reader,
        )

        self._collective_communicator = (
            CollectiveCommunicator() if self._distribution_strategy
            == DistributionStrategy.ALLREDUCE else None)
        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        model_inst = self._model_handler.get_model_to_train(model_inst)
        self.set_model(model_inst)

        self._model_version = -1
        self._model_versions_from_ps = [-1 for _ in range(self._ps_num)]
        self._task_data_service = TaskDataService(
            self,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            data_reader_params=get_dict_from_params_str(
                args.data_reader_params),
        )
        if self._dataset_fn is None:
            if hasattr(self._task_data_service.data_reader,
                       "default_dataset_fn"):
                self._dataset_fn = (
                    self._task_data_service.data_reader.default_dataset_fn())
            else:
                raise ValueError(
                    "dataset_fn is required if the data_reader used does "
                    "not provide default implementation of dataset_fn")
        self._get_model_steps = args.get_model_steps
        if self._get_model_steps > 1:
            self._opt = self._opt_fn()
        self._non_embed_grads = {}
        self._evaluation_result = {}
예제 #4
0
    def _init_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._worker_id = args.worker_id
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor=args.prediction_outputs_processor,
        )

        self._embedding_service_endpoint = eval(
            args.embedding_service_endpoint)

        self._distribution_strategy = args.distribution_strategy
        self._collective_communicator = (
            CollectiveCommunicator() if self._distribution_strategy
            == DistributionStrategy.ALLREDUCE else None)
        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        model_inst = self._model_handler.get_model_to_train(model_inst)
        self.set_model(model_inst)

        self._model_version = -1
        self._task_data_service = TaskDataService(
            self,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            data_reader_params=get_dict_from_params_str(
                args.data_reader_params),
        )
        self._get_model_steps = args.get_model_steps
        if self._get_model_steps > 1:
            self._opt = self._opt_fn()
            self._non_embed_grads = None
        self._evaluation_result = {}
예제 #5
0
 def setUp(self):
     tf.keras.backend.clear_session()
     self.master = MasterServicer(
         2,
         3,
         None,
         None,
         init_var=[],
         checkpoint_filename_for_init="",
         checkpoint_service=CheckpointService("", 0, 0, False),
         evaluation_service=None,
     )
     self.master._version = 1
     self.model_handler = ModelHandler.get_model_handler(
         distribution_strategy="ParameterServerStrategy", stub=self.master)
예제 #6
0
    def _init_from_args(self, args):
        """
        Please refer to elastic/python/common/args.py for more
        details about arguments of a worker.
        """
        self._worker_id = args.worker_id
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        self._log_loss_steps = args.log_loss_steps
        (
            model_inst,
            self._dataset_fn,
            self._loss,
            self._opt_fn,
            self._eval_metrics_fn,
            self._prediction_outputs_processor,
            self._custom_data_reader,
            self._callbacks_list,
        ) = get_model_spec(
            model_zoo=args.model_zoo,
            model_def=args.model_def,
            dataset_fn=args.dataset_fn,
            loss=args.loss,
            optimizer=args.optimizer,
            eval_metrics_fn=args.eval_metrics_fn,
            model_params=args.model_params,
            prediction_outputs_processor=args.prediction_outputs_processor,
            custom_data_reader=args.custom_data_reader,
            callbacks=args.callbacks,
        )

        self._model_handler = ModelHandler.get_model_handler(
            self._distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        model_inst = self._model_handler.get_model_to_train(model_inst)
        self.set_model(model_inst)

        self._model_version = -1
        self._task_data_service = TaskDataService(
            self._mc,
            self._job_type == JobType.TRAINING_WITH_EVALUATION,
            custom_data_reader=self._custom_data_reader,
            data_reader_params=get_dict_from_params_str(
                args.data_reader_params),
            data_origin=args.training_data,
        )
        if self._dataset_fn is None:
            if hasattr(self._task_data_service.data_reader,
                       "default_dataset_fn"):
                self._dataset_fn = (
                    self._task_data_service.data_reader.default_dataset_fn())
            else:
                raise ValueError(
                    "dataset_fn is required if the data_reader used does "
                    "not provide default implementation of dataset_fn")
        self._get_model_steps = args.get_model_steps
        self._opt = self._opt_fn()
        self._model.optimizer = self._opt
        self._non_embed_grads = {}
        self._evaluation_result = {}

        saved_model_exporter = SavedModelExporter(self._task_data_service,
                                                  self._dataset_fn,
                                                  self._model_handler)
        # Place default callbacks at the head to execute them firstly
        self._callbacks_list.callbacks.insert(0, saved_model_exporter)
        self._callbacks_list.set_model(model_inst)
        set_callback_parameters(
            self._callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )

        self._allreduce_trainer = None
        if self._distribution_strategy == DistributionStrategy.ALLREDUCE:
            master_addr = args.master_addr.split(":")[0]
            self._allreduce_trainer = AllReduceTrainer(self._mc, master_addr,
                                                       self._model, self._loss,
                                                       self._opt)
예제 #7
0
 def setUp(self):
     tf.keras.backend.clear_session()
     self.model_handler = ModelHandler.get_model_handler(
         distribution_strategy=DistributionStrategy.PARAMETER_SERVER,
         checkpoint_dir="",
     )
예제 #8
0
 def setUp(self):
     self.model_handler = ModelHandler.get_model_handler()
예제 #9
0
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip)
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.model_inst = load_model_from_module(args.model_def,
                                                 self.model_module,
                                                 args.model_params)
        model_handler = ModelHandler.get_model_handler(
            args.distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        self.model_inst = model_handler.get_model_to_train(self.model_inst)
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader]

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
        )

        saved_model_path = args.output
        if saved_model_path is not None and self.job_type in [
                JobType.TRAINING_ONLY,
                JobType.TRAINING_WITH_EVALUATION,
        ]:
            self.task_d.add_deferred_callback_create_save_model_task(
                saved_model_path)

        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        self._should_stop = False
        self._exit_code = 0
예제 #10
0
 def setUp(self):
     tf.keras.backend.clear_session()
     self.model_handler = ModelHandler.get_model_handler(
         distribution_strategy=DistributionStrategy.PARAMETER_SERVER,
         checkpoint_dir="elasticdl/python/tests/testdata/functional_ckpt/",
     )