コード例 #1
0
    def __init__(
        self,
        args,
        channel=None,
        ps_channels=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        max_allreduce_retry_num=DEFAULT_MAX_ALLREDUCE_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        if channel is None:
            self._stub = None
        else:
            self._stub = elasticdl_pb2_grpc.MasterStub(channel)

        self._use_multi_ps = False
        self._ps_vars = {}
        if isinstance(ps_channels, list):
            if len(ps_channels) > 0:
                self._use_multi_ps = True
                self._ps_stubs = [
                    elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
                ]
                self._var_to_ps = {}
                self._ps_num = len(self._ps_stubs)
        else:
            self._ps_num = 0
        self._distribution_strategy = args.distribution_strategy
        if (self._distribution_strategy
                == DistributionStrategy.PARAMETER_SERVER
                and self._use_multi_ps is False):
            raise ValueError(
                "PS channels are not set up under parameter server strategy")

        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._max_allreduce_retry_num = max_allreduce_retry_num
        self._init_from_args(args)
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
コード例 #2
0
ファイル: main.py プロジェクト: zhaozhy/elasticdl
def main():
    args = parse_worker_args()
    channel = grpc.insecure_channel(
        args.master_addr,
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )

    logger = log_utils.get_logger(__name__)

    logger.info("Starting worker %d", args.worker_id)
    worker = Worker(
        args.worker_id,
        args.job_type,
        args.minibatch_size,
        args.model_zoo,
        channel=channel,
        embedding_service_endpoint=eval(args.embedding_service_endpoint),
        dataset_fn=args.dataset_fn,
        loss=args.loss,
        optimizer=args.optimizer,
        eval_metrics_fn=args.eval_metrics_fn,
        model_def=args.model_def,
        model_params=args.model_params,
        data_reader_params=args.data_reader_params,
        get_model_steps=args.get_model_steps,
    )
    worker.run()
コード例 #3
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = build_channel(args.master_addr)

    ps_channels = []
    if args.ps_addrs:
        ps_addrs = args.ps_addrs.split(",")

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            channel = build_channel(addr)

            # Wait the channel is ready by a Future object.
            grpc.channel_ready_future(channel).result()
            logger.info("grpc channel %s to connect pod %s is ready" %
                        (addr, addr.split(".")[0]))
            ps_channels.append(channel)

    worker = Worker(args, channel=master_channel, ps_channels=ps_channels)
    worker.run()
コード例 #4
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())
        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.sync_version_tolerance = args.sync_version_tolerance
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self._set_lr_scheduler(model_module, args.learning_rate_scheduler)
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        self.num_workers = args.num_workers
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_saver(args)
        self._restore_params_from_checkpoint(args.checkpoint_dir_for_init)
        self._debug_info_needed = args.log_level.upper() == "DEBUG"
コード例 #5
0
    def __init__(self, args, task_manager, rendezvous_server=None):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = get_job_type(args)

        # Initialize the components from the model definition
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__

        self._optimizer = model_module[args.optimizer]()

        # TODO: Remove task manage and rendezvous server after
        # refactoring pod manager.
        self.task_manager = task_manager
        self.rendezvous_server = rendezvous_server

        self.evaluation_service = (
            None if args.eval_metrics_fn not in model_module
            else self._create_evaluation_service(
                model_module[args.eval_metrics_fn], args.evaluation_steps))
コード例 #6
0
def main():
    args = parse_worker_args()
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")
    channel = grpc.insecure_channel(
        args.master_addr,
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )

    # TODO, create PS channels here
    ps_addrs = args.ps_addrs.split(",")
    # Just print ps_addrs out to avoid flake8 failure
    # This print can be removed once we initialize ps_channels
    # by using ps_addrs
    print("Parameter server addresses are %s" % ps_addrs)
    ps_channels = None

    logger = log_utils.get_logger(__name__)

    logger.info("Starting worker %d", args.worker_id)
    worker = Worker(args, channel=channel, ps_channels=ps_channels)
    worker.run()
コード例 #7
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = build_channel(args.master_addr)

    ps_channels = []
    if args.ps_addrs:
        # TODO: use ps_addrs from master directly after ps service is working.
        #       Get ps pod ip for ps grpc connection for now.
        ps_addrs = args.ps_addrs.split(",")

        config.load_incluster_config()
        api = client.CoreV1Api()

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            addr_splitted = addr.split(".")
            while True:
                pod = api.read_namespaced_pod(
                    namespace=addr_splitted[1], name=addr_splitted[0]
                )
                if pod.status.pod_ip:
                    break
                # If ps pod is not ready yet, sleep 2 seconds and try again.
                time.sleep(2)
            addr = pod.status.pod_ip + ":" + addr.split(":")[-1]
            channel = grpc.insecure_channel(
                addr,
                options=[
                    (
                        "grpc.max_send_message_length",
                        GRPC.MAX_SEND_MESSAGE_LENGTH,
                    ),
                    (
                        "grpc.max_receive_message_length",
                        GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                    ),
                ],
            )

            # Wait the channel is ready by a Future object.
            grpc.channel_ready_future(channel).result()
            logger.info(
                "grpc channel %s to connect pod %s is ready"
                % (addr, pod.metadata.name)
            )
            ps_channels.append(channel)

    worker = Worker(args, channel=master_channel, ps_channels=ps_channels)
    worker.run()
コード例 #8
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())

        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.optimizer = model_module[args.optimizer]()
        # Create Parameters instance
        self.parameters = Parameters()
コード例 #9
0
ファイル: main.py プロジェクト: dut3062796s/elasticdl
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = grpc.insecure_channel(
        args.master_addr,
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )

    ps_channels = []
    if args.ps_addrs:
        # TODO: use ps_addrs from master directly after ps service is working.
        #       Get ps pod ip for ps grpc connection for now.
        ps_addrs = args.ps_addrs.split(",")
        from kubernetes import client, config

        config.load_incluster_config()
        api = client.CoreV1Api()

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            addr_splitted = addr.split(".")
            pod = api.read_namespaced_pod(namespace=addr_splitted[1],
                                          name=addr_splitted[0])
            addr = pod.status.pod_ip + ":" + addr.split(":")[-1]
            channel = grpc.insecure_channel(
                addr,
                options=[
                    (
                        "grpc.max_send_message_length",
                        GRPC.MAX_SEND_MESSAGE_LENGTH,
                    ),
                    (
                        "grpc.max_receive_message_length",
                        GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                    ),
                ],
            )
            ps_channels.append(channel)

    worker = Worker(args, channel=master_channel, ps_channels=ps_channels)
    worker.run()
コード例 #10
0
ファイル: worker.py プロジェクト: shijungg/elasticdl
    def __init__(
        self,
        args,
        master_client=None,
        ps_client=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        self._mc = master_client
        self._ps_client = ps_client
        self._distribution_strategy = args.distribution_strategy
        if (
            self._distribution_strategy
            == DistributionStrategy.PARAMETER_SERVER
        ):
            if self._ps_client is None:
                raise ValueError(
                    "PS channels are not set up under "
                    "parameter server strategy"
                )
            else:
                self._model_versions_from_ps = [
                    -1 for _ in range(self._ps_client.ps_num)
                ]
        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._init_from_args(args)
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
        self._var_created = False
コード例 #11
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = build_channel(args.master_addr)

    ps_channels = []
    if args.ps_addrs:
        ps_addrs = args.ps_addrs.split(",")

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            channel = build_channel(addr)

            succeeded = False
            for i in range(CONNECT_PS_MAX_RETRIES):
                try:
                    grpc.channel_ready_future(channel).result(
                        timeout=CONNECT_PS_TIMEOUT)
                    logger.info("grpc channel %s to connect pod %s is ready" %
                                (addr, addr.split(".")[0]))
                    ps_channels.append(channel)
                    succeeded = True
                    break
                except grpc.FutureTimeoutError:
                    logger.warning("Failed to connect pod %s with %d retry" %
                                   (addr.split(".")[0], i))
            if not succeeded:
                raise TimeoutError(
                    "Time out to connect pod %s with 3 retries" %
                    addr.split(".")[0])

    if args.distribution_strategy == DistributionStrategy.ALLREDUCE:
        logger.info("Wait for %s seconds for FTLib consensus service to "
                    "detect the worker pod" %
                    str(_ALLREDUCE_STRATEGY_WARM_UP_SECS))
        time.sleep(_ALLREDUCE_STRATEGY_WARM_UP_SECS)

    worker = Worker(
        args,
        channel=master_channel,
        ps_channels=ps_channels,
        set_parallelism=True,
    )
    worker.run()
コード例 #12
0
    def __init__(
        self,
        args,
        master_client=None,
        ps_client=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        self._mc = master_client
        self._ps_client = ps_client
        self._distribution_strategy = args.distribution_strategy
        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
        self._var_created = False
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        self._data_shard_service = DataShardService(self._mc,
                                                    self._minibatch_size)
        self._init_model_from_args(args)
        self._init_task_data_service(args)
        self._init_default_feed_if_needed()
        self._init_callbacks(args)
        self._init_trainer(args)
コード例 #13
0
ファイル: main.py プロジェクト: zerocurve/elasticdl
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_client = MasterClient(build_channel(args.master_addr),
                                 args.worker_id)

    ps_client = None
    if (args.distribution_strategy == DistributionStrategy.PARAMETER_SERVER
            and args.ps_addrs):
        ps_channels = []
        ps_addrs = args.ps_addrs.split(",")

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            channel = build_channel(addr)

            succeeded = False
            for i in range(CONNECT_PS_MAX_RETRIES):
                try:
                    grpc.channel_ready_future(channel).result(
                        timeout=CONNECT_PS_TIMEOUT)
                    logger.info("grpc channel %s to connect pod %s is ready" %
                                (addr, addr.split(".")[0]))
                    ps_channels.append(channel)
                    succeeded = True
                    break
                except grpc.FutureTimeoutError:
                    logger.warning("Failed to connect pod %s with %d retry" %
                                   (addr.split(".")[0], i))
            if not succeeded:
                raise TimeoutError(
                    "Time out to connect pod %s with 3 retries" %
                    addr.split(".")[0])
        ps_client = PSClient(ps_channels)

    worker = Worker(
        args,
        master_client=master_client,
        ps_client=ps_client,
        set_parallelism=True,
    )
    worker.run()
コード例 #14
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())

        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_service(args)
コード例 #15
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    master_addr = args.master_addr
    worker_id = int(args.worker_id)

    logger.info("Starting worker %d", worker_id)

    master_client = MasterClient(build_channel(master_addr), worker_id)

    logger.info("Building PS connection....")
    ps_client = (build_ps_client(args.ps_addrs, logger)
                 if args.distribution_strategy
                 == DistributionStrategy.PARAMETER_SERVER else None)

    logger.info("Have builded PS.")

    worker = Worker(
        args,
        master_client=master_client,
        ps_client=ps_client,
        set_parallelism=True,
    )
    worker.run()
コード例 #16
0
def main():
    args = parse_master_args()
    logger = get_logger("master", level=args.log_level.upper())

    # Master addr
    master_ip = os.getenv("MY_POD_IP", "localhost")
    master_addr = "%s:%d" % (master_ip, args.port)

    # Start TensorBoard service if requested
    if args.tensorboard_log_dir:
        logger.info(
            "Starting TensorBoard service with log directory %s",
            args.tensorboard_log_dir,
        )
        # Start TensorBoard CLI
        tb_service = TensorboardService(args.tensorboard_log_dir, master_ip)
        tb_service.start()
    else:
        tb_service = None

    # Start task queue
    logger.debug(
        "Starting task queue with training data directory %s, "
        "evaluation data directory %s, "
        "and prediction data directory %s",
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
    )

    records_per_task = args.minibatch_size * args.num_minibatches_per_task
    task_d = _make_task_dispatcher(
        args.training_data_dir,
        args.evaluation_data_dir,
        args.prediction_data_dir,
        records_per_task,
        args.num_epochs,
    )
    model_module = load_module(
        get_module_file_path(args.model_zoo, args.model_def)
    ).__dict__
    model_inst = load_model_from_module(
        args.model_def, model_module, args.model_params
    )
    optimizer = model_module[args.optimizer]()

    if all(
        (
            args.training_data_dir,
            args.evaluation_data_dir,
            args.evaluation_throttle_secs or args.evaluation_steps,
        )
    ):
        job_type = JobType.TRAINING_WITH_EVALUATION
    elif all(
        (
            args.evaluation_data_dir,
            not args.training_data_dir,
            not args.prediction_data_dir,
        )
    ):
        job_type = JobType.EVALUATION_ONLY
    elif all(
        (
            args.prediction_data_dir,
            not args.evaluation_data_dir,
            not args.training_data_dir,
        )
    ):
        job_type = JobType.PREDICTION_ONLY
    else:
        job_type = JobType.TRAINING_ONLY

    # Initialize checkpoint service
    if args.checkpoint_steps or job_type == JobType.TRAINING_WITH_EVALUATION:
        logger.info("Starting checkpoint service")
        checkpoint_service = CheckpointService(
            args.checkpoint_dir,
            args.checkpoint_steps,
            args.keep_checkpoint_max,
            job_type == JobType.TRAINING_WITH_EVALUATION,
        )
    else:
        checkpoint_service = None

    # Initialize evaluation service
    evaluation_service = None
    if (
        job_type == JobType.TRAINING_WITH_EVALUATION
        or job_type == JobType.EVALUATION_ONLY
    ):
        logger.info(
            "Starting evaluation service with throttle seconds %d "
            " and evaluation steps %d",
            args.evaluation_throttle_secs,
            args.evaluation_steps,
        )
        evaluation_service = EvaluationService(
            checkpoint_service,
            tb_service,
            task_d,
            args.evaluation_start_delay_secs,
            args.evaluation_throttle_secs,
            args.evaluation_steps,
            job_type == JobType.EVALUATION_ONLY,
        )
        evaluation_service.start()
        task_d.set_evaluation_service(evaluation_service)

    embedding_service_endpoint = None
    embedding_dims = {}
    # Search for embedding layers in the model,
    # if found, initialize embedding service
    layers = find_layer(model_inst, Embedding)
    if layers:
        embedding_service = EmbeddingService()
        embedding_service_endpoint = embedding_service.start_embedding_service(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
            resource_request=args.master_resource_request,
            resource_limit=args.master_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
        )
        logger.info(
            "Embedding service start succeeded. The endpoint is %s."
            % str(embedding_service_endpoint)
        )
        embedding_dims = dict(
            [(layer.name, layer.output_dim) for layer in layers]
        )

    # The master service
    logger.info("Starting master service")
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=64),
        options=[
            ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
            (
                "grpc.max_receive_message_length",
                GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
            ),
        ],
    )
    master_servicer = MasterServicer(
        args.grads_to_wait,
        args.minibatch_size,
        optimizer,
        task_d,
        init_var=model_inst.trainable_variables if model_inst.built else [],
        embedding_dims=embedding_dims,
        checkpoint_filename_for_init=args.checkpoint_filename_for_init,
        checkpoint_service=checkpoint_service,
        evaluation_service=evaluation_service,
        embedding_service_endpoint=embedding_service_endpoint,
        lr_staleness_modulation=args.lr_staleness_modulation,
        use_async=args.use_async,
    )
    elasticdl_pb2_grpc.add_MasterServicer_to_server(master_servicer, server)
    server.add_insecure_port("[::]:{}".format(args.port))
    server.start()
    logger.info("Server started at port: %d", args.port)

    worker_manager = None
    if args.num_workers:
        assert args.worker_image, "Worker image cannot be empty"

        worker_command = ["python"]
        worker_args = [
            "-m",
            "elasticdl.python.worker.main",
            "--master_addr",
            master_addr,
            "--job_type",
            job_type,
            "--embedding_service_endpoint",
            str(embedding_service_endpoint),
        ]
        worker_args.extend(build_arguments_from_parsed_result(args))

        env_dict = parse_envs(args.envs)
        env = []
        for key in env_dict:
            env.append(V1EnvVar(name=key, value=env_dict[key]))

        worker_manager = WorkerManager(
            task_d,
            job_name=args.job_name,
            image_name=args.worker_image,
            command=worker_command,
            args=worker_args,
            namespace=args.namespace,
            num_workers=args.num_workers,
            worker_resource_request=args.worker_resource_request,
            worker_resource_limit=args.worker_resource_limit,
            pod_priority=args.worker_pod_priority,
            volume=args.volume,
            image_pull_policy=args.image_pull_policy,
            restart_policy=args.restart_policy,
            cluster_spec=args.cluster_spec,
            envs=env,
        )
        worker_manager.update_status(WorkerManagerStatus.PENDING)
        logger.info("Launching %d workers", args.num_workers)
        worker_manager.start_workers()
        worker_manager.update_status(WorkerManagerStatus.RUNNING)

    # Start TensorBoard k8s Service if requested
    if tb_service:
        TensorBoardClient(
            job_name=args.job_name,
            image_name=args.worker_image,
            namespace=args.namespace,
        ).start_tensorboard_service()

    try:
        while True:
            if task_d.finished():
                if worker_manager:
                    worker_manager.update_status(WorkerManagerStatus.FINISHED)
                if args.output:
                    master_servicer.save_latest_checkpoint(args.output)
                break
            time.sleep(30)
    except KeyboardInterrupt:
        logger.warning("Server stopping")

    if evaluation_service:
        logger.info("Stopping evaluation service")
        evaluation_service.stop()

    logger.info("Stopping RPC server")
    server.stop(0)

    # Keep TensorBoard running when all the tasks are finished
    if tb_service:
        logger.info(
            "All tasks finished. Keeping TensorBoard service running..."
        )
        while True:
            if tb_service.is_active():
                time.sleep(10)
            else:
                logger.warning(
                    "Unable to keep TensorBoard running. "
                    "It has already terminated"
                )
                break
    logger.info("Master stopped")
コード例 #17
0
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir
        self.distribution_strategy = args.distribution_strategy

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)
        self.rendezvous_server = None
        if self.distribution_strategy == DistributionStrategy.ALLREDUCE:
            self.rendezvous_server = HorovodRendezvousServer(master_ip)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip
        )
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.model_inst = load_model_from_module(
            args.model_def, self.model_module, args.model_params
        )
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader
            ]

        # Initialize the callbacks
        self.callbacks_list = load_callbacks_from_module(
            args.callbacks, self.model_module
        )
        self.callbacks_list.set_model(self.model_inst)
        set_callback_parameters(
            self.callbacks_list,
            batch_size=args.minibatch_size,
            saved_model_path=args.output,
            checkpoint_path=args.checkpoint_dir,
        )
        self._set_completed_steps_by_checkpoint(args.checkpoint_dir_for_init)

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
            self.callbacks_list,
        )

        self.task_d.add_deferred_callback_create_train_end_task()
        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        self._should_stop = False
        self._exit_code = 0
        threading.Thread(
            target=self._check_timeout_tasks,
            name="check_timeout_tasks",
            daemon=True,
        ).start()
コード例 #18
0
ファイル: master.py プロジェクト: xhcom-ui/elasticdl
    def __init__(self, args):
        self.logger = get_logger("master", level=args.log_level.upper())

        self.num_ps_pods = args.num_ps_pods
        self.checkpoint_output_path = args.checkpoint_dir

        # Master addr
        master_ip = os.getenv("MY_POD_IP", "localhost")
        self.master_addr = "%s:%d" % (master_ip, args.port)
        self.job_type = Master._get_job_type(args)

        # Initialize TensorBoard service if requested
        self.tb_service = self._create_tensorboard_service(
            args.tensorboard_log_dir, master_ip)
        if self.tb_service:
            self.tb_client = TensorBoardClient(
                job_name=args.job_name,
                image_name=args.worker_image,
                namespace=args.namespace,
            )

        # Initialize the components from the model definition
        self.model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.model_inst = load_model_from_module(args.model_def,
                                                 self.model_module,
                                                 args.model_params)
        model_handler = ModelHandler.get_model_handler(
            args.distribution_strategy, checkpoint_dir=args.checkpoint_dir)
        self.model_inst = model_handler.get_model_to_train(self.model_inst)
        self.optimizer = self.model_module[args.optimizer]()
        self._create_data_reader_fn = create_data_reader
        if args.custom_data_reader in self.model_module:
            self._create_data_reader_fn = self.model_module[
                args.custom_data_reader]

        # Start task queue
        records_per_task = args.minibatch_size * args.num_minibatches_per_task
        self.task_d = _make_task_dispatcher(
            args.training_data,
            args.validation_data,
            args.prediction_data,
            records_per_task,
            args.num_epochs,
            args.data_reader_params,
            self._create_data_reader_fn,
        )

        saved_model_path = args.output
        if saved_model_path is not None and self.job_type in [
                JobType.TRAINING_ONLY,
                JobType.TRAINING_WITH_EVALUATION,
        ]:
            self.task_d.add_deferred_callback_create_save_model_task(
                saved_model_path)

        self.evaluation_service = self._create_evaluation_service(args)

        # Initialize master service
        self.master_servicer, self.server = self._create_master_service(args)

        # Initialize instance manager
        self.instance_manager = self._create_instance_manager(args)

        self._should_stop = False
        self._exit_code = 0