Example #1
0
    def __init__(
        self,
        args,
        channel=None,
        ps_channels=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        max_allreduce_retry_num=DEFAULT_MAX_ALLREDUCE_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        if channel is None:
            self._stub = None
        else:
            self._stub = elasticdl_pb2_grpc.MasterStub(channel)

        self._use_multi_ps = False
        self._ps_vars = {}
        if isinstance(ps_channels, list):
            if len(ps_channels) > 0:
                self._use_multi_ps = True
                self._ps_stubs = [
                    elasticdl_pb2_grpc.PserverStub(c) for c in ps_channels
                ]
                self._var_to_ps = {}
                self._ps_num = len(self._ps_stubs)
        else:
            self._ps_num = 0
        self._distribution_strategy = args.distribution_strategy
        if (self._distribution_strategy
                == DistributionStrategy.PARAMETER_SERVER
                and self._use_multi_ps is False):
            raise ValueError(
                "PS channels are not set up under parameter server strategy")

        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._max_allreduce_retry_num = max_allreduce_retry_num
        self._init_from_args(args)
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
Example #2
0
    def __init__(
        self,
        args,
        master_client=None,
        ps_client=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        self._mc = master_client
        self._ps_client = ps_client
        self._distribution_strategy = args.distribution_strategy
        if (
            self._distribution_strategy
            == DistributionStrategy.PARAMETER_SERVER
        ):
            if self._ps_client is None:
                raise ValueError(
                    "PS channels are not set up under "
                    "parameter server strategy"
                )
            else:
                self._model_versions_from_ps = [
                    -1 for _ in range(self._ps_client.ps_num)
                ]
        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._init_from_args(args)
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
        self._var_created = False
Example #3
0
    def __init__(
        self,
        args,
        master_client=None,
        ps_client=None,
        max_minibatch_retry_num=DEFAULT_MAX_MINIBATCH_RETRY_NUM,
        set_parallelism=False,
    ):
        """
        Arguments:
            channel: The channel for the gRPC master service.
            ps_channels: The PS channels for PS service
            max_minibatch_retry_num: The maximum number of a minibatch retry
                as its results (e.g. gradients) are not accepted by master.
            max_allreduce_retry_num: The maximum number of retries for
                allreduce operation if allreduce-based distributed
                training strategy is used.
        """
        self._args = args
        self.logger = get_logger("Worker", level=args.log_level.upper())

        if set_parallelism:
            # Explicitly setting the parallelism will avoid multi-process hangs
            # Maybe due to an unknown bug in Tensorflow?
            # Must called before TensorFlow is initialized.
            # Not set_parallelism by default to make unittests happy.
            num_threads = os.cpu_count()
            tf.config.threading.set_inter_op_parallelism_threads(num_threads)
            tf.config.threading.set_intra_op_parallelism_threads(num_threads)

        self._mc = master_client
        self._ps_client = ps_client
        self._distribution_strategy = args.distribution_strategy
        self._max_minibatch_retry_num = max_minibatch_retry_num
        self._timing = Timing(args.log_level.upper() == "DEBUG", self.logger)
        self._log_loss_count = 0
        self._var_created = False
        self._job_type = args.job_type
        self._minibatch_size = args.minibatch_size
        self._data_shard_service = DataShardService(self._mc,
                                                    self._minibatch_size)
        self._init_model_from_args(args)
        self._init_task_data_service(args)
        self._init_default_feed_if_needed()
        self._init_callbacks(args)
        self._init_trainer(args)