Esempio n. 1
0
    def remote_exec(self, args, hostname):
        """
        Execute a bash script remotely.

        Args:
            args (list): bash commands
            hostname (str): host name or address

        Returns:
            Process: process handle
        """
        cmd_list = []
        ssh_config = self._ssh_conf[hostname]
        if ssh_config.python_venv:
            cmd_list.append('%s;' % ssh_config.python_venv)
        if ssh_config.env:
            cmd_list.extend(
                ['%s=%s' % (k, v) for k, v in ssh_config.env.items()])
        full_cmd = ' '.join(cmd_list + args)

        remote_cmd = 'ssh -i {} -o StrictHostKeyChecking=no -tt -p {} {}@{} \'bash -c "{}"\' </dev/null' \
            .format(ssh_config.key_file, ssh_config.port, ssh_config.username, hostname, full_cmd)

        logging.debug('$ %s' % remote_cmd)

        if ENV.AUTODIST_DEBUG_REMOTE.val:
            return None

        # pylint: disable=subprocess-popen-preexec-fn
        proc = subprocess.Popen(remote_cmd, shell=True, preexec_fn=os.setsid)
        return proc
Esempio n. 2
0
 def _compile_strategy(self, strategy):
     logging.debug('Raw strategy: %s' % strategy)
     device_resolver = DeviceResolver(self._cluster)
     compiled_strategy = base.StrategyCompiler(self._original_graph_item) \
         .set_device_resolver(device_resolver.resolve_to_device_str) \
         .compile(strategy)
     logging.info('Compiled strategy: %s' % compiled_strategy)
     return compiled_strategy
Esempio n. 3
0
    def _apply(self, *args, **kwargs):
        """
        Apply replication to a graph.

        Returns:
            GraphItem
        """
        new_graph_item = self._graph_item
        if self._num_local_replicas >= 1:
            new_graph_item = self.replicate(new_graph_item)
            logging.debug('Successfully replicated operations')
        return new_graph_item
Esempio n. 4
0
 def wrapper(*args, **kwargs):
     # Assume grads_and_vars is an iterable of tuples
     # Materialize here because in case it's a generator, we need to be able to iterate multiple times
     grads_and_vars = list(kwargs.get('grads_and_vars') or args[1])
     grads, variables = map(list, zip(*grads_and_vars))
     if _default_graph_item and kwargs.pop('update', True):
         _default_graph_item.extend_gradient_info(grads, variables)
         logging.debug('Registered grads: \n {} with targets: \n {}'.format(
             grads, variables))
     args = (args[0], grads_and_vars
             )  # Replace possible generator with definite list
     return fn(*args, **kwargs)
Esempio n. 5
0
 def patch_optimizers():
     """Patch all instances of OptimizerV2 for us to store optimizer and gradient information."""
     q = deque(
         chain(OptimizerV2.__subclasses__(), OptimizerV1.__subclasses__()))
     while q:
         subclass = q.popleft()
         q.extend(list(subclass.__subclasses__()))
         subclass.__init__ = wrap_optimizer_init(subclass.__init__)
         subclass.apply_gradients = wrap_optimizer_apply_gradient(
             subclass.apply_gradients)
         logging.debug('Optimizer type: %s has been patched' %
                       subclass.__name__)
Esempio n. 6
0
def log_graph(graph, name):
    """
    Log the graph on Tensorboard.

    Args:
        graph: the tensorflow graph to be plotted on Tensorboard
        name: the name of the graph
    """
    directory = os.path.join(autodist.const.DEFAULT_WORKING_DIR, "graphs")
    os.makedirs(directory, exist_ok=True)
    p = os.path.join(directory, name)
    writer.FileWriter(p, graph=graph)
    logging.debug('Graph summary written to: %s' % p)
Esempio n. 7
0
    def __init__(self, key, graph_item, config, cluster):
        super().__init__(key)
        self._graph_item = graph_item
        self._cluster = cluster

        self._replica_devices = {device_spec.DeviceSpecV2.from_string(s) for s in config}

        self._local_canonical_replica_devices = sorted({
            d.to_string() for d in self._replica_devices
            if self._cluster.get_local_address() == cluster.get_address_from_task(d.job, d.task)
        })
        logging.debug('Local replica devices: {}'.format(self._local_canonical_replica_devices))
        self._num_local_replicas = len(self._local_canonical_replica_devices)

        self._local_worker_id = self._cluster.get_local_worker_task_index()
        self._local_worker_device = '/job:worker/task:{}'.format(self._local_worker_id)
Esempio n. 8
0
    def _initialize_synchronizers(self):
        self._synchronizers = {}
        for node in self._strategy.node_config:
            partitioner = getattr(node, 'partitioner')
            if partitioner:
                for part in node.part_config:
                    self._synchronizers[part.var_name] = \
                        Synchronizer.create(part.WhichOneof('synchronizer'),
                                            getattr(part, part.WhichOneof('synchronizer')))
            else:
                self._synchronizers[node.var_name] = \
                    Synchronizer.create(node.WhichOneof('synchronizer'),
                                        getattr(node, node.WhichOneof('synchronizer')))

        config = self._strategy.graph_config.replicas
        replica_devices = {
            device_spec.DeviceSpecV2.from_string(s)
            for s in config
        }
        replica_hosts = {
            self._cluster.get_address_from_task(d.job, d.task)
            for d in replica_devices
        }
        self._num_workers = len(replica_hosts)

        local_canonical_replica_devices = sorted({
            d.to_string()
            for d in replica_devices if self._cluster.get_local_address() ==
            self._cluster.get_address_from_task(d.job, d.task)
        })
        logging.debug('Local replica devices: {}'.format(
            local_canonical_replica_devices))
        self._num_local_replicas = len(local_canonical_replica_devices)

        local_worker_id = self._cluster.get_local_worker_task_index()
        local_worker_device = '/job:worker/task:{}'.format(local_worker_id)

        for synchronizer in self._synchronizers.values():
            synchronizer.assign_cluster_information(
                num_workers=self._num_workers,
                num_replicas=self._num_local_replicas,
                worker_device=local_worker_device,
                worker_id=local_worker_id,
                canonical_replica_devices=sorted(
                    {d.to_string()
                     for d in replica_devices}),
                is_chief=self._cluster.is_chief())
Esempio n. 9
0
    def wrapper(*args, **kwargs):
        # args[0] should be `self`, which is an object of type == optimizer class
        containing_class = type(args[0])
        class_name = containing_class.__name__

        # For calls like super(AdamWeightDecay, self).__init__(*args, **kwargs), the containing_class.__name__
        # returns the current class (AdamWeightDecay) instead of the parent class (Adam).
        # Avoid patching this pattern by checking fn.__qualname__.
        if not fn.__qualname__.startswith(class_name):
            return fn(*args, **kwargs)
        if _default_graph_item and kwargs.pop('update', True):
            _default_graph_item.extend_optimizer_info(containing_class, *args,
                                                      **kwargs)
            logging.debug(
                'Registered optimizer: {} \nwith args: {} \nkwargs: {}'.format(
                    class_name, args, kwargs))
        return fn(*args, **kwargs)
Esempio n. 10
0
def _clean_stale_servers():
    # pylint: disable=anomalous-backslash-in-string
    cmd = """ps aux | awk "/{}/ && !/ssh/ && ! /{}/ && ! /{}/" | awk "{{print \$2}}" | xargs kill -9"""  # noqa: W605
    # Processes of | the local stale servers && excluding the current starter's pid && ppid | keep pids | kill them
    cmd = cmd.format(
        os.path.splitext(os.path.basename(__file__))[0], os.getpid(),
        os.getppid())
    local_cmd = "bash -c '{}'".format(cmd)
    logging.debug('>>> {}'.format(local_cmd))
    try:
        output = subprocess.check_output(local_cmd,
                                         shell=True,
                                         stderr=subprocess.STDOUT)
        logging.debug('>>> {}'.format(output.decode('utf-8')))
    except subprocess.CalledProcessError as e:
        if e.returncode != 123:  # No stale process to kill
            raise
Esempio n. 11
0
    def transform(self):
        """Call graph transformer to transform a graph item based on strategy and cluster."""
        logging.info(
            'Transforming the original graph to a distributed graph...')
        with context.graph_mode():
            graph_item = self.graph_item
            # Ensure the transformation happens under graph mode, no matter the outer mode is under eager or graph.

            visualization_util.log_graph(graph=graph_item.graph,
                                         name='0-original')

            graph_item, self._strategy.node_config = VariablePartitioner.apply(
                self._strategy.node_config, graph_item)

            visualization_util.log_graph(graph=graph_item.graph,
                                         name='1-after-partition')

            # Create Synchronizers for each node in the strategy
            self._initialize_synchronizers()

            # Replicate the graph (both in-graph and between-graph)
            new_graph_item = Replicator.apply(
                config=self._strategy.graph_config.replicas,
                cluster=self._cluster,
                graph_item=graph_item)

            # Apply synchronizers
            if self._num_local_replicas >= 1:
                new_graph_item = self._in_graph_apply(new_graph_item)
                logging.debug(
                    'Successfully applied local in-graph replication')
                visualization_util.log_graph(new_graph_item.graph,
                                             '2-after-in-graph')

            if self._num_workers >= 1:
                new_graph_item = self._between_graph_apply(new_graph_item)
                logging.debug('Successfully applied between-graph replication')

            final_item = new_graph_item
            logging.info('Successfully built the distributed graph.')
            visualization_util.log_graph(graph=final_item.graph,
                                         name='3-transformed')

        return final_item
Esempio n. 12
0
 def _parse_node(self, node, num_nodes):
     host_address = node['address']
     if is_loopback_address(host_address) and num_nodes > 1:
         raise ValueError(
             "Can't (currently) use a loopback address when there are multiple nodes."
         )
     if node.get('chief') or num_nodes == 1:
         # 2 cases for marking this node as chief:
         # 1) The node was marked as chief
         # 2) If there is only one node, it is chief by default
         logging.info("Chief: %s" % host_address)
         self.__chief_address = host_address
     host_cpu = DeviceSpec(host_address, device_index=0)
     self._add_device(host_cpu)
     # handle any other CPUs when GPU is unavailable
     if len(node.get('gpus', [])) == 0:
         for cpu_index in set(sorted(node.get('cpus', []))) - {0}:
             cpu = DeviceSpec(host_address, host_cpu, DeviceType.CPU,
                              cpu_index)
             self._add_device(cpu)
     # handle GPUs
     for gpu_index in set(sorted(node.get('gpus', []))):
         gpu = DeviceSpec(host_address, host_cpu, DeviceType.GPU, gpu_index)
         self._add_device(gpu)
     self.__ssh_group[host_address] = node.get('ssh_config')
     if self.__ssh_group[
             host_address] is None and self.__chief_address != host_address:
         raise ValueError(
             "Need to define SSH groups for all non-chief nodes.")
     # handle network bandwidth (optional)
     if node.get('network_bandwidth'):
         self.__network_bandwidth[host_address] = node.get(
             'network_bandwidth')
     else:
         logging.debug(
             'The bandwidth for {} is undefined and set as default (1 GBE). '
             'Caution: AutoStrategy might be inaccurate.'.format(
                 host_address))
         self.__network_bandwidth[host_address] = 1
Esempio n. 13
0
 def terminate(self):
     """Terminate."""
     logging.debug('Terminating cluster...')
     for p in self.subprocesses:
         os.killpg(os.getpgid(p.pid), signal.SIGTERM)
Esempio n. 14
0
    def start(self):
        """
        Start tf.servers on all nodes.

        Note that this only runs (and only should run) on the chief node.
        """
        # pylint: disable=import-outside-toplevel
        from autodist.utils import server_starter

        # atexit registration should be placed
        #   - before the beginning of the start
        #   (to ensure the clean termination if the start fails in its half way); and
        #   - at the same module as the start
        #   (to follow the python assumption that
        #   lower level modules will normally be imported
        #   before higher level modules and thus must be cleaned up later).
        atexit.register(self.terminate)
        envs = {ENV.AUTODIST_MIN_LOG_LEVEL.name: 'ERROR'}
        envs = ['{}={}'.format(k, v) for k, v in envs.items()]
        module_name = server_starter.__name__
        module_file = server_starter.__file__

        for job_name, tasks in self.cluster_spec.items():
            for task_index, full_address in enumerate(tasks):
                address = full_address.split(':')[0]
                args = [
                    '--job_name=%s' % job_name,
                    '--task_index=%d' % task_index,
                    '--cpu_device_num=%d' % len(self._cpu_devices[address])
                ]
                if address in self._gpu_devices:
                    envs_cuda = []
                else:
                    envs_cuda = ['CUDA_VISIBLE_DEVICES=""']
                if self.is_chief(address):
                    json.dump(
                        self.cluster_spec,
                        open(
                            os.path.join(DEFAULT_WORKING_DIR,
                                         'cluster_spec.json'), 'w+'))
                    cmd = envs + envs_cuda + [
                        sys.executable, '-m', module_name
                    ] + args
                    # pylint: disable=subprocess-popen-preexec-fn
                    proc = subprocess.Popen(' '.join(cmd),
                                            shell=True,
                                            preexec_fn=os.setsid)
                    self.subprocesses.append(proc)
                    # The above line immediately follows the Popen
                    # to ensure no gap for termination failure due to the empty proc list.
                    logging.debug(
                        '$ local tf.server started at {}: job_name={} task_index={}'
                        .format(full_address, job_name, task_index))
                else:  # remote
                    self.remote_pre_start_tf_server(
                        address, tf_server_starter_filepath=module_file)
                    file = os.path.join(DEFAULT_WORKING_DIR,
                                        os.path.basename(module_file))
                    bash = envs + envs_cuda + ['python', '-u', file] + args
                    logging.info("Launching tf.server on %s" % address)
                    proc = self.remote_exec(bash, hostname=address)
                    # The above line immediately follows the Popen
                    # to ensure no gap for termination failure due to the empty proc list.
                    self.subprocesses.append(proc)
Esempio n. 15
0
 def _del(sess=_distributed_session):
     """Enforce the sess to be closed before the cluster termination in the atexit stack."""
     sess.close()
     logging.debug('Closing session...')
Esempio n. 16
0
 def join(self):
     """Wait for all subprocesses of remote workers to be completed."""
     logging.debug('Joining workers...')
     for t in self.threads:
         t.join()
Esempio n. 17
0
    def _remap_fetch(self, fetch):
        """
        Remap the user-provided fetches to the right list of fetches after graph transformations.

        Cases:
            * If original fetch exists (which is not affected by graph transformation), fetch the original.
            * Otherwise, for fetches that are train_ops, fetch them on all replicas;
            * for other fetches, only fetch it on master replica.
                * For example, for partitioned vars, it corresponds to the concat one as_tensor on the first replica.
        """
        _remap_element = self._remap_element
        fetch_type = type(fetch)
        fetch_name = fetch if isinstance(fetch, str) else fetch.name
        contract_fn = lambda fetched_vals: fetched_vals[0]  # noqa: E731
        try:
            transformed_fetch = [_remap_element(fetch_type, fetch_name)]
        except KeyError:
            master_replica_name = ops.prepend_name_scope(
                fetch_name, replica_prefix(0))
            master_replica_fetch = _remap_element(fetch_type,
                                                  master_replica_name)
            polymorphic_dim = self._polymorphic_dim(master_replica_fetch)

            def is_train_op(op):
                # In TF2: train_op as AssignAddVariableOp
                # In TF1 (being deprecated): no_op with a groups of stateful ops as control dependencies
                # TODO(unless deprecating): make the checking as strict as possible
                return isinstance(
                    op, ops.Operation) and (op.op_def.is_stateful
                                            or op.op_def.name == 'NoOp')

            if is_train_op(master_replica_fetch):
                transformed_fetch = [
                    _remap_element(
                        fetch_type,
                        ops.prepend_name_scope(fetch_name, replica_prefix(i)))
                    for i in range(self._graph_transformer.num_local_replicas)
                ]
                ####################################################################
                # # For Debugging Local Replicas
                ####################################################################
                # transformed_fetch = [
                #     self._graph_item.graph.as_graph_element('AutoDist-Replica-0/emb/part_0_take_grad')
                # ]
                # transformed_fetch = [
                #     _remap_element(ops.Tensor, ops.prepend_name_scope(
                #         'Mean:0',
                #         replica_prefix(i)))
                #     for i in range(self._graph_transformer.num_local_replicas)
                # ]
                # transformed_fetch = [_remap_element(ops.Tensor,
                #     ops.prepend_name_scope(
                #         'sampled_softmax_loss/embedding_lookup:0',
                #         replica_prefix(1)
                #     )
                # )]
                ####################################################################
                logging.debug('Fetch mapped from {} to {}'.format(
                    fetch, transformed_fetch))
            elif polymorphic_dim:
                transformed_fetch = [
                    _remap_element(
                        fetch_type,
                        ops.prepend_name_scope(fetch_name, replica_prefix(i)))
                    for i in range(self._graph_transformer.num_local_replicas)
                ]
                contract_fn = lambda fetch_vals: np.concatenate(
                    fetch_vals, axis=polymorphic_dim)  # noqa: E731
            else:
                transformed_fetch = [master_replica_fetch]
        return transformed_fetch, contract_fn