def remote_exec(self, args, hostname): """ Execute a bash script remotely. Args: args (list): bash commands hostname (str): host name or address Returns: Process: process handle """ cmd_list = [] ssh_config = self._ssh_conf[hostname] if ssh_config.python_venv: cmd_list.append('%s;' % ssh_config.python_venv) if ssh_config.env: cmd_list.extend( ['%s=%s' % (k, v) for k, v in ssh_config.env.items()]) full_cmd = ' '.join(cmd_list + args) remote_cmd = 'ssh -i {} -o StrictHostKeyChecking=no -tt -p {} {}@{} \'bash -c "{}"\' </dev/null' \ .format(ssh_config.key_file, ssh_config.port, ssh_config.username, hostname, full_cmd) logging.debug('$ %s' % remote_cmd) if ENV.AUTODIST_DEBUG_REMOTE.val: return None # pylint: disable=subprocess-popen-preexec-fn proc = subprocess.Popen(remote_cmd, shell=True, preexec_fn=os.setsid) return proc
def _compile_strategy(self, strategy): logging.debug('Raw strategy: %s' % strategy) device_resolver = DeviceResolver(self._cluster) compiled_strategy = base.StrategyCompiler(self._original_graph_item) \ .set_device_resolver(device_resolver.resolve_to_device_str) \ .compile(strategy) logging.info('Compiled strategy: %s' % compiled_strategy) return compiled_strategy
def _apply(self, *args, **kwargs): """ Apply replication to a graph. Returns: GraphItem """ new_graph_item = self._graph_item if self._num_local_replicas >= 1: new_graph_item = self.replicate(new_graph_item) logging.debug('Successfully replicated operations') return new_graph_item
def wrapper(*args, **kwargs): # Assume grads_and_vars is an iterable of tuples # Materialize here because in case it's a generator, we need to be able to iterate multiple times grads_and_vars = list(kwargs.get('grads_and_vars') or args[1]) grads, variables = map(list, zip(*grads_and_vars)) if _default_graph_item and kwargs.pop('update', True): _default_graph_item.extend_gradient_info(grads, variables) logging.debug('Registered grads: \n {} with targets: \n {}'.format( grads, variables)) args = (args[0], grads_and_vars ) # Replace possible generator with definite list return fn(*args, **kwargs)
def patch_optimizers(): """Patch all instances of OptimizerV2 for us to store optimizer and gradient information.""" q = deque( chain(OptimizerV2.__subclasses__(), OptimizerV1.__subclasses__())) while q: subclass = q.popleft() q.extend(list(subclass.__subclasses__())) subclass.__init__ = wrap_optimizer_init(subclass.__init__) subclass.apply_gradients = wrap_optimizer_apply_gradient( subclass.apply_gradients) logging.debug('Optimizer type: %s has been patched' % subclass.__name__)
def log_graph(graph, name): """ Log the graph on Tensorboard. Args: graph: the tensorflow graph to be plotted on Tensorboard name: the name of the graph """ directory = os.path.join(autodist.const.DEFAULT_WORKING_DIR, "graphs") os.makedirs(directory, exist_ok=True) p = os.path.join(directory, name) writer.FileWriter(p, graph=graph) logging.debug('Graph summary written to: %s' % p)
def __init__(self, key, graph_item, config, cluster): super().__init__(key) self._graph_item = graph_item self._cluster = cluster self._replica_devices = {device_spec.DeviceSpecV2.from_string(s) for s in config} self._local_canonical_replica_devices = sorted({ d.to_string() for d in self._replica_devices if self._cluster.get_local_address() == cluster.get_address_from_task(d.job, d.task) }) logging.debug('Local replica devices: {}'.format(self._local_canonical_replica_devices)) self._num_local_replicas = len(self._local_canonical_replica_devices) self._local_worker_id = self._cluster.get_local_worker_task_index() self._local_worker_device = '/job:worker/task:{}'.format(self._local_worker_id)
def _initialize_synchronizers(self): self._synchronizers = {} for node in self._strategy.node_config: partitioner = getattr(node, 'partitioner') if partitioner: for part in node.part_config: self._synchronizers[part.var_name] = \ Synchronizer.create(part.WhichOneof('synchronizer'), getattr(part, part.WhichOneof('synchronizer'))) else: self._synchronizers[node.var_name] = \ Synchronizer.create(node.WhichOneof('synchronizer'), getattr(node, node.WhichOneof('synchronizer'))) config = self._strategy.graph_config.replicas replica_devices = { device_spec.DeviceSpecV2.from_string(s) for s in config } replica_hosts = { self._cluster.get_address_from_task(d.job, d.task) for d in replica_devices } self._num_workers = len(replica_hosts) local_canonical_replica_devices = sorted({ d.to_string() for d in replica_devices if self._cluster.get_local_address() == self._cluster.get_address_from_task(d.job, d.task) }) logging.debug('Local replica devices: {}'.format( local_canonical_replica_devices)) self._num_local_replicas = len(local_canonical_replica_devices) local_worker_id = self._cluster.get_local_worker_task_index() local_worker_device = '/job:worker/task:{}'.format(local_worker_id) for synchronizer in self._synchronizers.values(): synchronizer.assign_cluster_information( num_workers=self._num_workers, num_replicas=self._num_local_replicas, worker_device=local_worker_device, worker_id=local_worker_id, canonical_replica_devices=sorted( {d.to_string() for d in replica_devices}), is_chief=self._cluster.is_chief())
def wrapper(*args, **kwargs): # args[0] should be `self`, which is an object of type == optimizer class containing_class = type(args[0]) class_name = containing_class.__name__ # For calls like super(AdamWeightDecay, self).__init__(*args, **kwargs), the containing_class.__name__ # returns the current class (AdamWeightDecay) instead of the parent class (Adam). # Avoid patching this pattern by checking fn.__qualname__. if not fn.__qualname__.startswith(class_name): return fn(*args, **kwargs) if _default_graph_item and kwargs.pop('update', True): _default_graph_item.extend_optimizer_info(containing_class, *args, **kwargs) logging.debug( 'Registered optimizer: {} \nwith args: {} \nkwargs: {}'.format( class_name, args, kwargs)) return fn(*args, **kwargs)
def _clean_stale_servers(): # pylint: disable=anomalous-backslash-in-string cmd = """ps aux | awk "/{}/ && !/ssh/ && ! /{}/ && ! /{}/" | awk "{{print \$2}}" | xargs kill -9""" # noqa: W605 # Processes of | the local stale servers && excluding the current starter's pid && ppid | keep pids | kill them cmd = cmd.format( os.path.splitext(os.path.basename(__file__))[0], os.getpid(), os.getppid()) local_cmd = "bash -c '{}'".format(cmd) logging.debug('>>> {}'.format(local_cmd)) try: output = subprocess.check_output(local_cmd, shell=True, stderr=subprocess.STDOUT) logging.debug('>>> {}'.format(output.decode('utf-8'))) except subprocess.CalledProcessError as e: if e.returncode != 123: # No stale process to kill raise
def transform(self): """Call graph transformer to transform a graph item based on strategy and cluster.""" logging.info( 'Transforming the original graph to a distributed graph...') with context.graph_mode(): graph_item = self.graph_item # Ensure the transformation happens under graph mode, no matter the outer mode is under eager or graph. visualization_util.log_graph(graph=graph_item.graph, name='0-original') graph_item, self._strategy.node_config = VariablePartitioner.apply( self._strategy.node_config, graph_item) visualization_util.log_graph(graph=graph_item.graph, name='1-after-partition') # Create Synchronizers for each node in the strategy self._initialize_synchronizers() # Replicate the graph (both in-graph and between-graph) new_graph_item = Replicator.apply( config=self._strategy.graph_config.replicas, cluster=self._cluster, graph_item=graph_item) # Apply synchronizers if self._num_local_replicas >= 1: new_graph_item = self._in_graph_apply(new_graph_item) logging.debug( 'Successfully applied local in-graph replication') visualization_util.log_graph(new_graph_item.graph, '2-after-in-graph') if self._num_workers >= 1: new_graph_item = self._between_graph_apply(new_graph_item) logging.debug('Successfully applied between-graph replication') final_item = new_graph_item logging.info('Successfully built the distributed graph.') visualization_util.log_graph(graph=final_item.graph, name='3-transformed') return final_item
def _parse_node(self, node, num_nodes): host_address = node['address'] if is_loopback_address(host_address) and num_nodes > 1: raise ValueError( "Can't (currently) use a loopback address when there are multiple nodes." ) if node.get('chief') or num_nodes == 1: # 2 cases for marking this node as chief: # 1) The node was marked as chief # 2) If there is only one node, it is chief by default logging.info("Chief: %s" % host_address) self.__chief_address = host_address host_cpu = DeviceSpec(host_address, device_index=0) self._add_device(host_cpu) # handle any other CPUs when GPU is unavailable if len(node.get('gpus', [])) == 0: for cpu_index in set(sorted(node.get('cpus', []))) - {0}: cpu = DeviceSpec(host_address, host_cpu, DeviceType.CPU, cpu_index) self._add_device(cpu) # handle GPUs for gpu_index in set(sorted(node.get('gpus', []))): gpu = DeviceSpec(host_address, host_cpu, DeviceType.GPU, gpu_index) self._add_device(gpu) self.__ssh_group[host_address] = node.get('ssh_config') if self.__ssh_group[ host_address] is None and self.__chief_address != host_address: raise ValueError( "Need to define SSH groups for all non-chief nodes.") # handle network bandwidth (optional) if node.get('network_bandwidth'): self.__network_bandwidth[host_address] = node.get( 'network_bandwidth') else: logging.debug( 'The bandwidth for {} is undefined and set as default (1 GBE). ' 'Caution: AutoStrategy might be inaccurate.'.format( host_address)) self.__network_bandwidth[host_address] = 1
def terminate(self): """Terminate.""" logging.debug('Terminating cluster...') for p in self.subprocesses: os.killpg(os.getpgid(p.pid), signal.SIGTERM)
def start(self): """ Start tf.servers on all nodes. Note that this only runs (and only should run) on the chief node. """ # pylint: disable=import-outside-toplevel from autodist.utils import server_starter # atexit registration should be placed # - before the beginning of the start # (to ensure the clean termination if the start fails in its half way); and # - at the same module as the start # (to follow the python assumption that # lower level modules will normally be imported # before higher level modules and thus must be cleaned up later). atexit.register(self.terminate) envs = {ENV.AUTODIST_MIN_LOG_LEVEL.name: 'ERROR'} envs = ['{}={}'.format(k, v) for k, v in envs.items()] module_name = server_starter.__name__ module_file = server_starter.__file__ for job_name, tasks in self.cluster_spec.items(): for task_index, full_address in enumerate(tasks): address = full_address.split(':')[0] args = [ '--job_name=%s' % job_name, '--task_index=%d' % task_index, '--cpu_device_num=%d' % len(self._cpu_devices[address]) ] if address in self._gpu_devices: envs_cuda = [] else: envs_cuda = ['CUDA_VISIBLE_DEVICES=""'] if self.is_chief(address): json.dump( self.cluster_spec, open( os.path.join(DEFAULT_WORKING_DIR, 'cluster_spec.json'), 'w+')) cmd = envs + envs_cuda + [ sys.executable, '-m', module_name ] + args # pylint: disable=subprocess-popen-preexec-fn proc = subprocess.Popen(' '.join(cmd), shell=True, preexec_fn=os.setsid) self.subprocesses.append(proc) # The above line immediately follows the Popen # to ensure no gap for termination failure due to the empty proc list. logging.debug( '$ local tf.server started at {}: job_name={} task_index={}' .format(full_address, job_name, task_index)) else: # remote self.remote_pre_start_tf_server( address, tf_server_starter_filepath=module_file) file = os.path.join(DEFAULT_WORKING_DIR, os.path.basename(module_file)) bash = envs + envs_cuda + ['python', '-u', file] + args logging.info("Launching tf.server on %s" % address) proc = self.remote_exec(bash, hostname=address) # The above line immediately follows the Popen # to ensure no gap for termination failure due to the empty proc list. self.subprocesses.append(proc)
def _del(sess=_distributed_session): """Enforce the sess to be closed before the cluster termination in the atexit stack.""" sess.close() logging.debug('Closing session...')
def join(self): """Wait for all subprocesses of remote workers to be completed.""" logging.debug('Joining workers...') for t in self.threads: t.join()
def _remap_fetch(self, fetch): """ Remap the user-provided fetches to the right list of fetches after graph transformations. Cases: * If original fetch exists (which is not affected by graph transformation), fetch the original. * Otherwise, for fetches that are train_ops, fetch them on all replicas; * for other fetches, only fetch it on master replica. * For example, for partitioned vars, it corresponds to the concat one as_tensor on the first replica. """ _remap_element = self._remap_element fetch_type = type(fetch) fetch_name = fetch if isinstance(fetch, str) else fetch.name contract_fn = lambda fetched_vals: fetched_vals[0] # noqa: E731 try: transformed_fetch = [_remap_element(fetch_type, fetch_name)] except KeyError: master_replica_name = ops.prepend_name_scope( fetch_name, replica_prefix(0)) master_replica_fetch = _remap_element(fetch_type, master_replica_name) polymorphic_dim = self._polymorphic_dim(master_replica_fetch) def is_train_op(op): # In TF2: train_op as AssignAddVariableOp # In TF1 (being deprecated): no_op with a groups of stateful ops as control dependencies # TODO(unless deprecating): make the checking as strict as possible return isinstance( op, ops.Operation) and (op.op_def.is_stateful or op.op_def.name == 'NoOp') if is_train_op(master_replica_fetch): transformed_fetch = [ _remap_element( fetch_type, ops.prepend_name_scope(fetch_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] #################################################################### # # For Debugging Local Replicas #################################################################### # transformed_fetch = [ # self._graph_item.graph.as_graph_element('AutoDist-Replica-0/emb/part_0_take_grad') # ] # transformed_fetch = [ # _remap_element(ops.Tensor, ops.prepend_name_scope( # 'Mean:0', # replica_prefix(i))) # for i in range(self._graph_transformer.num_local_replicas) # ] # transformed_fetch = [_remap_element(ops.Tensor, # ops.prepend_name_scope( # 'sampled_softmax_loss/embedding_lookup:0', # replica_prefix(1) # ) # )] #################################################################### logging.debug('Fetch mapped from {} to {}'.format( fetch, transformed_fetch)) elif polymorphic_dim: transformed_fetch = [ _remap_element( fetch_type, ops.prepend_name_scope(fetch_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] contract_fn = lambda fetch_vals: np.concatenate( fetch_vals, axis=polymorphic_dim) # noqa: E731 else: transformed_fetch = [master_replica_fetch] return transformed_fetch, contract_fn