Ejemplo n.º 1
0
    def apply_chunked_hyperfan_init(self,
                                    method='in',
                                    use_xavier=False,
                                    temb_var=1.,
                                    ext_inp_var=1.,
                                    eps=1e-5,
                                    cemb_normal_init=False):
        r"""Initialize the network using a chunked hyperfan init.

        Inspired by the method
        `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we
        implemented for the full hypernetwork in method
        :meth:`toy_example.hyper_model.HyperNetwork.apply_hyperfan_init`, we
        heuristically developed a better initialization method for chunked
        hypernetworks.

        Unfortunately, the `Hyperfan Init` method does not apply to this kind of
        hypernetwork, since we reuse the same hypernet output head for the whole
        main network.

        Luckily, we can provide a simple heuristic. Similar to
        `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play
        with the variance of the input embeddings to affect the variance of the
        output weights.

        In a chunked hypernetwork, the input for each chunk is identical except
        for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}`
        denote the remaining inputs to the hypernetwork, which are identical
        for all chunks. Then, assuming the hypernetwork was initialized via
        fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}`
        can be written as follows (see documentation of method
        :meth:`toy_example.hyper_model.HyperNetwork.apply_hyperfan_init`):

        .. math::

            \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                \frac{n_c}{n_e+n_c} \text{Var}(c)

        Hence, we can achieve a desired output variance :math:`\text{Var}(v)`
        by initializing the chunk embeddinggs :math:`\mathbf{c}` via the
        following variance:

        .. math::

            \text{Var}(c) = \max \Big\{ 0, \
                \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \
                n_e \text{Var}(e) \big] \Big\}

        Now, one important question remains. How do we pick a desired output
        variance :math:`\text{Var}(v)` for a chunk?

        Note, a chunk may include weights from several layers. The likelihood
        for this to happen depends on the main net architecture and the chunk
        size (see constructor argument ``chunk_dim``). The smaller the chunk
        size, the less likely it is that a chunk will contain elements from
        multiple main net weight tensors.

        In case each chunk would contain only weights from one main net weight
        tensor, we could simply pick the variance :math:`\text{Var}(v)` that
        would have been chosen by a main net initialization method (such as
        Xavier).

        In case a chunk contains contributions from several main net weight
        tensors, we apply the following heuristic. If a chunk contains
        contributions of a set of main network weight tensors
        :math:`W_1, \dots, W_K` with relative contribution sizes\
        :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where
        :math:`n_v` denotes the chunk size and if the corresponding main network
        initialization method would require init varainces
        :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request
        a weighted average as follow:

        .. math::

            \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k)

        What about bias vectors? Usually, the variance analysis applied to
        Xavier or Kaiming init assumes that biases are initialized to zero. This
        is not possible in this setting, as it would require assigning a
        negative variance to :math:`\mathbf{c}`. Instead, we follow the default
        PyTorch initialization (e.g., see method ``reset_parameters`` in class
        :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly
        within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where
        :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of
        initialization corresponds to a variance of
        :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`.

        Warning:
            Note, in order to compute the fan-in of layers with bias vectors, we
            need access to the corresponding weight tensor in the same layer.
            Since there is no clean way of matching a bias shape to its
            corresponging weight tensor shape we use the following heuristic,
            which should be correct for most main networks. We assume that the
            shape directly preceding a bias shape in the constructor argument
            ``target_shapes`` is the corresponding weight tensor.

        Note:
            Constructor argument ``noise_dim`` is automatically considered by
            this method.

        Note:
            We hypernet inputs should be zero mean.

        Warning:
            This method considers all 1D target weight tensors as bias vectors.

        Note:
            To avoid that the variances with which chunks are initialized
            have to be clipped (because they are too small or even negative),
            the variance of the remaining hypernet inputs should be properly
            scaled. In general, one should adhere the following rule

            .. math::

                \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v)

            This method will calculate and print the maximum value that should
            be chosen for :math:`\text{Var}(e)` and will print warnings if
            variances have to be clipped.

        Args:
            method (str): The type of initialization that should be applied.
                Possible options are:

                - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output
                  variances of the hypernetwork should correspond to fan-in
                  variances.
                - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output
                  variances of the hypernetwork should correspond to fan-out
                  variances.
                - ``harmonic``: Use the harmonic mean of the fan-in and fan-out
                  variance as target variance of the hypernetwork output.
            use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``)
                init should be used.
            temb_var (float): The initial variance of the task embeddings.

                .. note::
                    If ``temb_std`` was set in the constructor, then this method
                    will automatically correct the provided ``temb_var`` as 
                    follows: :code:`temb_var += temb_std**2`.
            ext_inp_var (float): The initial variance of the external input.
                Only needs to be specified if external inputs are provided.
                
                .. note::
                    Not supported yet by this hypernetwork type, but should soon
                    be included as a feature.
            eps (float): The minimum variance with which a chunk embedding is
                initialized.
            cemb_normal_init (bool): Use normal init for chunk embeedings
                rather than uniform init.
        """
        if method not in ['in', 'out', 'harmonic']:
            raise ValueError('Invalid value for argument "method".')
        if not self.has_theta:
            raise ValueError('Hypernet without internal weights can\'t be ' +
                             'initialized.')

        ### Compute input variance ###
        # The input variance does not include the variance of chunk embeddings!
        # Instead, it is the varaince of the inputs that are shared across all
        # chunks.
        if self._temb_std != -1:
            # Sum of uncorrelated variables.
            temb_var += self._temb_std**2

        assert self._noise_dim == -1 or self._noise_dim > 0

        # TODO external inputs are not yet considered.
        inp_dim = self._te_dim + \
            (self._noise_dim if self._noise_dim != -1 else 0)
        #(self._size_ext_input if self._size_ext_input is not None else 0) \

        inp_var = (self._te_dim / inp_dim) * temb_var
        #if self._size_ext_input is not None:
        #    inp_var += (self._size_ext_input  / inp_dim) * ext_inp_var
        if self._noise_dim != -1:
            inp_var += (self._noise_dim / inp_dim) * 1.

        c_dim = self._ce_dim

        ### Compute target variance of each output tensor ###
        target_vars = []

        for i, s in enumerate(self.target_shapes):
            # FIXME 1D shape is not necessarily bias vector.
            if len(s) == 1:  # Assume it's a bias vector
                # Assume that last shape has been the corresponding weight
                # tensor.
                if i > 0 and len(self.target_shapes[i - 1]) > 1:
                    fan_in, _ = iutils.calc_fan_in_and_out( \
                        self.target_shapes[i-1])
                else:
                    # FIXME Quick-fix, use fan-out instead.
                    fan_in = s[0]

                target_vars.append(1. / (3. * fan_in))

            else:
                fan_in, fan_out = iutils.calc_fan_in_and_out(s)

                c_relu = 1 if use_xavier else 2

                var_in = c_relu / fan_in
                var_out = c_relu / fan_out

                if method == 'in':
                    var = var_in
                elif method == 'out':
                    var = var_out
                else:
                    var = 2 * (1. / var_in + 1. / var_out)

                target_vars.append(var)

        ### Target variance per chunk ###
        chunk_vars = []
        i = 0
        n = np.prod(self.target_shapes[i])

        for j in range(self._num_chunks):
            m = self._chunk_dim
            var = 0

            while m > 0:
                # Special treatment to fill up last chunk.
                if j == self._num_chunks - 1 and i == len(target_vars) - 1:
                    assert n <= m
                    o = m
                else:
                    o = min(m, n)

                var += o / self._chunk_dim * target_vars[i]
                m -= o
                n -= o

                if n == 0:
                    i += 1
                    if i < len(target_vars):
                        n = np.prod(self.target_shapes[i])

            chunk_vars.append(var)

        max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars)
        max_inp_std = math.sqrt(max_inp_var)
        print('Initializing hypernet with Chunked Hyperfan Init ...')
        if inp_var >= max_inp_var:
            warn('Note, hypernetwork inputs should have an initial total ' +
                 'variance (std) smaller than %f (%f) in order for this ' \
                 % (max_inp_var, max_inp_std) + 'method to work properly.')

        ### Compute variances of chunk embeddings ###
        # We could have done that in the previous loop. But I think the code is
        # more readible this way.
        c_vars = []
        n_clipped = 0
        for i, var in enumerate(chunk_vars):
            c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var)
            if c_var < eps:
                n_clipped += 1
                #warn('Initial variance of chunk embedding %d has to ' % i + \
                #     'be clipped.')

            c_vars.append(max(eps, c_var))

        if n_clipped > 0:
            warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \
                 'chunk embeddings had to be clipped.')

        ### Initialize chunk embeddings ###
        for i, c_emb in enumerate(self.chunk_embeddings):
            c_std = math.sqrt(c_vars[i])
            if cemb_normal_init:
                torch.nn.init.normal_(c_emb, mean=0, std=c_std)
            else:
                a = math.sqrt(3.0) * c_std
                torch.nn.init._no_grad_uniform_(c_emb, -a, a)

        ### Initialize hypernet with fan-in init ###
        for i, w in enumerate(self._hypernet.theta):
            if w.ndim == 1:  # bias
                assert i % 2 == 1
                torch.nn.init.constant_(w, 0)

            else:
                if use_xavier:
                    iutils.xavier_fan_in_(w)
                else:
                    torch.nn.init.kaiming_uniform_(w,
                                                   mode='fan_in',
                                                   nonlinearity='relu')
Ejemplo n.º 2
0
    def apply_hyperfan_init(self,
                            method='in',
                            use_xavier=False,
                            uncond_var=1.,
                            cond_var=1.,
                            mnet=None,
                            w_val=None,
                            w_var=None,
                            b_val=None,
                            b_var=None):
        r"""Initialize the network using `hyperfan init`.

        Hyperfan initialization was developed in the following paper for this
        kind of hypernetwork

            "Principled Weight Initialization for Hypernetworks"
            https://openreview.net/forum?id=H1lma24tPB

        The initialization is based on the following idea: When the main network
        would be initialized using Xavier or Kaiming init, then variance of
        activations (fan-in) or gradients (fan-out) would be preserved by using
        a proper variance for the initial weight distribution (assuming certain
        assumptions hold at initialization, which are different for Xavier and
        Kaiming).

        When using this kind of initializations in the hypernetwork, then the
        variance of the initial main net weight distribution would simply equal
        the variance of the input embeddings (which can lead to exploding
        activations, e.g., for fan-in inits).

        The above mentioned paper proposes a quick fix for the type of hypernet
        that resembles the simple MLP hnet implemented in this class, i.e.,
        which have a separate output head per weight tensor in the main network.

        Assuming that input embeddings are initialized with a certain variance
        (e.g., 1) and we use Xavier or Kaiming init for the hypernet, then the
        variance of the last hidden activation will also be 1.

        Then, we can modify the variance of the weights of each output head in
        the hypernet to obtain the same variance per main net weight tensor that
        we would typically obtain when applying Xavier or Kaiming to the main
        network directly.

        Note:
            If ``mnet`` is not provided or the corresponding attribute
            :attr:`mnets.mnet_interface.MainNetInterface.param_shapes_meta` is
            not implemented, then this method assumes that 1D target tensors
            (cf. constructor argument ``target_shapes``) represent bias vectors
            in the main network.

        Note:
            To compute the hyperfan-out initialization of bias vectors, we need
            access to the fan-in of the layer, which we can only compute based
            on the corresponding weight tensor in the same layer. This is only
            possible if ``mnet`` is provided. Otherwise, the following
            heuristic is applied. We assume that the shape directly preceding
            a bias shape in the constructor argument ``target_shapes`` is the
            corresponding weight tensor.

        Note:
            All hypernet inputs are assumed to be zero-mean random variables.

        **Variance of the hypernet input**

        In general, the input to the hypernetwork can be a concatenation of
        multiple embeddings (see description of arguments ``uncond_var`` and
        ``cond_var``).

        Let's denote the complete hypernetwork input by
        :math:`\mathbf{x} \in \mathbb{R}^n`, which consists of a conditional
        embedding :math:`\mathbf{e} \in \mathbb{R}^{n_e}` and an unconditional
        input :math:`\mathbf{c} \in \mathbb{R}^{n_c}`, i.e.,

        .. math::

            \mathbf{x} = \begin{bmatrix} \
            \mathbf{e} \\ \
            \mathbf{c} \
            \end{bmatrix}

        We simply define the variance of an input :math:`\text{Var}(x_j)` as
        the weighted average of the individual variances, i.e.,

        .. math::

            \text{Var}(x_j) \equiv \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                \frac{n_c}{n_e+n_c} \text{Var}(c)

        To see that this is correct, consider a linear layer
        :math:`\mathbf{y} = W \mathbf{x}` or

        .. math::

            y_i &= \sum_j w_{ij} x_j \\ \
                &= \sum_{j=1}^{n_e} w_{ij} e_j + \
                   \sum_{j=n_e+1}^{n_e+n_c} w_{ij} c_{j-n_e}

        Hence, we can compute the variance of :math:`y_i` as follows (assuming
        the typical Xavier assumptions):

        .. math::

            \text{Var}(y) &= n_e \text{Var}(w) \text{Var}(e) + \
                             n_c \text{Var}(w) \text{Var}(c) \\ \
                          &= \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                             \frac{n_c}{n_e+n_c} \text{Var}(c)

        Note, that Xavier would have initialized :math:`W` using
        :math:`\text{Var}(w) = \frac{1}{n} = \frac{1}{n_e+n_c}`.

        Args:
            method (str): The type of initialization that should be applied.
                Possible options are:

                - ``'in'``: Use `Hyperfan-in`.
                - ``'out'``: Use `Hyperfan-out`.
                - ``'harmonic'``: Use the harmonic mean of the `Hyperfan-in` and
                  `Hyperfan-out` init.
            use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``)
                init should be used.
            uncond_var (float): The variance of unconditional embeddings. This
                value is only taken into consideration if ``uncond_in_size > 0``
                (cf. constructor arguments).
            cond_var (float): The initial variance of conditional embeddings.
                This value is only taken into consideration if
                ``cond_in_size > 0`` (cf. constructor arguments).
            mnet (mnets.mnet_interface.MainNetInterface, optional): If
                applicable, the user should provide the main (or target)
                network, whose weights are generated by this hypernetwork. The
                ``mnet`` instance is used to extract valuable information that
                improve the initialization result. If provided, it is assumed
                that ``target_shapes`` (cf. constructor arguments) corresponds
                either to
                :attr:`mnets.mnet_interface.MainNetInterface.param_shapes` or
                :attr:`mnets.mnet_interface.MainNetInterface.hyper_shapes_learned`.
            w_val (list or dict, optional): The mean of the distribution with
                which output head weight matrices are initialized. Note, each
                weight tensor prescribed by
                :attr:`hnets.hnet_interface.HyperNetInterface.target_shapes` is
                produced via an independent linear output head.

                One may either specify a list of numbers having the same length
                as :attr:`hnets.hnet_interface.HyperNetInterface.target_shapes`
                or specify a dictionary which may have as keys the tensor names
                occurring in
                :attr:`mnets.mnet_interface.MainNetInterface.param_shapes_meta`
                and the corresponding mean value for the weight matrices of all
                output heads producing this type of tensor.
                If a list is provided, entries may be ``None`` and if a
                dictionary is provided, not all types of parameter tensors need
                to be specified. For tensors, for which no value is specified,
                the default value will be used. The default values for tensor
                types ``'weight'`` and ``'bias'`` are calculated based on the
                proposed hyperfan-initialization. For other tensor types the
                actual hypernet outputs should be drawn from the following
                distributions

                  - ``'bn_scale'``: :math:`w \sim \delta(w - 1)`
                  - ``'bn_shift'``: :math:`w \sim \delta(w)`
                  - ``'cm_scale'``: :math:`w \sim \delta(w - 1)`
                  - ``'cm_shift'``: :math:`w \sim \delta(w)`
                  - ``'embedding'``: :math:`w \sim \mathcal{N}(0, 1)`

                Which would correspond to the following passed arguments

                .. code-block:: python

                    w_val = {
                        'bn_scale': 0,
                        'bn_shift': 0,
                        'cm_scale': 0,
                        'cm_shift': 0,
                        'embedding': 0
                    }
                    w_var = {
                        'bn_scale': 0,
                        'bn_shift': 0,
                        'cm_scale': 0,
                        'cm_shift': 0,
                        'embedding': 0
                    }
                    b_val = {
                        'bn_scale': 1,
                        'bn_shift': 0,
                        'cm_scale': 1,
                        'cm_shift': 0,
                        'embedding': 0
                    }
                    b_var = {
                        'bn_scale': 0,
                        'bn_shift': 0,
                        'cm_scale': 0,
                        'cm_shift': 0,
                        'embedding': 1
                    }
            w_var (list or dict, optional): The variance of the distribution
                with which output head weight matrices are initialized. Variance
                values of zero means that weights are set to a constant defined
                by ``w_val``. See description of argument ``w_val`` for more
                details.
            b_val (list or dict, optional): The mean of the distribution
                with which output head bias vectors are initialized.
                See description of argument ``w_val`` for more details.
            b_var (list or dict, optional): The variance of the distribution
                with which output head bias vectors are initialized.
                See description of argument ``w_val`` for more details.
        """
        if method not in ['in', 'out', 'harmonic']:
            raise ValueError('Invalid value "%s" for argument "method".' %
                             method)
        if self.unconditional_params is None:
            assert self._no_uncond_weights
            raise ValueError('Hypernet without internal weights can\'t be ' +
                             'initialized.')

        ### Extract meta-information about target shapes ###
        meta = None
        if mnet is not None:
            assert isinstance(mnet, MainNetInterface)

            try:
                meta = mnet.param_shapes_meta
            except:
                meta = None

            if meta is not None:
                if len(self.target_shapes) == len(mnet.param_shapes):
                    pass
                    # meta = mnet.param_shapes_meta
                elif len(self.target_shapes) == len(mnet.hyper_shapes_learned):
                    meta = []
                    for ii in mnet.hyper_shapes_learned_ref:
                        meta.append(mnet.param_shapes_meta[ii])
                else:
                    warn('Target shapes of this hypernetwork could not be ' +
                         'matched to the meta information provided to the ' +
                         'initialization.')
                    meta = None

        # TODO If the user doesn't (or can't) provide an `mnet` instance, we
        # should alternatively allow him to pass meta information directly.
        if meta is None:
            meta = []

            # Heuristical approach to derive meta information from given shapes.
            layer_ind = 0
            for i, s in enumerate(self.target_shapes):
                curr_meta = dict()

                if len(s) > 1:
                    curr_meta['name'] = 'weight'
                    curr_meta['layer'] = layer_ind
                    layer_ind += 1
                else:  # just a heuristic, we can't know
                    curr_meta['name'] = 'bias'
                    if i > 0 and meta[-1]['name'] == 'weight':
                        curr_meta['layer'] = meta[-1]['layer']
                    else:
                        curr_meta['layer'] = -1

                meta.append(curr_meta)

        assert len(meta) == len(self.target_shapes)

        # Mapping from layer index to the corresponding shape.
        layer_shapes = dict()
        # Mapping from layer index to whether the layer has a bias vector.
        layer_has_bias = defaultdict(lambda: False)
        for i, m in enumerate(meta):
            if m['name'] == 'weight' and m['layer'] != -1:
                assert len(self.target_shapes[i]) > 1
                layer_shapes[m['layer']] = self.target_shapes[i]
            if m['name'] == 'bias' and m['layer'] != -1:
                layer_has_bias[m['layer']] = True

        ### Compute input variance ###
        cond_dim = self._cond_in_size
        uncond_dim = self._uncond_in_size
        inp_dim = cond_dim + uncond_dim

        input_variance = 0
        if cond_dim > 0:
            input_variance += (cond_dim / inp_dim) * cond_var
        if uncond_dim > 0:
            input_variance += (uncond_dim / inp_dim) * uncond_var

        ### Initialize hidden layers to preserve variance ###
        # Note, if batchnorm layers are used, they will simply be initialized to
        # have no effect after initialization. This does not effect the
        # performed whitening operation.
        if self.batchnorm_layers is not None:
            for bn_layer in self.batchnorm_layers:
                if hasattr(bn_layer, 'scale'):
                    nn.init.ones_(bn_layer.scale)
                if hasattr(bn_layer, 'bias'):
                    nn.init.zeros_(bn_layer.bias)

            # Since batchnorm layers whiten the statistics of hidden
            # acitivities, the variance of the input will not be preserved by
            # Xavier/Kaiming.
            if len(self.batchnorm_layers) > 0:
                input_variance = 1.

        # We initialize biases with 0 (see Xavier assumption 4 in the Hyperfan
        # paper). Otherwise, we couldn't ignore the biases when computing the
        # output variance of a layer.
        # Note, we have to use fan-in init for the hidden layer to ensure the
        # property, that we preserve the input variance.
        assert len(self._layers) + 1 == len(self.layer_weight_tensors)
        for i, w_tensor in enumerate(self.layer_weight_tensors[:-1]):
            if use_xavier:
                iutils.xavier_fan_in_(w_tensor)
            else:
                torch.nn.init.kaiming_uniform_(w_tensor,
                                               mode='fan_in',
                                               nonlinearity='relu')

            if self.has_bias:
                nn.init.zeros_(self.layer_bias_vectors[i])

        ### Define default parameters of weight init distributions ###
        w_val_list = []
        w_var_list = []
        b_val_list = []
        b_var_list = []

        for i, m in enumerate(meta):

            def extract_val(user_arg):
                curr = None
                if isinstance(user_arg, (list, tuple)) and \
                        user_arg[i] is not None:
                    curr = user_arg[i]
                elif isinstance(user_arg, (dict)) and \
                        m['name'] in user_arg.keys():
                    curr = user_arg[m['name']]
                return curr

            curr_w_val = extract_val(w_val)
            curr_w_var = extract_val(w_var)
            curr_b_val = extract_val(b_val)
            curr_b_var = extract_val(b_var)

            if m['name'] == 'weight' or m['name'] == 'bias':
                if None in [curr_w_val, curr_w_var, curr_b_val, curr_b_var]:
                    # If distribution not fully specified, then we just fall
                    # back to hyper-fan init.
                    curr_w_val = None
                    curr_w_var = None
                    curr_b_val = None
                    curr_b_var = None
            else:
                assert m['name'] in [
                    'bn_scale', 'bn_shift', 'cm_scale', 'cm_shift', 'embedding'
                ]
                if curr_w_val is None:
                    curr_w_val = 0
                if curr_w_var is None:
                    curr_w_var = 0
                if curr_b_val is None:
                    curr_b_val = 1 if m['name'] in ['bn_scale', 'cm_scale'] \
                                   else 0
                if curr_b_var is None:
                    curr_b_var = 1 if m['name'] in ['embedding'] else 0

            w_val_list.append(curr_w_val)
            w_var_list.append(curr_w_var)
            b_val_list.append(curr_b_val)
            b_var_list.append(curr_b_var)

        ### Initialize output heads ###
        # Note, that all output heads are realized internally via one large
        # fully-connected layer.
        # All output heads are linear layers. The biases of these linear
        # layers (called gamma and beta in the paper) are simply initialized
        # to zero. Note, that we allow deviations from this below.
        if self.has_bias:
            nn.init.zeros_(self.layer_bias_vectors[-1])

        c_relu = 1 if use_xavier else 2

        # We are not interested in the fan-out, since the fan-out is just
        # the number of elements in the main network.
        # `fan-in` is called `d_k` in the paper and is just the size of the
        # last hidden layer.
        fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(\
            self.layer_weight_tensors[-1])

        s_ind = 0
        for i, out_shape in enumerate(self.target_shapes):
            m = meta[i]
            e_ind = s_ind + int(np.prod(out_shape))

            curr_w_val = w_val_list[i]
            curr_w_var = w_var_list[i]
            curr_b_val = b_val_list[i]
            curr_b_var = b_var_list[i]

            if curr_w_val is None:
                c_bias = 2 if layer_has_bias[m['layer']] else 1

                if m['name'] == 'bias':
                    m_fan_out = out_shape[0]

                    # NOTE For the hyperfan-out init, we also need to know the
                    # fan-in of the layer.
                    if m['layer'] != -1:
                        m_fan_in, _ = iutils.calc_fan_in_and_out( \
                            layer_shapes[m['layer']])
                    else:
                        # FIXME Quick-fix.
                        m_fan_in = m_fan_out

                    var_in = c_relu / (2. * fan_in * input_variance)
                    num = c_relu * (1. - m_fan_in / m_fan_out)
                    denom = fan_in * input_variance
                    var_out = max(0, num / denom)

                else:
                    assert m['name'] == 'weight'
                    m_fan_in, m_fan_out = iutils.calc_fan_in_and_out(out_shape)

                    var_in = c_relu / (c_bias * m_fan_in * fan_in * \
                                       input_variance)
                    var_out = c_relu / (m_fan_out * fan_in * input_variance)

                if method == 'in':
                    var = var_in
                elif method == 'out':
                    var = var_out
                elif method == 'harmonic':
                    var = 2 * (1. / var_in + 1. / var_out)
                else:
                    raise ValueError('Method %s invalid.' % method)

                # Initialize output head weight tensor using `var`.
                std = math.sqrt(var)
                a = math.sqrt(3.0) * std
                torch.nn.init._no_grad_uniform_( \
                    self.layer_weight_tensors[-1][s_ind:e_ind, :], -a, a)
            else:
                if curr_w_var == 0:
                    nn.init.constant_(
                        self.layer_weight_tensors[-1][s_ind:e_ind, :],
                        curr_w_val)
                else:
                    std = math.sqrt(curr_w_var)
                    a = math.sqrt(3.0) * std
                    torch.nn.init._no_grad_uniform_( \
                        self.layer_weight_tensors[-1][s_ind:e_ind, :],
                        curr_w_val-a, curr_w_val+a)

                if curr_b_var == 0:
                    nn.init.constant_(self.layer_bias_vectors[-1][s_ind:e_ind],
                                      curr_b_val)
                else:
                    std = math.sqrt(curr_b_var)
                    a = math.sqrt(3.0) * std
                    torch.nn.init._no_grad_uniform_( \
                        self.layer_bias_vectors[-1][s_ind:e_ind],
                        curr_b_val-a, curr_b_val+a)

            s_ind = e_ind
Ejemplo n.º 3
0
    def apply_hyperfan_init(self,
                            method='in',
                            use_xavier=False,
                            temb_var=1.,
                            ext_inp_var=1.):
        r"""Initialize the network using hyperfan init.

        Hyperfan initialization was developed in the following paper for this
        kind of hypernetwork

            "Principled Weight Initialization for Hypernetworks"
            https://openreview.net/forum?id=H1lma24tPB

        The initialization is based on the following idea: When the main network
        would be initialized using Xavier or Kaiming init, then variance of
        activations (fan-in) or gradients (fan-out) would be preserved by using
        a proper variance for the initial weight distribution (assuming certain
        assumptions hold at initialization, which are different for Xavier and
        Kaiming).

        When using these kind of initializations in the hypernetwork, then the
        variance of the initial main net weight distribution would simply equal
        the variance of the input embeddings (which can lead to exploding
        activations, e.g., for fan-in inits).

        The above mentioned paper proposes a quick fix for the type of hypernets
        which have a separate output head per weight tensor in the main network
        (which is the case for this hypernetwork class).

        Assuming that input embeddings are initialized with a certain variance
        (e.g., 1) and we use Xavier or Kaiming init for the hypernet, then the
        variance of the last hidden activation will also be 1.

        Then, we can modify the variance of the weights of each output head in
        the hypernet to obtain the variance for the main net weight tensors that
        we would typically obtain when applying Xavier or Kaiming to the main
        network directly.

        Warning:
            This method currently assumes that 1D target tensors (cmp.
            constructor argument ``target_shapes``) are bias vectors in the
            main network.

        Warning:
            To compute the hyperfan-out initialization of bias vectors, we need
            access to the fan-in of the layer, which we can only compute based
            on the corresponding weight tensor in the same layer. Since there is
            no clean way of matching a bias shape to its corresponging weight
            tensor shape we use the following heuristic, which should be correct
            for most main networks. We assume that the shape directly preceding
            a bias shape in the constructor argument ``target_shapes`` is the
            corresponding weight tensor.

        **Variance of the hypernet input**

        In general, the input to the hypernetwork can be a concatenation of
        multiple embeddings (see description of arguments ``temb_var`` and
        ``ext_inp_var``).

        Let's denote the complete hypernetwork input by
        :math:`\mathbf{x} \in \mathbb{R}^n`, which consists of a task embedding
        :math:`\mathbf{e} \in \mathbb{R}^{n_e}` and an external input
        :math:`\mathbf{c} \in \mathbb{R}^{n_c}`, i.e.,

        .. math::

            \mathbf{x} = \begin{bmatrix} \
            \mathbf{e} \\ \
            \mathbf{c} \
            \end{bmatrix}

        We simply define the variance of an input :math:`\text{Var}(x_j)` as
        the weighted average of the individual variances, i.e.,

        .. math::

            \text{Var}(x_j) \equiv \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                \frac{n_c}{n_e+n_c} \text{Var}(c)

        To see that this is correct, consider a linear layer
        :math:`\mathbf{y} = W \mathbf{x}` or

        .. math::

            y_i &= \sum_j w_{ij} x_j \\ \
                &= \sum_{j=1}^{n_e} w_{ij} e_j + \
                   \sum_{j=n_e+1}^{n_e+n_c} w_{ij} c_{j-n_e}

        Hence, we can compute the variance of :math:`y_i` as follows (assuming
        the typical Xavier assumptions):

        .. math::

            \text{Var}(y) &= n_e \text{Var}(w) \text{Var}(e) + \
                             n_c \text{Var}(w) \text{Var}(c) \\ \
                          &= \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                             \frac{n_c}{n_e+n_c} \text{Var}(c)

        Note, that Xavier would have initialized :math:`W` using
        :math:`\text{Var}(w) = \frac{1}{n} = \frac{1}{n_e+n_c}`.

        Note:
            This method will automatically incorporate the noise embedding that
            is inputted into the network if constructor argument ``noise_dim``
            was set.

        Note:
            We hypernet inputs should be zero mean.

        Args:
            method (str): The type of initialization that should be applied.
                Possible options are:

                - ``in``: Use `Hyperfan-in`.
                - ``out``: Use `Hyperfan-out`.
                - ``harmonic``: Use the harmonic mean of the `Hyperfan-in` and
                  `Hyperfan-out` init.
            use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``)
                init should be used.
            temb_var (float): The initial variance of the task embeddings.

                .. note::
                    If ``temb_std`` was set in the constructor, then this method
                    will automatically correct the provided ``temb_var`` as 
                    follows: :code:`temb_var += temb_std**2`.
            ext_inp_var (float): The initial variance of the external input.
                Only needs to be specified if external inputs are provided
                (see argument ``ce_dim`` of constructor).
        """
        # FIXME If the network has external inputs and task embeddings, then
        # both these inputs might have different variances. Thus, a single
        # parameter `input_variance` might not be sufficient.
        # Now, we assume that the user provides a proper variance. We could
        # simplify the job for him by providing multiple arguments and compute
        # the weighting ourselves.

        # FIXME Handle constructor arguments `noise_dim` and `temb_std`.
        # Note, we would jost need to add `temb_std**2` to the variance of
        # task embeddings, since the variance of a sum of uncorrelated RVs is
        # just the sum of the individual variances.

        if method not in ['in', 'out', 'harmonic']:
            raise ValueError('Invalid value for argument "method".')
        if not self.has_theta:
            raise ValueError('Hypernet without internal weights can\'t be ' +
                             'initialized.')

        ### Compute input variance ###
        if self._temb_std != -1:
            # Sum of uncorrelated variables.
            temb_var += self._temb_std**2

        assert self._size_ext_input is None or self._size_ext_input > 0
        assert self._noise_dim == -1 or self._noise_dim > 0

        inp_dim = self._te_dim + \
            (self._size_ext_input if self._size_ext_input is not None else 0) \
            + (self._noise_dim if self._noise_dim != -1 else 0)

        input_variance = (self._te_dim / inp_dim) * temb_var
        if self._size_ext_input is not None:
            input_variance += (self._size_ext_input / inp_dim) * ext_inp_var
        if self._noise_dim != -1:
            input_variance += (self._noise_dim / inp_dim) * 1.

        ### Initialize hidden layers to preserve variance ###
        # We initialize biases with 0 (see Xavier assumption 4 in the Hyperfan
        # paper). Otherwise, we couldn't ignore the biases when computing the
        # output variance of a layer.
        # Note, we have to use fan-in init for the hidden layer to ensure the
        # property, that we preserve the input variance.

        for i in range(0, len(self._hidden_dims), 2 if self._use_bias else 1):
            #W = self.theta[i]
            if use_xavier:
                iutils.xavier_fan_in_(self.theta[i])
            else:
                torch.nn.init.kaiming_uniform_(self.theta[i],
                                               mode='fan_in',
                                               nonlinearity='relu')

            if self._use_bias:
                #b = self.theta[i+1]
                torch.nn.init.constant_(self.theta[i + 1], 0)

        ### Initialize output heads ###
        c_relu = 1 if use_xavier else 2
        # FIXME Not a proper way to figure out whether the hnet produces
        # bias vectors in the mnet.
        c_bias = 1
        for s in self.target_shapes:
            if len(s) == 1:
                c_bias = 2
                break
        # This is how we should do it instead.
        #c_bias = 2 if mnet.has_bias else 1

        j = 0
        for i in range(len(self._hidden_dims), len(self._theta_shapes),
                       2 if self._use_bias else 1):

            # All output heads are linear layers. The biases of these linear
            # layers (called gamma and beta in the paper) are simply initialized
            # to zero.
            if self._use_bias:
                #b = self.theta[i+1]
                torch.nn.init.constant_(self.theta[i + 1], 0)

            # We are not interested in the fan-out, since the fan-out is just
            # the number of elements in the corresponding main network tensor.
            # `fan-in` is called `d_k` in the paper and is just the size of the
            # last hidden layer.
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( \
                self.theta[i])

            out_shape = self.target_shapes[j]

            # FIXME 1D output tensors don't need to be bias vectors. They can
            # be arbitrary embeddings or, for instance, batchnorm weights.
            if len(out_shape) == 1:  # Assume output is bias vector.
                m_fan_out = out_shape[0]

                # NOTE For the hyperfan-out init, we also need to know the
                # fan-in of the layer.
                # FIXME We have no proper way at the moment to get the correct
                # fan-in of the layer this bias vector belongs to.
                if j > 0 and len(self.target_shapes[j - 1]) > 1:
                    m_fan_in, _ = iutils.calc_fan_in_and_out( \
                        self.target_shapes[j-1])
                else:
                    # FIXME Quick-fix.
                    m_fan_in = m_fan_out

                var_in = c_relu / (2. * fan_in * input_variance)
                num = c_relu * (1. - m_fan_in / m_fan_out)
                denom = fan_in * input_variance
                var_out = c_relu / max(0, num / denom)

            else:
                m_fan_in, m_fan_out = iutils.calc_fan_in_and_out(out_shape)

                var_in = c_relu / (c_bias * m_fan_in * fan_in * input_variance)
                var_out = c_relu / (m_fan_out * fan_in * input_variance)

            if method == 'in':
                var = var_in
            elif method == 'out':
                var = var_out
            elif method == 'harmonic':
                var = 2 * (1. / var_in + 1. / var_out)
            else:
                raise ValueError('Method %s invalid.' % method)

            # Initialize output head weight tensor using `var`.
            std = math.sqrt(var)
            a = math.sqrt(3.0) * std
            torch.nn.init._no_grad_uniform_(self.theta[i], -a, a)
    def apply_chunked_hyperfan_init(self,
                                    method='in',
                                    use_xavier=False,
                                    uncond_var=1.,
                                    cond_var=1.,
                                    eps=1e-5,
                                    cemb_normal_init=False,
                                    mnet=None,
                                    target_vars=None):
        r"""Initialize the network using a chunked hyperfan init.

        Inspired by the method
        `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we
        implemented for the MLP hypernetwork in method
        :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`, we heuristically
        developed a better initialization method for chunked hypernetworks.

        Unfortunately, the `Hyperfan Init` method from the paper does not apply
        to this kind of hypernetwork, since we reuse the same hypernet output
        head for the whole main network.

        Luckily, we can provide a simple heuristic. Similar to
        `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play
        with the variance of the input embeddings to affect the variance of the
        output weights.

        In a chunked hypernetwork, the input for each chunk is identical except
        for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}`
        denote the remaining inputs to the hypernetwork, which are identical
        for all chunks. Then, assuming the hypernetwork was initialized via
        fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}`
        can be written as follows (see documentation of method
        :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`):

        .. math::

            \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                \frac{n_c}{n_e+n_c} \text{Var}(c)

        Hence, we can achieve a desired output variance :math:`\text{Var}(v)`
        by initializing the chunk embeddings :math:`\mathbf{c}` via the
        following variance:

        .. math::

            \text{Var}(c) = \max \Big\{ 0, \
                \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \
                n_e \text{Var}(e) \big] \Big\}

        Now, one important question remains. How do we pick a desired output
        variance :math:`\text{Var}(v)` for a chunk?

        Note, a chunk may include weights from several layers. The likelihood
        for this to happen depends on the main net architecture and the chunk
        size (see constructor argument ``chunk_size``). The smaller the chunk
        size, the less likely it is that a chunk will contain elements from
        multiple main net weight tensors.

        In case each chunk would contain only weights from one main net weight
        tensor, we could simply pick the variance :math:`\text{Var}(v)` that
        would have been chosen by a main net initialization method (such as
        Xavier).

        In case a chunk contains contributions from several main net weight
        tensors, we apply the following heuristic. If a chunk contains
        contributions of a set of main network weight tensors
        :math:`W_1, \dots, W_K` with relative contribution sizes\
        :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where
        :math:`n_v` denotes the chunk size and if the corresponding main network
        initialization method would require init variances
        :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request
        a weighted average as follow:

        .. math::

            \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k)

        What about bias vectors? Usually, the variance analysis applied to
        Xavier or Kaiming init assumes that biases are initialized to zero. This
        is not possible in this setting, as it would require assigning a
        negative variance to :math:`\mathbf{c}`. Instead, we follow the default
        PyTorch initialization (e.g., see method ``reset_parameters`` in class
        :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly
        within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where
        :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of
        initialization corresponds to a variance of
        :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`.

        Note:
            All hypernet inputs are assumed to be zero-mean random variables.

        Note:
            To avoid that the variances with which chunks are initialized
            have to be clipped (because they are too small or even negative),
            the variance of the remaining hypernet inputs should be properly
            scaled. In general, one should adhere the following rule

            .. math::

                \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v)

            This method will calculate and print the maximum value that should
            be chosen for :math:`\text{Var}(e)` and will print warnings if
            variances have to be clipped.

        Args:
            (....): See arguments of method
                :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`.
            method (str): The type of initialization that should be applied.
                Possible options are:

                - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output
                  variances of the hypernetwork should correspond to fan-in
                  variances.
                - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output
                  variances of the hypernetwork should correspond to fan-out
                  variances.
                - ``harmonic``: Use the harmonic mean of the fan-in and fan-out
                  variance as target variance of the hypernetwork output.
            eps (float): The minimum variance with which a chunk embedding is
                initialized.
            cemb_normal_init (bool): Use normal init for chunk embeddings
                rather than uniform init.
            target_vars (list or dict, optional): The variance of the
                distribution for each parameter tensor generated by this
                hypernetwork. Target variance values can either be provided as
                list of length ``len(hnet.target_shapes)`` or as dictionary.
                The usage is analoguous to the usage of parameter ``w_val`` of
                method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`.

                Note:
                    This method currently does not allow initial output
                    distributions with non-zero mean. However, the docstring of
                    method
                    :meth:`probabilistic.gauss_hnet_init.gauss_hyperfan_init`
                    describes how this is in principle feasible and might be
                    incorporated in the future.
                
                Note:
                    Unspecified target variances for parameter tensors of type
                    ``'weight'`` or ``'bias'`` are computed as described above.
                    Default target variances for all other parameter tensor
                    types are simply ``1``.
        """
        if method not in ['in', 'out', 'harmonic']:
            raise ValueError('Invalid value "%s" for argument "method".' %
                             method)
        if self.unconditional_params is None:
            assert self._no_uncond_weights
            raise ValueError('Hypernet without internal weights can\'t be ' +
                             'initialized.')
        if self.unconditional_params is None and self._cond_chunk_embs:
            assert self._no_cond_weights
            raise ValueError('Chunked hyperfan init cannot be applied if ' +
                             'chunk embeddings are not internally maintained.')

        ### Extract meta-information about target shapes ###
        # FIXME This section is copied from the HMLP implementation.
        meta = None
        if mnet is not None:
            assert isinstance(mnet, MainNetInterface)

            try:
                meta = mnet.param_shapes_meta
            except:
                meta = None

            if meta is not None:
                if len(self.target_shapes) == len(mnet.param_shapes):
                    pass
                    # meta = mnet.param_shapes_meta
                elif len(self.target_shapes) == len(mnet.hyper_shapes_learned):
                    meta = []
                    for ii in mnet.hyper_shapes_learned_ref:
                        meta.append(mnet.param_shapes_meta[ii])
                else:
                    warn('Target shapes of this hypernetwork could not be ' +
                         'matched to the meta information provided to the ' +
                         'initialization.')
                    meta = None

        # TODO If the user doesn't (or can't) provide an `mnet` instance, we
        # should alternatively allow him to pass meta information directly.
        if meta is None:
            meta = []

            # Heuristical approach to derive meta information from given shapes.
            layer_ind = 0
            for i, s in enumerate(self.target_shapes):
                curr_meta = dict()

                if len(s) > 1:
                    curr_meta['name'] = 'weight'
                    curr_meta['layer'] = layer_ind
                    layer_ind += 1
                else:  # just a heuristic, we can't know
                    curr_meta['name'] = 'bias'
                    if i > 0 and meta[-1]['name'] == 'weight':
                        curr_meta['layer'] = meta[-1]['layer']
                    else:
                        curr_meta['layer'] = -1

                meta.append(curr_meta)

        assert len(meta) == len(self.target_shapes)

        # Mapping from layer index to the corresponding shape.
        layer_shapes = dict()
        # Mapping from layer index to whether the layer has a bias vector.
        layer_has_bias = defaultdict(lambda: False)
        for i, m in enumerate(meta):
            if m['name'] == 'weight' and m['layer'] != -1:
                assert len(self.target_shapes[i]) > 1
                layer_shapes[m['layer']] = self.target_shapes[i]
            if m['name'] == 'bias' and m['layer'] != -1:
                layer_has_bias[m['layer']] = True

        ### Compute input variance ###
        # The input variance does not include the variance of chunk embeddings!
        # Instead, it is the variance of the inputs that are shared across all
        # chunks.
        cond_dim = self._cond_in_size
        uncond_dim = self._uncond_in_size
        # Note, `inp_dim` can be zero if conditional chunk embeddings are used.
        inp_dim = cond_dim + uncond_dim

        inp_var = 0
        if cond_dim > 0:
            inp_var += (cond_dim / inp_dim) * cond_var
        if uncond_dim > 0:
            inp_var += (uncond_dim / inp_dim) * uncond_var

        c_dim = self.chunk_emb_size

        ### Initialize hypernet with fan-in init ###
        if self.batchnorm_layers is not None and len(
                self.batchnorm_layers) > 0:
            # Note, batchnorm layers simply whiten the incoming statistics.
            # Thus, if we tune the variance of chunk embeddings, this variance
            # is normalized by a batchnorm layer and thus vanishes.
            raise RuntimeError('Chunked hyperfan init not applicable if a ' +
                               'hypernetwork with batchnorm layers is used.')

        # Note, the whole internal hypernetwork is initialized with fan-in init
        # to simply pass the variance of all inputs to the hypernet output.
        for i, w_tensor in enumerate(self.layer_weight_tensors):
            if use_xavier:
                iutils.xavier_fan_in_(w_tensor)
            else:
                torch.nn.init.kaiming_uniform_(w_tensor,
                                               mode='fan_in',
                                               nonlinearity='relu')

            if self.has_bias:
                nn.init.zeros_(self.layer_bias_vectors[i])

        ### Compute target variance of each output tensor ###
        if target_vars is None:
            target_vars = [None] * len(self.target_shapes)
        elif isinstance(target_vars, dict):
            target_vars_d = target_vars
            target_vars = []

            for i, m in enumerate(meta):
                if m['name'] in target_vars_d.keys():
                    target_vars.append(target_vars_d[m['name']])
                else:
                    target_vars.append(None)
        else:
            assert isinstance(target_vars, (list, tuple))
            assert len(target_vars) == len(self.target_shapes)

        for i, s in enumerate(self.target_shapes):
            if target_vars[i] is not None:
                # Use user specified target variance.
                continue

            m = meta[i]

            if m['name'] == 'bias':
                if m['layer'] != -1:
                    fan_in, _ = iutils.calc_fan_in_and_out( \
                        layer_shapes[m['layer']])
                else:
                    # FIXME Quick-fix, use fan-out instead.
                    fan_in = s[0]

                target_vars[i] = 1. / (3. * fan_in)

            elif m['name'] == 'weight':
                fan_in, fan_out = iutils.calc_fan_in_and_out(s)

                c_relu = 1 if use_xavier else 2

                var_in = c_relu / fan_in
                var_out = c_relu / fan_out

                if method == 'in':
                    var = var_in
                elif method == 'out':
                    var = var_out
                else:
                    var = 2 * (1. / var_in + 1. / var_out)

                target_vars[i] = var

            else:
                target_vars[i] = 1.

        ### Target variance per chunk ###
        chunk_vars = []
        i = 0
        n = np.prod(self.target_shapes[i])

        for j in range(self.num_chunks):
            m = self._chunk_size
            var = 0

            while m > 0:
                # Special treatment to fill up last chunk.
                if j == self.num_chunks - 1 and i == len(target_vars) - 1:
                    assert n <= m
                    o = m
                else:
                    o = min(m, n)

                var += o / self._chunk_size * target_vars[i]
                m -= o
                n -= o

                if n == 0:
                    i += 1
                    if i < len(target_vars):
                        n = np.prod(self.target_shapes[i])

            chunk_vars.append(var)

        if inp_dim > 0:
            max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars)
            max_inp_std = math.sqrt(max_inp_var)
            print('Initializing hypernet with Chunked Hyperfan Init ...')
            if inp_var >= max_inp_var:
                warn('Note, hypernetwork inputs should have an initial total ' +
                     'variance (std) smaller than %f (%f) in order for this ' \
                     % (max_inp_var, max_inp_std) + 'method to work properly.')

        ### Compute variances of chunk embeddings ###
        # We could have done that in the previous loop. But I think the code is
        # more readible this way.
        c_vars = []
        n_clipped = 0
        for i, var in enumerate(chunk_vars):
            c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var)
            if c_var < eps:
                n_clipped += 1
                #warn('Initial variance of chunk embedding %d has to ' % i + \
                #     'be clipped.')

            c_vars.append(max(eps, c_var))

        if n_clipped > 0:
            warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \
                 'chunk embeddings had to be clipped.')

        ### Initialize chunk embeddings ###
        for i in range(self.num_chunks):
            c_std = math.sqrt(c_vars[i])

            num_conds = self.num_known_conds if self._cond_chunk_embs else 1
            for j in range(num_conds):
                cond_id = j if self._cond_chunk_embs else None

                c_emb = self.get_chunk_emb(chunk_id=i, cond_id=cond_id)

                if cemb_normal_init:
                    torch.nn.init.normal_(c_emb, mean=0, std=c_std)
                else:
                    a = math.sqrt(3.0) * c_std
                    torch.nn.init._no_grad_uniform_(c_emb, -a, a)