def average_parameters(self): r""" Averages parameters if ``step`` is no less than ``warmup_steps`` and it can be divided by ``period``, where ``step`` is increased by 1 at each iteration in the training loop. """ if self.step >= self.warmup_steps and (self.step - self.warmup_steps) % self.period == 0: utils.average_parameters(self.module, self.process_group) self.step += 1
def average_parameters(self, step: int): r""" Averages parameters if the given step is in less than ``warmup_steps``, or it can be divided by ``period``. Args: step (int): Training step. """ if step < self.warmup_steps or step % self.period == 0: utils.average_parameters(self.module, self.process_group)
def average_parameters(self, params): r""" Averages parameters if ``step`` is no less than ``warmup_steps`` and it can be divided by a period in the keys of ``period_process_group_dict``, where ``step`` is increased by 1 at each iteration in the training loop. If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``, only the largest period is used, and the corresponding process group is used for averaging parameters. """ if self.step >= self.warmup_steps: found, group = self._find_process_group() if found: utils.average_parameters(iter(params), group) self.step += 1