Пример #1
0
    def testGetEmptyTask(self):
        master = MasterServicer(
            3,
            _TaskDispatcher({}, {}, {}, records_per_task=3, num_epochs=2),
            evaluation_service=None,
        )

        req = elasticdl_pb2.GetTaskRequest()

        # No task yet, make sure the returned versions are as expected.
        req.worker_id = 1
        task = master.get_task(req, None)
        self.assertEqual("", task.shard_name)
        self.assertEqual(0, task.model_version)

        master._version = 1
        task = master.get_task(req, None)
        self.assertEqual("", task.shard_name)
        self.assertEqual(1, task.model_version)
Пример #2
0
    def test_get_empty_task(self):
        self.master.task_manager = create_task_manager([], [])
        master_servicer = MasterServicer(
            self.master.task_manager,
            self.master.instance_manager,
            None,
            None,
        )

        req = elasticai_api_pb2.GetTaskRequest()

        # No task yet, make sure the returned versions are as expected.
        req.worker_id = 1
        task = master_servicer.get_task(req, None)
        self.assertEqual("", task.shard.name)
        self.assertEqual(0, task.model_version)

        master_servicer._version = 1
        task = master_servicer.get_task(req, None)
        self.assertEqual("", task.shard.name)
        self.assertEqual(1, task.model_version)
Пример #3
0
    def testReportTaskResult(self):
        task_d = _TaskDispatcher(
            {
                "shard_1": (0, 10),
                "shard_2": (0, 9)
            },
            {},
            {},
            records_per_task=3,
            num_epochs=2,
        )
        master = MasterServicer(
            3,
            task_d,
            evaluation_service=None,
        )

        # task to number of runs.
        tasks = defaultdict(int)
        while True:
            req = elasticdl_pb2.GetTaskRequest()
            req.worker_id = random.randint(1, 10)
            task = master.get_task(req, None)
            if not task.shard_name:
                break
            self.assertEqual(task_d._doing[task.task_id][0], req.worker_id)
            task_key = (task.shard_name, task.start, task.end)
            tasks[task_key] += 1
            report = elasticdl_pb2.ReportTaskResultRequest()
            report.task_id = task.task_id
            if task.start == 0 and tasks[task_key] == 1:
                # Simulate error reports.
                report.err_message = "Worker error"
            master.report_task_result(report, None)

        self.assertDictEqual(
            {
                ("shard_1", 0, 3): 3,
                ("shard_1", 3, 6): 2,
                ("shard_1", 6, 9): 2,
                ("shard_1", 9, 10): 2,
                ("shard_2", 0, 3): 3,
                ("shard_2", 3, 6): 2,
                ("shard_2", 6, 9): 2,
            },
            tasks,
        )
Пример #4
0
    def test_report_task_result(self):
        self.master.task_manager = create_task_manager([("shard_1", 0, 10),
                                                        ("shard_2", 0, 9)], [],
                                                       2)
        master = MasterServicer(
            self.master.task_manager,
            self.master.instance_manager,
            None,
            None,
        )

        # task to number of runs.
        tasks = defaultdict(int)
        while True:
            req = elasticai_api_pb2.GetTaskRequest()
            req.worker_id = random.randint(1, 10)
            task = master.get_task(req, None)
            if not task.shard.name:
                break
            self.assertEqual(self.master.task_manager._doing[task.task_id][0],
                             req.worker_id)
            task_key = (task.shard.name, task.shard.start, task.shard.end)
            tasks[task_key] += 1
            report = elasticai_api_pb2.ReportTaskResultRequest()
            report.task_id = task.task_id
            if task.shard.start == 0 and tasks[task_key] == 1:
                # Simulate error reports.
                report.err_message = "Worker error"
            master.report_task_result(report, None)

        self.assertDictEqual(
            {
                ("shard_1", 0, 3): 3,
                ("shard_1", 3, 6): 2,
                ("shard_1", 6, 9): 2,
                ("shard_1", 9, 10): 2,
                ("shard_2", 0, 3): 3,
                ("shard_2", 3, 6): 2,
                ("shard_2", 6, 9): 2,
            },
            tasks,
        )
Пример #5
0
def distributed_train_and_evaluate(
    feature_shape,
    model_zoo_path,
    model_def,
    model_params="",
    eval_metrics_fn="eval_metrics_fn",
    loss="loss",
    training=True,
    dataset_name=DatasetName.IMAGE_DEFAULT,
    callback_classes=[],
    use_async=False,
    get_model_steps=1,
    ps_channels=None,
    pservers=None,
    distribution_strategy=DistributionStrategy.PARAMETER_SERVER,
):
    """Runs distributed training and evaluation with a local master. Grpc
    calls are mocked by local master call.

    Args:
        feature_shape: The shape of model input.
        model_zoo_path: The directory that contains user-defined model files
            or a specific model file.
        model_def: The import path to the model definition function/class in
            the model zoo, e.g.  "cifar10_subclass.CustomModel".
        model_params: The dictionary of model parameters in a string that will
            be used to instantiate the model, e.g. "param1=1,param2=2".
        eval_metrics_fn: The name of the evaluation metrics function defined
            in the model file.
        loss: The name of the loss function defined in the model file.
        training: True for job type `TRAIN_WITH_EVALUATION`, False for
            job type `EVALUATION`.
        dataset_name: A dataset name from `DatasetName`.
        callback_classes: A List of callbacks that will be called at given
            stages of the training procedure.
        use_async: A bool. True if using asynchronous updates.
        get_model_steps: Worker will perform `get_model` from the parameter
            server every this many steps.
        ps_channels: A channel list to all parameter server pods.
        pservers: A list of parameter server pods.
        distribution_strategy: The distribution startegy used by workers, e.g.
            DistributionStrategy.PARAMETER_SERVER or
            DistributionStrategy.AllreduceStrategy.

    Returns:
        An integer indicating the model version after the distributed training
        and evaluation.
    """
    job_type = (JobType.TRAINING_WITH_EVALUATION
                if training else JobType.EVALUATION_ONLY)
    evaluation_steps = 1 if job_type == JobType.TRAINING_WITH_EVALUATION else 0
    batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16
    pservers = pservers or []
    ps_channels = ps_channels or []

    model_module = load_module(get_module_file_path(model_zoo_path,
                                                    model_def)).__dict__

    for channel in ps_channels:
        grpc.channel_ready_future(channel).result()
    worker_arguments = [
        "--worker_id",
        "1",
        "--job_type",
        job_type,
        "--minibatch_size",
        batch_size,
        "--model_zoo",
        model_zoo_path,
        "--model_def",
        model_def,
        "--model_params",
        model_params,
        "--loss",
        loss,
        "--get_model_steps",
        get_model_steps,
        "--distribution_strategy",
        distribution_strategy,
    ]
    args = parse_worker_args(worker_arguments)
    worker = Worker(args, ps_channels=ps_channels)

    if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]:
        record_num = batch_size
    else:
        record_num = 128
    shards = {
        create_recordio_file(record_num, dataset_name, feature_shape): (
            0,
            record_num,
        )
    }
    if training:
        training_shards = shards
        evaluation_shards = shards
    else:
        training_shards = {}
        evaluation_shards = shards
    task_d = _TaskDispatcher(
        training_shards,
        evaluation_shards,
        {},
        records_per_task=64,
        num_epochs=1,
    )

    if training:
        evaluation_service = EvaluationService(
            None,
            task_d,
            0,
            0,
            evaluation_steps,
            False,
            model_module[eval_metrics_fn],
        )
    else:
        evaluation_service = EvaluationService(
            None,
            task_d,
            0,
            0,
            evaluation_steps,
            True,
            model_module[eval_metrics_fn],
        )
    task_d.set_evaluation_service(evaluation_service)

    master = MasterServicer(
        batch_size,
        task_d,
        evaluation_service=evaluation_service,
    )
    callbacks = [
        callback_class(master, worker) for callback_class in callback_classes
    ]

    in_process_master = InProcessMaster(master, callbacks)
    worker._stub = in_process_master
    for pservicer in pservers:
        pservicer._master_stub = in_process_master

    worker.run()

    req = elasticdl_pb2.GetTaskRequest()
    req.worker_id = 1
    task = master.get_task(req, None)
    # No more task.
    if task.shard_name:
        raise RuntimeError(
            "There are some tasks unfinished after worker exits.")
    return master._version