Ejemplo n.º 1
0
    def forward_pre_hook(self, bn: BatchNorm, inputs: Tuple[Tensor,
                                                            ...]) -> None:
        if not (bn.training and bn.track_running_stats):
            return

        # Don't track the running stats of this batch. It is already deferred.
        bn.track_running_stats = False
        bn.momentum_ = bn.momentum
        bn.momentum = None

        # Skip if this forward pass is triggered by checkpoint recomputation.
        if is_recomputing():
            return

        input, = inputs

        # Detach from the autograd graph.
        input = input.detach()

        # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
        dim = [0]
        dim.extend(range(2, input.dim()))

        bn.sum += input.sum(dim)
        bn.sum_squares += (input**2).sum(dim)

        size = input.size().numel() / input.size(1)
        bn.counter += size

        # Enable the backward hook.
        self.tracked = True
Ejemplo n.º 2
0
    def forward(self, input: Tensor) -> Tensor:  # type: ignore
        if not self.training:
            # Don't train parameters on the evaluation mode.
            return F.batch_norm(
                input,
                running_mean=self.running_mean,
                running_var=self.running_var,
                weight=self.weight,
                bias=self.bias,
                training=False,
                momentum=0.0,
                eps=self.eps,
            )

        if not is_recomputing():
            # Track a micro-batch on the training mode
            # but not under a recomputation.
            tracked_enough = self._track(input)

            # Update the running statistics for a mini-batch
            # if it has tracked enough micro-batches.
            if tracked_enough:
                self._commit()

        # Normalize a micro-batch and train the parameters.
        return F.batch_norm(
            input,
            running_mean=None,
            running_var=None,
            weight=self.weight,
            bias=self.bias,
            training=True,
            momentum=0.0,
            eps=self.eps,
        )
Ejemplo n.º 3
0
 def forward(self, input):
     logs.append((is_checkpointing(), is_recomputing()))
     return input