예제 #1
0
 def __init__(
     self,
     eps=1e-5,
     momentum=0.1,
     affine=True,
     track_running_stats=True,
     device=None,
     dtype=None,
 ) -> None:
     factory_kwargs = {"device": device, "dtype": dtype}
     super(_LazyInstanceNorm, self).__init__(
         # affine and track_running_stats are hardcoded to False to
         # avoid creating tensors that will soon be overwritten.
         0,
         eps,
         momentum,
         False,
         False,
         **factory_kwargs,
     )
     self.affine = affine
     self.track_running_stats = track_running_stats
     if self.affine:
         self.weight = UninitializedParameter(**factory_kwargs)
         self.bias = UninitializedParameter(**factory_kwargs)
     if self.track_running_stats:
         self.running_mean = UninitializedBuffer(**factory_kwargs)
         self.running_var = UninitializedBuffer(**factory_kwargs)
         self.num_batches_tracked = torch.tensor(
             0,
             dtype=torch.long,
             **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
         )
예제 #2
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
    def __init__(self,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super(_LazyBatchNorm, self).__init__(0, eps, momentum, affine,
                                             track_running_stats)
        if self.affine:
            self.weight = UninitializedParameter()
            self.bias = UninitializedParameter()
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer()
            self.running_var = UninitializedBuffer()

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features, ))
                self.bias.materialize((self.num_features, ))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features, ))
                self.running_var.materialize((self.num_features, ))
            self.reset_parameters()
예제 #3
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):

    def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(_LazyBatchNorm, self).__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0, eps, momentum, False, False)
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter()
            self.bias = UninitializedParameter()
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer()
            self.running_var = UninitializedBuffer()
            self.num_batches_tracked = torch.tensor(0, dtype=torch.long)

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features,))
                self.bias.materialize((self.num_features,))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features,))
                self.running_var.materialize((self.num_features,))
            self.reset_parameters()
예제 #4
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):

    weight: UninitializedParameter  # type: ignore[assignment]
    bias: UninitializedParameter  # type: ignore[assignment]

    def __init__(self,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True,
                 device=None,
                 dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_LazyBatchNorm, self).__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0,
            eps,
            momentum,
            False,
            False,
            **factory_kwargs,
        )
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter(**factory_kwargs)
            self.bias = UninitializedParameter(**factory_kwargs)
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer(**factory_kwargs)
            self.running_var = UninitializedBuffer(**factory_kwargs)
            self.num_batches_tracked = torch.tensor(
                0,
                dtype=torch.long,
                **{k: v
                   for k, v in factory_kwargs.items() if k != 'dtype'})

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore[override]
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features, ))
                self.bias.materialize((self.num_features, ))
            if self.track_running_stats:
                self.running_mean.materialize(
                    (self.num_features, ))  # type:ignore[union-attr]
                self.running_var.materialize(
                    (self.num_features, ))  # type:ignore[union-attr]
            self.reset_parameters()
예제 #5
0
 def __init__(self,
              eps=1e-5,
              momentum=0.1,
              affine=True,
              track_running_stats=True):
     super(_LazyBatchNorm, self).__init__(0, eps, momentum, affine,
                                          track_running_stats)
     if self.affine:
         self.weight = UninitializedParameter()
         self.bias = UninitializedParameter()
     if self.track_running_stats:
         self.running_mean = UninitializedBuffer()
         self.running_var = UninitializedBuffer()
예제 #6
0
 def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
     super(_LazyBatchNorm, self).__init__(
         # affine and track_running_stats are hardcoded to False to
         # avoid creating tensors that will soon be overwritten.
         0, eps, momentum, False, False)
     self.affine = affine
     self.track_running_stats = track_running_stats
     if self.affine:
         self.weight = UninitializedParameter()
         self.bias = UninitializedParameter()
     if self.track_running_stats:
         self.running_mean = UninitializedBuffer()
         self.running_var = UninitializedBuffer()
         self.num_batches_tracked = torch.tensor(0, dtype=torch.long)