Ejemplo n.º 1
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
    def __init__(self,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super(_LazyBatchNorm, self).__init__(0, eps, momentum, affine,
                                             track_running_stats)
        if self.affine:
            self.weight = UninitializedParameter()
            self.bias = UninitializedParameter()
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer()
            self.running_var = UninitializedBuffer()

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features, ))
                self.bias.materialize((self.num_features, ))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features, ))
                self.running_var.materialize((self.num_features, ))
            self.reset_parameters()
Ejemplo n.º 2
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):

    def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(_LazyBatchNorm, self).__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0, eps, momentum, False, False)
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter()
            self.bias = UninitializedParameter()
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer()
            self.running_var = UninitializedBuffer()
            self.num_batches_tracked = torch.tensor(0, dtype=torch.long)

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features,))
                self.bias.materialize((self.num_features,))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features,))
                self.running_var.materialize((self.num_features,))
            self.reset_parameters()
Ejemplo n.º 3
0
class LazyLinear(LazyModuleMixin, Linear):
    r"""A :class:`torch.nn.Linear` module where `in_features` is inferred.

    In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
    class. They will be initialized after the first call to ``forward`` is done and the
    module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument
    of the :class:`Linear` is inferred from the ``input.shape[-1]``.

    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
    on lazy modules and their limitations.

    Args:
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`


    """

    cls_to_become = Linear  # type: ignore[assignment]
    weight: UninitializedParameter
    bias: UninitializedParameter  # type: ignore[assignment]

    def __init__(self,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        # bias is hardcoded to False to avoid creating tensor
        # that will soon be overwritten.
        super().__init__(0, 0, False)
        self.weight = UninitializedParameter(**factory_kwargs)
        self.out_features = out_features
        if bias:
            self.bias = UninitializedParameter(**factory_kwargs)

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.in_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore[override]
        if self.has_uninitialized_params():
            with torch.no_grad():
                self.in_features = input.shape[-1]
                self.weight.materialize((self.out_features, self.in_features))
                if self.bias is not None:
                    self.bias.materialize((self.out_features, ))
                self.reset_parameters()
Ejemplo n.º 4
0
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):

    weight: UninitializedParameter  # type: ignore[assignment]
    bias: UninitializedParameter  # type: ignore[assignment]

    def __init__(self,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True,
                 device=None,
                 dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_LazyBatchNorm, self).__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0,
            eps,
            momentum,
            False,
            False,
            **factory_kwargs,
        )
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter(**factory_kwargs)
            self.bias = UninitializedParameter(**factory_kwargs)
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer(**factory_kwargs)
            self.running_var = UninitializedBuffer(**factory_kwargs)
            self.num_batches_tracked = torch.tensor(
                0,
                dtype=torch.long,
                **{k: v
                   for k, v in factory_kwargs.items() if k != 'dtype'})

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore[override]
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features, ))
                self.bias.materialize((self.num_features, ))
            if self.track_running_stats:
                self.running_mean.materialize(
                    (self.num_features, ))  # type:ignore[union-attr]
                self.running_var.materialize(
                    (self.num_features, ))  # type:ignore[union-attr]
            self.reset_parameters()
Ejemplo n.º 5
0
class LazyLinear(LazyModuleMixin, Linear):
    r"""A :class:`torch.nn.Linear` module with lazy initialization.

    In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
    class. They will be initialized  after the first call to ``forward`` is done and the 
    module will become a regular :class:`torch.nn.Linear` module.

    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
    on lazy modules and their limitations.

    Args:
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`


    """

    cls_to_become = Linear  # type: ignore[assignment]
    weight: UninitializedParameter

    def __init__(self, out_features: int, bias: bool = True) -> None:
        super().__init__(0, out_features, bias)
        self.weight = UninitializedParameter()    

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.in_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore
        if self.has_uninitialized_params():
            with torch.no_grad():
                self.in_features = input.shape[-1]
                self.weight.materialize((self.out_features, self.in_features))  # type: ignore
                self.reset_parameters()
Ejemplo n.º 6
0
class LazyLin(Lazy):
    hs = Hypers({"d_in", "d_out"}, {"bias": True})

    def __init__(self, d_out=None, ps={}, hs=[], **kw):
        if d_out is not None:
            kw.update(d_out=d_out)
        super().__init__(ps, [self.hs] + hs, **kw)
        cfg = self.cfg
        kw = {"dtype": cfg.dtype, "device": cfg.device}
        self.weight = UninitializedParameter(**kw)
        if cfg.bias:
            self.bias = UninitializedParameter(**kw)
        else:
            self.register_parameter("bias", None)

    def build(self, x):
        cfg = self.cfg
        if not self.is_built():
            with torch.no_grad():
                cfg.d_in = x.shape[-1]
                self.weight.materialize((cfg.d_out, cfg.d_in))
                if cfg.bias:
                    self.bias.materialize((cfg.d_out, ))
                self.reset_params()

    def reset_params(self):
        cfg = self.cfg
        if self.is_built():
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            b = 1 / math.sqrt(cfg.d_in)
            # nn.init.uniform_(self.weight, -b, b)
            if self.bias is not None:
                nn.init.uniform_(self.bias, -b, b)

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)