def _register_hooks(self): if self._num_groups > 0: p_list = [] # Get list of parameters with grads for param_group in self.param_groups: for p in param_group['params']: if p.requires_grad: p_list.append(p) # To ensure parameter order and group formation is consistent, broadcast p_list order # from rank 0 and use for every worker p_list_names = [self._parameter_names.get(p) for p in p_list] p_list_names = broadcast_object(p_list_names, root_rank=0) p_list = sorted( p_list, key=lambda p: p_list_names.index(self._parameter_names.get(p))) # Form groups p_groups = split_list(p_list, self._num_groups) p_groups = [tuple(p) for p in p_groups] for group in p_groups: for p in group: self._p_to_group[p] = group self._group_counts[group] = 0 for param_group in self.param_groups: for p in param_group['params']: if p.requires_grad: p.grad = p.data.new(p.size()).zero_() self._requires_update.add(p) p_tmp = p.expand_as(p) grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(self._make_hook(p)) self._grad_accs.append(grad_acc)
def sync(self): # Get the set of processed indices from all workers world_processed_indices = _union( allgather_object(self.value.processed_indices)) # Replace local processed indices with global indices state_dict = self.value.state_dict() state_dict['processed_indices'] = world_processed_indices # Broadcast and load the state to make sure we're all in sync self.value.load_state_dict(broadcast_object(state_dict))
def _register_hooks(self): if self._groups is not None: p_list = [] # Get list of parameters with grads for param_group in self.param_groups: for p in param_group['params']: if p.requires_grad: p_list.append(p) # To ensure parameter order and group formation is consistent, broadcast p_list order # from rank 0 and use for every worker p_list_names = [self._parameter_names.get(p) for p in p_list] p_list_names = broadcast_object(p_list_names, root_rank=0, process_set=self.process_set) p_list = sorted( p_list, key=lambda p: p_list_names.index(self._parameter_names.get(p))) # Form groups if isinstance(self._groups, list): p_groups = [] grouped_id = set() p_list_ids = [id(p) for p in p_list] for group in self._groups: p_groups.append([p for p in group if id(p) in p_list_ids]) for p in p_groups[-1]: grouped_id.add(id(p)) for p in p_list: if id(p) not in grouped_id: p_groups.append([p]) else: p_groups = split_list(p_list, self._groups) p_groups = [tuple(p) for p in p_groups] for group in p_groups: for p in group: self._p_to_group[p] = group self._group_counts[group] = 0 for param_group in self.param_groups: for p in param_group['params']: if p.requires_grad: self._requires_update.add(p) p_tmp = p.expand_as(p) grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(self._make_hook(p)) self._grad_accs.append(grad_acc)
def broadcast(self): broadcast_parameters(self._model.state_dict(), root_rank=0) broadcast_optimizer_state(self._optimizer, root_rank=0) self.global_completed_batch_num = broadcast_object( self.global_completed_batch_num, name="GlobalCompletedBatchNum")
def sync(self): state_dict = self.value.state_dict() # Broadcast and load the state to make sure we're all in sync self.value.load_state_dict(broadcast_object(state_dict))