예제 #1
0
    def _register_hooks(self):

        if self._num_groups > 0:
            p_list = []
            # Get list of parameters with grads
            for param_group in self.param_groups:
                for p in param_group['params']:
                    if p.requires_grad:
                        p_list.append(p)

            # To ensure parameter order and group formation is consistent, broadcast p_list order
            # from rank 0 and use for every worker
            p_list_names = [self._parameter_names.get(p) for p in p_list]
            p_list_names = broadcast_object(p_list_names, root_rank=0)
            p_list = sorted(
                p_list,
                key=lambda p: p_list_names.index(self._parameter_names.get(p)))

            # Form groups
            p_groups = split_list(p_list, self._num_groups)
            p_groups = [tuple(p) for p in p_groups]
            for group in p_groups:
                for p in group:
                    self._p_to_group[p] = group
                self._group_counts[group] = 0

        for param_group in self.param_groups:
            for p in param_group['params']:
                if p.requires_grad:
                    p.grad = p.data.new(p.size()).zero_()
                    self._requires_update.add(p)
                    p_tmp = p.expand_as(p)
                    grad_acc = p_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_hook(p))
                    self._grad_accs.append(grad_acc)
예제 #2
0
    def sync(self):
        # Get the set of processed indices from all workers
        world_processed_indices = _union(
            allgather_object(self.value.processed_indices))

        # Replace local processed indices with global indices
        state_dict = self.value.state_dict()
        state_dict['processed_indices'] = world_processed_indices

        # Broadcast and load the state to make sure we're all in sync
        self.value.load_state_dict(broadcast_object(state_dict))
예제 #3
0
    def _register_hooks(self):
        if self._groups is not None:
            p_list = []
            # Get list of parameters with grads
            for param_group in self.param_groups:
                for p in param_group['params']:
                    if p.requires_grad:
                        p_list.append(p)

            # To ensure parameter order and group formation is consistent, broadcast p_list order
            # from rank 0 and use for every worker
            p_list_names = [self._parameter_names.get(p) for p in p_list]
            p_list_names = broadcast_object(p_list_names,
                                            root_rank=0,
                                            process_set=self.process_set)
            p_list = sorted(
                p_list,
                key=lambda p: p_list_names.index(self._parameter_names.get(p)))

            # Form groups
            if isinstance(self._groups, list):
                p_groups = []
                grouped_id = set()
                p_list_ids = [id(p) for p in p_list]
                for group in self._groups:
                    p_groups.append([p for p in group if id(p) in p_list_ids])
                    for p in p_groups[-1]:
                        grouped_id.add(id(p))
                for p in p_list:
                    if id(p) not in grouped_id:
                        p_groups.append([p])
            else:
                p_groups = split_list(p_list, self._groups)

            p_groups = [tuple(p) for p in p_groups]
            for group in p_groups:
                for p in group:
                    self._p_to_group[p] = group
                self._group_counts[group] = 0

        for param_group in self.param_groups:
            for p in param_group['params']:
                if p.requires_grad:
                    self._requires_update.add(p)
                    p_tmp = p.expand_as(p)
                    grad_acc = p_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_hook(p))
                    self._grad_accs.append(grad_acc)
예제 #4
0
 def broadcast(self):
     broadcast_parameters(self._model.state_dict(), root_rank=0)
     broadcast_optimizer_state(self._optimizer, root_rank=0)
     self.global_completed_batch_num = broadcast_object(
         self.global_completed_batch_num, name="GlobalCompletedBatchNum")
예제 #5
0
파일: state.py 프로젝트: rongou/horovod
    def sync(self):
        state_dict = self.value.state_dict()

        # Broadcast and load the state to make sure we're all in sync
        self.value.load_state_dict(broadcast_object(state_dict))