Exemplo n.º 1
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
    def __init__(self,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super(_LazyBatchNorm, self).__init__(0, eps, momentum, affine,
                                             track_running_stats)
        if self.affine:
            self.weight = UninitializedParameter()
            self.bias = UninitializedParameter()
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer()
            self.running_var = UninitializedBuffer()

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features, ))
                self.bias.materialize((self.num_features, ))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features, ))
                self.running_var.materialize((self.num_features, ))
            self.reset_parameters()
Exemplo n.º 2
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):

    def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(_LazyBatchNorm, self).__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0, eps, momentum, False, False)
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter()
            self.bias = UninitializedParameter()
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer()
            self.running_var = UninitializedBuffer()
            self.num_batches_tracked = torch.tensor(0, dtype=torch.long)

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features,))
                self.bias.materialize((self.num_features,))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features,))
                self.running_var.materialize((self.num_features,))
            self.reset_parameters()
Exemplo n.º 3
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):

    weight: UninitializedParameter  # type: ignore[assignment]
    bias: UninitializedParameter  # type: ignore[assignment]

    def __init__(self,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True,
                 device=None,
                 dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_LazyBatchNorm, self).__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0,
            eps,
            momentum,
            False,
            False,
            **factory_kwargs,
        )
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter(**factory_kwargs)
            self.bias = UninitializedParameter(**factory_kwargs)
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer(**factory_kwargs)
            self.running_var = UninitializedBuffer(**factory_kwargs)
            self.num_batches_tracked = torch.tensor(
                0,
                dtype=torch.long,
                **{k: v
                   for k, v in factory_kwargs.items() if k != 'dtype'})

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore[override]
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features, ))
                self.bias.materialize((self.num_features, ))
            if self.track_running_stats:
                self.running_mean.materialize(
                    (self.num_features, ))  # type:ignore[union-attr]
                self.running_var.materialize(
                    (self.num_features, ))  # type:ignore[union-attr]
            self.reset_parameters()