コード例 #1
0
ファイル: optimizer.py プロジェクト: winnerineast/byteps
    def __init__(self, model, byteps_opt, num_steps=10**6):
        """Construct a new ScheduledOptimizer, which uses byteps optimizer under the hood for averaging gradients
         across all workers.
        Args:
            model: The training model. ByteScheduler uses the model object to register hooks.
            byteps_opt: Optimizer to use for averaging gradients and applying updates.
            num_steps: The maximum number of training steps. ByteScheduler needs to know when to stop cross-iteration
            scheduling.
        """
        self._model = model
        self._opt = byteps_opt
        self._logger = logging.getLogger("ByteScheduler")
        self._logger.debug("byteps size {}, rank {}".format(size(), rank()))
        self._desc = "rank {}".format(rank())

        # Track training steps
        self._step = 0
        self._final_step = num_steps

        # Use lock to block the forward propagation of each parameter.
        self._locks = {}
        for param_group in self.param_groups:
            for p in param_group['params']:
                self._locks[p] = threading.Lock()

        if size() > 1:
            self._register_forward_hooks()
            self._register_hooks()

        # Poll whether the tensor push-pull is finished.
        self._event_queue = queue.Queue()
        self._poller = threading.Thread(target=self._poll, args=())
        self._poller.start()
コード例 #2
0
def broadcast_parameters(params, root_rank):
    """
    Broadcasts the parameters from root rank to all other processes.
    Typical usage is to broadcast the `model.state_dict()`,
    `model.named_parameters()`, or `model.parameters()`.
    Arguments:
        params: One of the following:
            - list of parameters to broadcast
            - dict of parameters to broadcast
        root_rank: The rank of the process from which parameters will be
                   broadcasted to all other processes.
    """
    if isinstance(params, dict):
        params = sorted(params.items())
    elif isinstance(params, list):
        # support both named_parameters() and regular parameters()
        params = [p if isinstance(p, tuple) else (None, p) for p in params]
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run synchronous broadcasts.
    for name, p in params:
        # Broadcast is implemented as push + pull in BytePS
        # To make it a real broadcast, we set the non-root tensors all 0.
        if rank() != root_rank:
            p.fill_(0)
        # Remember to disable averaging because we are doing broadcast
        handle = byteps_push_pull(p, average=False, name="Parameter." + name)
        synchronize(handle)