Пример #1
0
    def _make_cluster_def(self):
        """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.

    Raises:
      TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
        of strings.
    """
        self._cluster_def = tensorflow_server_pb2.ClusterDef()

        # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
        for job_name, task_list in sorted(self._cluster_spec.items()):
            try:
                job_name = compat.as_bytes(job_name)
            except TypeError:
                raise TypeError("Job name %r must be bytes or unicode" %
                                job_name)

            job_def = self._cluster_def.job.add()
            job_def.name = job_name

            for i, task_address in enumerate(task_list):
                try:
                    task_address = compat.as_bytes(task_address)
                except TypeError:
                    raise TypeError(
                        "Task address %r must be bytes or unicode" %
                        task_address)
                job_def.tasks[i] = task_address
Пример #2
0
    def __init__(self, cluster):
        """Creates a `ClusterSpec`.

    Args:
      cluster: A dictionary mapping one or more job names to lists of network
        addresses, or a `tf.train.ClusterDef` protocol buffer.

    Raises:
      TypeError: If `cluster` is not a dictionary mapping strings to lists
        of strings, and not a `tf.train.ClusterDef` protobuf.
    """
        if isinstance(cluster, dict):
            self._cluster_spec = cluster
            self._make_cluster_def()
        elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
            self._cluster_def = cluster
            self._cluster_spec = {}
            for job_def in self._cluster_def.job:
                self._cluster_spec[job_def.name] = [
                    t for t in job_def.tasks.values()
                ]
        elif isinstance(cluster, ClusterSpec):
            self._cluster_def = tensorflow_server_pb2.ClusterDef()
            self._cluster_def.MergeFrom(cluster.as_cluster_def())
            self._cluster_spec = {}
            for job_def in self._cluster_def.job:
                self._cluster_spec[job_def.name] = [
                    t for t in job_def.tasks.values()
                ]
        else:
            raise TypeError(
                "`cluster` must be a dictionary mapping one or more "
                "job names to lists of network addresses, or a "
                "`ClusterDef` protocol buffer")
Пример #3
0
    def __init__(self, cluster):
        """Creates a `ClusterSpec`.

    Args:
      cluster: A dictionary mapping one or more job names to (i) a
        list of network addresses, or (ii) a dictionary mapping integer
        task indices to network addresses; or a `tf.train.ClusterDef`
        protocol buffer.

    Raises:
      TypeError: If `cluster` is not a dictionary mapping strings to lists
        of strings, and not a `tf.train.ClusterDef` protobuf.
    """
        if isinstance(cluster, dict):
            self._cluster_spec = {}
            for job_name, tasks in cluster.items():
                if isinstance(tasks, (list, tuple)):
                    job_tasks = {i: task for i, task in enumerate(tasks)}
                elif isinstance(tasks, dict):
                    job_tasks = {i: task for i, task in tasks.items()}
                else:
                    raise TypeError(
                        "The tasks for job %r must be a list or a dictionary "
                        "from integers to strings." % job_name)
                self._cluster_spec[job_name] = job_tasks
            self._make_cluster_def()
        elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
            self._cluster_def = cluster
            self._cluster_spec = {}
            for job_def in self._cluster_def.job:
                self._cluster_spec[job_def.name] = {
                    i: t
                    for i, t in job_def.tasks.items()
                }
        elif isinstance(cluster, ClusterSpec):
            self._cluster_def = tensorflow_server_pb2.ClusterDef()
            self._cluster_def.MergeFrom(cluster.as_cluster_def())
            self._cluster_spec = {}
            for job_def in self._cluster_def.job:
                self._cluster_spec[job_def.name] = {
                    i: t
                    for i, t in job_def.tasks.items()
                }
        else:
            raise TypeError(
                "`cluster` must be a dictionary mapping one or more "
                "job names to lists of network addresses, or a "
                "`ClusterDef` protocol buffer")
Пример #4
0
def make_cluster_def(cluster_spec):
    """Returns a `tf.ClusterDef` based on the given `cluster_spec`.

  Args:
    cluster_spec: A dictionary mapping one or more job names to lists
      of network addresses.

  Returns:
    A `tf.ClusterDef` protocol buffer.

  Raises:
    TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
      of strings.
  """
    if not isinstance(cluster_spec, dict):
        raise TypeError(
            "`cluster_spec` must be a dictionary mapping one or more "
            "job names to lists of network addresses")

    cluster_def = tensorflow_server_pb2.ClusterDef()

    # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
    for job_name, task_list in sorted(cluster_spec.items()):
        try:
            job_name = compat.as_bytes(job_name)
        except TypeError:
            raise TypeError("Job name %r must be bytes or unicode" % job_name)

        job_def = cluster_def.job.add()
        job_def.name = job_name

        for i, task_address in enumerate(task_list):
            try:
                task_address = compat.as_bytes(task_address)
            except TypeError:
                raise TypeError("Task address %r must be bytes or unicode" %
                                task_address)
            job_def.tasks[i] = task_address

    return cluster_def