def _make_cluster_def(self): """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. Raises: TypeError: If `cluster_spec` is not a dictionary mapping strings to lists of strings. """ self._cluster_def = tensorflow_server_pb2.ClusterDef() # NOTE(mrry): Sort by job_name to produce deterministic protobufs. for job_name, task_list in sorted(self._cluster_spec.items()): try: job_name = compat.as_bytes(job_name) except TypeError: raise TypeError("Job name %r must be bytes or unicode" % job_name) job_def = self._cluster_def.job.add() job_def.name = job_name for i, task_address in enumerate(task_list): try: task_address = compat.as_bytes(task_address) except TypeError: raise TypeError( "Task address %r must be bytes or unicode" % task_address) job_def.tasks[i] = task_address
def __init__(self, cluster): """Creates a `ClusterSpec`. Args: cluster: A dictionary mapping one or more job names to lists of network addresses, or a `tf.train.ClusterDef` protocol buffer. Raises: TypeError: If `cluster` is not a dictionary mapping strings to lists of strings, and not a `tf.train.ClusterDef` protobuf. """ if isinstance(cluster, dict): self._cluster_spec = cluster self._make_cluster_def() elif isinstance(cluster, tensorflow_server_pb2.ClusterDef): self._cluster_def = cluster self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = [ t for t in job_def.tasks.values() ] elif isinstance(cluster, ClusterSpec): self._cluster_def = tensorflow_server_pb2.ClusterDef() self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = [ t for t in job_def.tasks.values() ] else: raise TypeError( "`cluster` must be a dictionary mapping one or more " "job names to lists of network addresses, or a " "`ClusterDef` protocol buffer")
def __init__(self, cluster): """Creates a `ClusterSpec`. Args: cluster: A dictionary mapping one or more job names to (i) a list of network addresses, or (ii) a dictionary mapping integer task indices to network addresses; or a `tf.train.ClusterDef` protocol buffer. Raises: TypeError: If `cluster` is not a dictionary mapping strings to lists of strings, and not a `tf.train.ClusterDef` protobuf. """ if isinstance(cluster, dict): self._cluster_spec = {} for job_name, tasks in cluster.items(): if isinstance(tasks, (list, tuple)): job_tasks = {i: task for i, task in enumerate(tasks)} elif isinstance(tasks, dict): job_tasks = {i: task for i, task in tasks.items()} else: raise TypeError( "The tasks for job %r must be a list or a dictionary " "from integers to strings." % job_name) self._cluster_spec[job_name] = job_tasks self._make_cluster_def() elif isinstance(cluster, tensorflow_server_pb2.ClusterDef): self._cluster_def = cluster self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items() } elif isinstance(cluster, ClusterSpec): self._cluster_def = tensorflow_server_pb2.ClusterDef() self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items() } else: raise TypeError( "`cluster` must be a dictionary mapping one or more " "job names to lists of network addresses, or a " "`ClusterDef` protocol buffer")
def make_cluster_def(cluster_spec): """Returns a `tf.ClusterDef` based on the given `cluster_spec`. Args: cluster_spec: A dictionary mapping one or more job names to lists of network addresses. Returns: A `tf.ClusterDef` protocol buffer. Raises: TypeError: If `cluster_spec` is not a dictionary mapping strings to lists of strings. """ if not isinstance(cluster_spec, dict): raise TypeError( "`cluster_spec` must be a dictionary mapping one or more " "job names to lists of network addresses") cluster_def = tensorflow_server_pb2.ClusterDef() # NOTE(mrry): Sort by job_name to produce deterministic protobufs. for job_name, task_list in sorted(cluster_spec.items()): try: job_name = compat.as_bytes(job_name) except TypeError: raise TypeError("Job name %r must be bytes or unicode" % job_name) job_def = cluster_def.job.add() job_def.name = job_name for i, task_address in enumerate(task_list): try: task_address = compat.as_bytes(task_address) except TypeError: raise TypeError("Task address %r must be bytes or unicode" % task_address) job_def.tasks[i] = task_address return cluster_def