Ejemplo n.º 1
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = build_channel(args.master_addr)

    ps_channels = []
    if args.ps_addrs:
        ps_addrs = args.ps_addrs.split(",")

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            channel = build_channel(addr)

            # Wait the channel is ready by a Future object.
            grpc.channel_ready_future(channel).result()
            logger.info("grpc channel %s to connect pod %s is ready" %
                        (addr, addr.split(".")[0]))
            ps_channels.append(channel)

    worker = Worker(args, channel=master_channel, ps_channels=ps_channels)
    worker.run()
Ejemplo n.º 2
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = build_channel(args.master_addr)

    ps_channels = []
    if args.ps_addrs:
        ps_addrs = args.ps_addrs.split(",")

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            channel = build_channel(addr)

            succeeded = False
            for i in range(CONNECT_PS_MAX_RETRIES):
                try:
                    grpc.channel_ready_future(channel).result(
                        timeout=CONNECT_PS_TIMEOUT)
                    logger.info("grpc channel %s to connect pod %s is ready" %
                                (addr, addr.split(".")[0]))
                    ps_channels.append(channel)
                    succeeded = True
                    break
                except grpc.FutureTimeoutError:
                    logger.warning("Failed to connect pod %s with %d retry" %
                                   (addr.split(".")[0], i))
            if not succeeded:
                raise TimeoutError(
                    "Time out to connect pod %s with 3 retries" %
                    addr.split(".")[0])

    if args.distribution_strategy == DistributionStrategy.ALLREDUCE:
        logger.info("Wait for %s seconds for FTLib consensus service to "
                    "detect the worker pod" %
                    str(_ALLREDUCE_STRATEGY_WARM_UP_SECS))
        time.sleep(_ALLREDUCE_STRATEGY_WARM_UP_SECS)

    worker = Worker(
        args,
        channel=master_channel,
        ps_channels=ps_channels,
        set_parallelism=True,
    )
    worker.run()
Ejemplo n.º 3
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())
        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.sync_version_tolerance = args.sync_version_tolerance
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self._set_lr_scheduler(model_module, args.learning_rate_scheduler)
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        self.num_workers = args.num_workers
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_saver(args)
        self._restore_params_from_checkpoint(args.checkpoint_dir_for_init)
        self._debug_info_needed = args.log_level.upper() == "DEBUG"
Ejemplo n.º 4
0
def create_pserver(model_zoo_path, model_def, grads_to_wait, use_async,
                   num_ps_pods):
    ports = [i + 12345 for i in range(num_ps_pods)]
    channels = []
    for port in ports:
        addr = "localhost:%d" % port
        channel = build_channel(addr)
        channels.append(channel)

    pservers = []
    for port in ports:
        args = PserverArgs(
            grads_to_wait=grads_to_wait,
            use_async=True,
            port=port,
            model_zoo=model_zoo_path,
            model_def=model_def,
        )
        pserver = ParameterServer(args)
        pserver.prepare()
        pservers.append(pserver)

    for channel in channels:
        grpc.channel_ready_future(channel).result()

    return ports, channels, pservers
Ejemplo n.º 5
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_client = MasterClient(build_channel(args.master_addr),
                                 args.worker_id)

    ps_client = None
    if (args.distribution_strategy == DistributionStrategy.PARAMETER_SERVER
            and args.ps_addrs):
        ps_channels = []
        ps_addrs = args.ps_addrs.split(",")

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            channel = build_channel(addr)

            succeeded = False
            for i in range(CONNECT_PS_MAX_RETRIES):
                try:
                    grpc.channel_ready_future(channel).result(
                        timeout=CONNECT_PS_TIMEOUT)
                    logger.info("grpc channel %s to connect pod %s is ready" %
                                (addr, addr.split(".")[0]))
                    ps_channels.append(channel)
                    succeeded = True
                    break
                except grpc.FutureTimeoutError:
                    logger.warning("Failed to connect pod %s with %d retry" %
                                   (addr.split(".")[0], i))
            if not succeeded:
                raise TimeoutError(
                    "Time out to connect pod %s with 3 retries" %
                    addr.split(".")[0])
        ps_client = PSClient(ps_channels)

    worker = Worker(
        args,
        master_client=master_client,
        ps_client=ps_client,
        set_parallelism=True,
    )
    worker.run()
Ejemplo n.º 6
0
 def setUp(self):
     self._port = 9999
     addr = "localhost:%d" % self._port
     self._channel = build_channel(addr)
     embedding_info = elasticdl_pb2.EmbeddingTableInfo()
     embedding_info.name = "layer_a"
     embedding_info.dim = 32
     embedding_info.initializer = "normal"
     self._embedding_info = embedding_info
     self._server = None
Ejemplo n.º 7
0
def main():
    args = parse_worker_args()
    logger = log_utils.get_logger(__name__)
    logger.info("Starting worker %d", args.worker_id)
    if args.master_addr is None:
        raise ValueError("master_addr is missing for worker")

    master_channel = build_channel(args.master_addr)

    ps_channels = []
    if args.ps_addrs:
        # TODO: use ps_addrs from master directly after ps service is working.
        #       Get ps pod ip for ps grpc connection for now.
        ps_addrs = args.ps_addrs.split(",")

        config.load_incluster_config()
        api = client.CoreV1Api()

        for addr in ps_addrs:
            # addr is in the form as "ps-pod-name.namespace.svc:port"
            addr_splitted = addr.split(".")
            while True:
                pod = api.read_namespaced_pod(
                    namespace=addr_splitted[1], name=addr_splitted[0]
                )
                if pod.status.pod_ip:
                    break
                # If ps pod is not ready yet, sleep 2 seconds and try again.
                time.sleep(2)
            addr = pod.status.pod_ip + ":" + addr.split(":")[-1]
            channel = grpc.insecure_channel(
                addr,
                options=[
                    (
                        "grpc.max_send_message_length",
                        GRPC.MAX_SEND_MESSAGE_LENGTH,
                    ),
                    (
                        "grpc.max_receive_message_length",
                        GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                    ),
                ],
            )

            # Wait the channel is ready by a Future object.
            grpc.channel_ready_future(channel).result()
            logger.info(
                "grpc channel %s to connect pod %s is ready"
                % (addr, pod.metadata.name)
            )
            ps_channels.append(channel)

    worker = Worker(args, channel=master_channel, ps_channels=ps_channels)
    worker.run()
    def _create_pserver(self, model_def, num):
        self._ports = [i + 12345 for i in range(num)]
        for port in self._ports:
            addr = "localhost:%d" % port
            channel = build_channel(addr)
            self._channels.append(channel)

        self._model_def = model_def
        for port in self._ports:
            args = PserverArgs(
                grads_to_wait=1,
                use_async=True,
                port=port,
                model_zoo=self._model_zoo_path,
                model_def=self._model_def,
            )
            pserver = ParameterServer(args)
            pserver.prepare()
            self._pservers.append(pserver)
Ejemplo n.º 9
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())

        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_service(args)
Ejemplo n.º 10
0
def distributed_train_and_evaluate(
    feature_shape,
    model_zoo_path,
    model_def,
    model_params="",
    eval_metrics_fn="eval_metrics_fn",
    loss="loss",
    training=True,
    dataset_name=DatasetName.IMAGE_DEFAULT,
    use_async=False,
    get_model_steps=1,
    ps_channels=None,
    pservers=None,
    distribution_strategy=DistributionStrategy.PARAMETER_SERVER,
):
    """Runs distributed training and evaluation with a local master. Grpc
    calls are mocked by local master call.

    Args:
        feature_shape: The shape of model input.
        model_zoo_path: The directory that contains user-defined model files
            or a specific model file.
        model_def: The import path to the model definition function/class in
            the model zoo, e.g.  "cifar10_subclass.CustomModel".
        model_params: The dictionary of model parameters in a string that will
            be used to instantiate the model, e.g. "param1=1,param2=2".
        eval_metrics_fn: The name of the evaluation metrics function defined
            in the model file.
        loss: The name of the loss function defined in the model file.
        training: True for job type `TRAIN_WITH_EVALUATION`, False for
            job type `EVALUATION`.
        dataset_name: A dataset name from `DatasetName`.
        use_async: A bool. True if using asynchronous updates.
        get_model_steps: Worker will perform `get_model` from the parameter
            server every this many steps.
        ps_channels: A channel list to all parameter server pods.
        pservers: A list of parameter server pods.
        distribution_strategy: The distribution startegy used by workers, e.g.
            DistributionStrategy.PARAMETER_SERVER or
            DistributionStrategy.AllreduceStrategy.

    Returns:
        An integer indicating the model version after the distributed training
        and evaluation.
    """
    job_type = (JobType.TRAINING_WITH_EVALUATION
                if training else JobType.EVALUATION_ONLY)
    evaluation_steps = 1 if job_type == JobType.TRAINING_WITH_EVALUATION else 0
    batch_size = 8 if dataset_name == DatasetName.IMAGENET else 16
    pservers = pservers or []
    ps_channels = ps_channels or []

    model_module = load_module(get_module_file_path(model_zoo_path,
                                                    model_def)).__dict__

    for channel in ps_channels:
        grpc.channel_ready_future(channel).result()
    worker_arguments = [
        "--worker_id",
        "1",
        "--job_type",
        job_type,
        "--minibatch_size",
        batch_size,
        "--model_zoo",
        model_zoo_path,
        "--model_def",
        model_def,
        "--model_params",
        model_params,
        "--loss",
        loss,
        "--get_model_steps",
        get_model_steps,
        "--distribution_strategy",
        distribution_strategy,
    ]
    args = parse_worker_args(worker_arguments)

    if dataset_name in [DatasetName.IMAGENET, DatasetName.FRAPPE]:
        record_num = batch_size
    else:
        record_num = 128
    shards = {
        create_recordio_file(record_num, dataset_name, feature_shape): (
            0,
            record_num,
        )
    }
    if training:
        training_shards = shards
        evaluation_shards = shards
    else:
        training_shards = {}
        evaluation_shards = shards
    task_d = _TaskDispatcher(
        training_shards,
        evaluation_shards,
        {},
        records_per_task=64,
        num_epochs=1,
    )

    if training:
        evaluation_service = EvaluationService(
            None,
            task_d,
            0,
            0,
            evaluation_steps,
            False,
            model_module[eval_metrics_fn],
        )
    else:
        evaluation_service = EvaluationService(
            None,
            task_d,
            0,
            0,
            evaluation_steps,
            True,
            model_module[eval_metrics_fn],
        )
    task_d.set_evaluation_service(evaluation_service)

    def master_creator():
        return MasterServicer(
            batch_size,
            task_d,
            evaluation_service=evaluation_service,
            master=None,
        )

    svc, port = _server(master_creator)
    mc = MasterClient(build_channel("localhost:%d" % port), 1)
    worker = Worker(args, master_client=mc, ps_client=PSClient(ps_channels))

    for pservicer in pservers:
        # FIXME(yancey1989): decouple pserver and master client
        pservicer._master_stub = mc

    worker.run()

    task = mc.get_task()
    # stop the master servicer
    svc.stop(0)
    # No more task.
    if task.shard_name:
        raise RuntimeError(
            "There are some tasks unfinished after worker exits.")
    return task.model_version