def forward_pre_hook(self, bn: BatchNorm, inputs: Tuple[Tensor, ...]) -> None: if not (bn.training and bn.track_running_stats): return # Don't track the running stats of this batch. It is already deferred. bn.track_running_stats = False bn.momentum_ = bn.momentum bn.momentum = None # Skip if this forward pass is triggered by checkpoint recomputation. if is_recomputing(): return input, = inputs # Detach from the autograd graph. input = input.detach() # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. dim = [0] dim.extend(range(2, input.dim())) bn.sum += input.sum(dim) bn.sum_squares += (input**2).sum(dim) size = input.size().numel() / input.size(1) bn.counter += size # Enable the backward hook. self.tracked = True
def forward_hook(self, bn: BatchNorm, input: Tensor, output: Tensor) -> None: # Any internal state modified by this hook should not be visible to users. bn.track_running_stats = True try: bn.momentum = self.momentum except AttributeError: pass
def backward_hook(self, bn: BatchNorm, grad_input: Tensor, grad_output: Tensor) -> None: # pragma: no cover if not self.tracked: return new_mean = bn.sum / bn.counter new_var = bn.sum_squares / bn.counter - new_mean**2 # Calculate the exponential moving average here. bn.running_mean = bn.running_mean * ( 1 - bn.momentum) + new_mean * bn.momentum bn.running_var = bn.running_var * (1 - bn.momentum) + new_var * bn.momentum bn.sum.zero_() bn.sum_squares.zero_() bn.counter.zero_() # Disable the backward hook until the next forward pass. self.tracked = False
def __call__(self, bn: BatchNorm) -> None: if not bn.track_running_stats or bn.momentum is None: # The given batch norm doesn't track running stats. return bn.register_buffer('sum', torch.zeros_like(bn.running_mean)) bn.register_buffer('sum_squares', torch.zeros_like(bn.running_var)) bn.register_buffer('counter', torch.tensor(0, dtype=torch.long)) bn.register_forward_pre_hook(self.forward_pre_hook) bn.register_forward_hook(self.forward_hook) bn.register_backward_hook(self.backward_hook)
def forward_hook(self, bn: BatchNorm, input: Tensor, output: Tensor) -> None: # Any internal state modified by this hook should not be visible to users. bn.track_running_stats = True
def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None: if torch.is_grad_enabled(): return module.track_running_stats = module._track_running_stats_backup