Ejemplo n.º 1
0
    def _start_reduction_threads(self):
        num_buckets = len(self.bucket_sizes)
        self._reduction_queues = [queue.Queue() for _ in range(num_buckets)]
        self._reduction_threads = []
        self._reduction_streams = [[] for _ in range(num_buckets)]
        self._nccl_streams = []
        self._default_streams = []
        for dev_id in self.device_ids:
            with torch.cuda.device(dev_id):
                # TODO: don't assume we're on a default stream
                self._default_streams.append(torch.cuda.current_stream())
                self._nccl_streams.append(torch.cuda.Stream())
        for reduction_queue, reduction_streams in zip(self._reduction_queues,
                                                      self._reduction_streams):
            for dev_id in self.device_ids:
                with torch.cuda.device(dev_id):
                    reduction_streams.append(torch.cuda.Stream())
            # We only use the first device for distributed reductions
            dist._register_stream(reduction_streams[0])
            group_id = dist.new_group()

            self._reduction_threads.append(
                threading.Thread(target=self._reduction_thread_fn,
                                 args=(reduction_queue, group_id,
                                       self.device_ids, reduction_streams,
                                       self._nccl_streams)))
            self._reduction_threads[-1].start()
Ejemplo n.º 2
0
    def _gossip_target(dist_config, gossip_flag, train_flag, gossip_params,
                       gossip_device_buffer):
        """ Gossip thread, which performs push-sum on model params """
        logger = make_logger(dist_config['rank'], dist_config['verbose'])

        if dist_config['comm_device'].type != 'cpu':
            gossip_stream = torch.cuda.Stream()
            dist._register_stream(gossip_stream)
        else:
            gossip_stream = torch.cuda.current_stream()

        gossip_flag.set()

        # gossip loop
        while True:
            train_flag.wait()
            logger.debug('received train-flag')
            try:
                with torch.cuda.stream(gossip_stream):
                    # construct gossip tensor
                    out_msg = _flatten_tensors(gossip_params)
                    dist.all_reduce(out_msg)
                    # update gossip variables with result
                    for r, g in zip(_unflatten_tensors(out_msg, gossip_params),
                                    gossip_device_buffer):
                        g.copy_(r, non_blocking=True)
            except RuntimeError as e:
                logger.warning('received runtime error {}'.format(e))
            finally:
                # give main thread go-ahead to read our gossip buffer
                train_flag.clear()
                gossip_flag.set()
Ejemplo n.º 3
0
    def _start_reduction_threads(self):
        num_buckets = len(self.bucket_sizes)
        self._reduction_queues = [queue.Queue() for _ in range(num_buckets)]
        self._reduction_threads = []
        self._reduction_streams = [[] for _ in range(num_buckets)]
        self._nccl_streams = []
        self._default_streams = []
        for dev_id in self.device_ids:
            with torch.cuda.device(dev_id):
                # TODO: don't assume we're on a default stream
                self._default_streams.append(torch.cuda.current_stream())
                self._nccl_streams.append(torch.cuda.Stream())
        for reduction_queue, reduction_streams in zip(self._reduction_queues, self._reduction_streams):
            for dev_id in self.device_ids:
                with torch.cuda.device(dev_id):
                    reduction_streams.append(torch.cuda.Stream())
            # We only use the first device for distributed reductions
            dist._register_stream(reduction_streams[0])

            if dist._backend == dist.dist_backend.NCCL:
                group_id = dist.group.WORLD
            else:
                group_id = dist.new_group()

            self._reduction_threads.append(threading.Thread(
                target=self._reduction_thread_fn,
                args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams)))
            self._reduction_threads[-1].daemon = True
            self._reduction_threads[-1].start()
    def _gossip_target(dist_config, gossip_flag, train_flag, gossip_lock,
                       gossip_params, gossip_device_buffer, gossip_ps_weight,
                       gossip_ps_factor):
        """ Gossip thread, which performs push-sum on model params """
        logger = make_logger(dist_config['rank'], dist_config['verbose'])

        if dist_config['comm_device'].type != 'cpu':
            gossip_stream = torch.cuda.Stream()
            dist._register_stream(gossip_stream)
        else:
            gossip_stream = torch.cuda.current_stream()

        # init gossip instance
        if dist_config['push_sum']:
            gossiper = PushSum(_flatten_tensors(gossip_params),
                               device=dist_config['comm_device'],
                               graph=dist_config['graph'],
                               mixing=dist_config['mixing'],
                               rank=dist_config['rank'],
                               world_size=dist_config['world_size'],
                               logger=logger)
        else:
            gossiper = PushPull(_flatten_tensors(gossip_params),
                                device=dist_config['comm_device'],
                                graph=dist_config['graph'],
                                rank=dist_config['rank'],
                                world_size=dist_config['world_size'],
                                mixing=dist_config['mixing'],
                                logger=logger)
        dist_config['graph'] = gossiper._graph_manager
        dist_config['mixing'] = gossiper._mixing_manager
        dist_config['gossiper'] = gossiper
        gossip_ps_factor[0] = gossiper.mixing_weights['lo']
        gossip_flag.set()

        # gossip loop
        while True:
            train_flag.wait()
            logger.debug('received train-flag')
            try:
                with torch.cuda.stream(gossip_stream):
                    # construct gossip tensor
                    out_msg = _flatten_tensors(gossip_params)
                    # gossip step
                    with gossip_lock:
                        in_msg, psw = gossiper.mix(out_msg,
                                                   gossip_ps_weight[0],
                                                   residual=True)
                        gossip_ps_factor[0] = gossiper.mixing_weights['lo']
                    # update gossip variables with residuals
                    gossip_ps_weight[0] = psw
                    for r, g in zip(_unflatten_tensors(in_msg, gossip_params),
                                    gossip_device_buffer):
                        g.copy_(r, non_blocking=True)
            except RuntimeError as e:
                logger.warning('received runtime error {}'.format(e))
                gossiper.clean_msg_buffers_()
                gossip_ps_weight[0] = -1
            finally:
                # give main thread go-ahead to read our gossip buffer
                train_flag.clear()
                gossip_flag.set()