Ejemplo n.º 1
0
    def __init__(
        self,
        args,
        channel=None,
        ps_channels=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        max_allreduce_retry_num=DEFAULT_MAX_ALLREDUCE_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        if channel is None:
            self._stub = None
        else:
            self._stub = elasticdl_pb2_grpc.MasterStub(channel)

        self._use_multi_ps = False
        self._ps_vars = {}
        if isinstance(ps_channels, list):
            if len(ps_channels) > 0:
                self._use_multi_ps = True
                self._ps_stubs = [
                    elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
                ]
                self._var_to_ps = {}
                self._ps_num = len(self._ps_stubs)
        else:
            self._ps_num = 0
        self._distribution_strategy = args.distribution_strategy
        if (self._distribution_strategy
                == DistributionStrategy.PARAMETER_SERVER
                and self._use_multi_ps is False):
            raise ValueError(
                "PS channels are not set up under parameter server strategy")

        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._max_allreduce_retry_num = max_allreduce_retry_num
        self._init_from_args(args)
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
Ejemplo n.º 2
0
 def __init__(self, ps_channels):
     self.ps_stubs = [
         elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
     ]
     self.ps_num = len(self.ps_stubs)
     self.parameter_to_ps = {}
     self.ps_to_parameter = {}
Ejemplo n.º 3
0
    def __init__(
        self,
        args,
        channel=None,
        ps_channels=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        max_allreduce_retry_num=DEFAULT_MAX_ALLREDUCE_RETRY_NUM,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: TODO
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        if channel is None:
            self._stub = None
        else:
            self._stub = elasticdl_pb2_grpc.MasterStub(channel)

        self._use_multi_ps = False
        if isinstance(ps_channels, list):
            if len(ps_channels) > 0:
                self._use_multi_ps = True
                self._ps_stubs = [
                    elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
                ]
                self._var_to_ps = {}
        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._max_allreduce_retry_num = max_allreduce_retry_num
        self._init_from_args(args)
Ejemplo n.º 4
0
    def create_server_and_stub(self, grads_to_wait, lr_staleness_modulation,
                               use_async, **kwargs):
        args = PserverArgs(grads_to_wait=grads_to_wait,
                           lr_staleness_modulation=lr_staleness_modulation,
                           use_async=use_async,
                           port=self._port,
                           model_zoo=_test_model_zoo_path,
                           model_def="test_module.custom_model",
                           **kwargs)
        pserver = ParameterServer(args)
        pserver.prepare()
        self._parameters = pserver.parameters
        self._server = pserver.server
        self._stub = elasticdl_pb2_grpc.PserverStub(self._channel)

        self._lr = 0.1