Esempio n. 1
0
def get_hypernet(config,
                 device,
                 net_type,
                 target_shapes,
                 num_conds,
                 no_cond_weights=False,
                 no_uncond_weights=False,
                 uncond_in_size=0,
                 shmlp_chunk_shapes=None,
                 shmlp_num_per_chunk=None,
                 shmlp_assembly_fct=None,
                 verbose=True,
                 cprefix=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    Args:
        config (argparse.Namespace): Command-line arguments.

            Note:
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hnet_args`.
        device: PyTorch device.
        net_type (str): The type of network. The following options are
            available:

            - ``'hmlp'``
            - ``'chunked_hmlp'``
            - ``'structured_hmlp'``
            - ``'hdeconv'``
            - ``'chunked_hdeconv'``
        target_shapes (list): See argument ``target_shapes`` of
            :class:`hnets.mlp_hnet.HMLP`.
        num_conds (int): Number of conditions that should be known to the
            hypernetwork.
        no_cond_weights (bool): See argument ``no_cond_weights`` of
            :class:`hnets.mlp_hnet.HMLP`.
        no_uncond_weights (bool): See argument ``no_uncond_weights`` of
            :class:`hnets.mlp_hnet.HMLP`.
        uncond_in_size (int): See argument ``uncond_in_size`` of
            :class:`hnets.mlp_hnet.HMLP`.
        shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this function are prefixed, since several
            hypernetworks should be generated.

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hnet_args`.
    """
    assert net_type in [
        'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv'
    ]

    hnet = None

    ### FIXME Code almost identically copied from `get_mnet_model` ###
    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    if hc('hnet_net_act'):
        net_act = gc('hnet_net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('hnet_no_bias')
    dropout_rate = get_val('hnet_dropout_rate')
    specnorm = get_val('hnet_specnorm')
    batchnorm = get_val('hnet_batchnorm')
    no_batchnorm = get_val('hnet_no_batchnorm')
    #bn_no_running_stats = get_val('hnet_bn_no_running_stats')
    #n_distill_stats = get_val('hnet_bn_distill_stats')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x
    ### FIXME Code copied until here                               ###

    if hc('hmlp_arch'):
        hmlp_arch_is_list = False
        hmlp_arch = gc('hmlp_arch')
        if ';' in hmlp_arch:
            hmlp_arch_is_list = True
            if net_type != 'structured_hmlp':
                raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) +
                                 'contain semicolons for network type ' +
                                 '"structured_hmlp"!')
            hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')]
        else:
            hmlp_arch = misc.str_to_ints(hmlp_arch)
    if hc('chunk_emb_size'):
        chunk_emb_size = gc('chunk_emb_size')
        chunk_emb_size = misc.str_to_ints(chunk_emb_size)
        if len(chunk_emb_size) == 1:
            chunk_emb_size = chunk_emb_size[0]
        else:
            if net_type != 'structured_hmlp':
                raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) +
                                 'only contain multiple values for network ' +
                                 'type "structured_hmlp"!')

    if hc('cond_emb_size'):
        cond_emb_size = gc('cond_emb_size')
    else:
        cond_emb_size = 0

    if net_type == 'hmlp':
        assert hc('hmlp_arch')

        # Default keyword arguments of class HMLP.
        dkws = misc.get_default_args(HMLP.__init__)

        hnet = HMLP(target_shapes,
                    uncond_in_size=uncond_in_size,
                    cond_in_size=cond_emb_size,
                    layers=hmlp_arch,
                    verbose=verbose,
                    activation_fn=assign(net_act, dkws['activation_fn']),
                    use_bias=assign(not no_bias, dkws['use_bias']),
                    no_uncond_weights=no_uncond_weights,
                    no_cond_weights=no_cond_weights,
                    num_cond_embs=num_conds,
                    dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
                    use_spectral_norm=assign(specnorm,
                                             dkws['use_spectral_norm']),
                    use_batch_norm=assign(use_bn,
                                          dkws['use_batch_norm'])).to(device)

    elif net_type == 'chunked_hmlp':
        assert hc('hmlp_arch')
        assert hc('chmlp_chunk_size')
        assert hc('chunk_emb_size')
        cond_chunk_embs = get_val('use_cond_chunk_embs')

        # Default keyword arguments of class ChunkedHMLP.
        dkws = misc.get_default_args(ChunkedHMLP.__init__)

        hnet = ChunkedHMLP(
            target_shapes,
            gc('chmlp_chunk_size'),
            chunk_emb_size=chunk_emb_size,
            cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']),
            uncond_in_size=uncond_in_size,
            cond_in_size=cond_emb_size,
            layers=hmlp_arch,
            verbose=verbose,
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_uncond_weights=no_uncond_weights,
            no_cond_weights=no_cond_weights,
            num_cond_embs=num_conds,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device)

    elif net_type == 'structured_hmlp':
        assert hc('hmlp_arch')
        assert hc('chunk_emb_size')
        cond_chunk_embs = get_val('use_cond_chunk_embs')

        assert shmlp_chunk_shapes is not None and \
            shmlp_num_per_chunk is not None and \
            shmlp_assembly_fct is not None

        # Default keyword arguments of class StructuredHMLP.
        dkws = misc.get_default_args(StructuredHMLP.__init__)
        dkws_hmlp = misc.get_default_args(HMLP.__init__)

        shmlp_hmlp_kwargs = []
        if not hmlp_arch_is_list:
            hmlp_arch = [hmlp_arch]
        for i, arch in enumerate(hmlp_arch):
            shmlp_hmlp_kwargs.append({
                'layers': arch,
                'activation_fn': assign(net_act, dkws_hmlp['activation_fn']),
                'use_bias': assign(not no_bias, dkws_hmlp['use_bias']),
                'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']),
                'use_spectral_norm': \
                    assign(specnorm, dkws_hmlp['use_spectral_norm']),
                'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm'])
            })
        if len(shmlp_hmlp_kwargs) == 1:
            shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0]

        hnet = StructuredHMLP(target_shapes,
                              shmlp_chunk_shapes,
                              shmlp_num_per_chunk,
                              chunk_emb_size,
                              shmlp_hmlp_kwargs,
                              shmlp_assembly_fct,
                              cond_chunk_embs=assign(cond_chunk_embs,
                                                     dkws['cond_chunk_embs']),
                              uncond_in_size=uncond_in_size,
                              cond_in_size=cond_emb_size,
                              verbose=verbose,
                              no_uncond_weights=no_uncond_weights,
                              no_cond_weights=no_cond_weights,
                              num_cond_embs=num_conds).to(device)

    elif net_type == 'hdeconv':
        #HDeconv
        raise NotImplementedError
    else:
        assert net_type == 'chunked_hdeconv'
        #ChunkedHDeconv
        raise NotImplementedError

    return hnet
Esempio n. 2
0
    def __init__(self,
                 target_shapes,
                 chunk_shapes,
                 num_per_chunk,
                 chunk_emb_sizes,
                 hmlp_kwargs,
                 assembly_fct,
                 cond_chunk_embs=False,
                 uncond_in_size=0,
                 cond_in_size=8,
                 verbose=True,
                 no_uncond_weights=False,
                 no_cond_weights=False,
                 num_cond_embs=1):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        HyperNetInterface.__init__(self)

        ### Basic checks for user inputs ###
        assert isinstance(chunk_shapes,
                          (list, tuple)) and len(chunk_shapes) > 0
        num_chunk_weights = 0
        for chunk in chunk_shapes:  # Each chunk is a list of shapes!
            assert isinstance(chunk, (list, tuple)) and len(chunk) > 0
            num_chunk_weights += StructuredHMLP.shapes_to_num_weights(chunk)
        num_trgt_weights = StructuredHMLP.shapes_to_num_weights(target_shapes)
        if num_trgt_weights > num_chunk_weights:
            # TODO Should we display a warning? The user might actively want
            # to reuse the same weights in the target network. In the end, the
            # user should be completely free on how he assembles the chunks to
            # weights within the `assembly_fct`.
            pass
        assert isinstance(num_per_chunk, (list, tuple)) and \
            len(num_per_chunk) == len(chunk_shapes)
        if 0 in num_per_chunk:
            raise ValueError('Option "num_per_chunk" may not contains 0s. ' +
                             'Each internal hypernetwork must create at ' +
                             'least one chunk!')
        assert isinstance(chunk_emb_sizes, (int, list, tuple))
        if isinstance(chunk_emb_sizes, int):
            chunk_emb_sizes = [chunk_emb_sizes] * len(chunk_shapes)
        assert len(chunk_emb_sizes) == len(chunk_shapes)
        if 0 in chunk_emb_sizes and uncond_in_size == 0 and cond_in_size == 0:
            raise ValueError('Argument "chunk_emb_sizes" may not contain ' +
                             '0s if "uncond_in_size" and "cond_in_size" are ' +
                             '0!')
        for i, s in enumerate(chunk_emb_sizes):
            if s == 0 and num_per_chunk[i] != 1:
                raise ValueError('Option "chunk_emb_sizes" may only contain ' +
                                 'zeroes if the corresponding entry in ' +
                                 '"num_per_chunk" is 1.')
        assert isinstance(hmlp_kwargs, (dict, list, tuple))
        if isinstance(hmlp_kwargs, dict):
            hmlp_kwargs = [dict(hmlp_kwargs) for _ in range(len(chunk_shapes))]
        assert len(hmlp_kwargs) == len(chunk_shapes)
        for hkwargs in hmlp_kwargs:
            assert isinstance(hkwargs, dict)
            forbidden = [
                'uncond_in_size', 'cond_in_size', 'no_uncond_weights',
                'no_cond_weights', 'num_cond_embs'
            ]
            for kw in forbidden:
                if kw in hkwargs.keys():
                    raise ValueError('Key %s may not be passed with argument ' \
                                     % kw + '"hmlp_kwargs"!')

            if 'verbose' not in hkwargs.keys():
                hkwargs['verbose'] = False

        ### Make constructor arguments internally available ###
        self._chunk_shapes = chunk_shapes
        self._num_per_chunk = num_per_chunk
        self._chunk_emb_sizes = chunk_emb_sizes
        #self._hkwargs = hkwargs
        self._assembly_fct = assembly_fct
        self._cond_chunk_embs = cond_chunk_embs
        self._uncond_in_size = uncond_in_size
        self._cond_in_size = cond_in_size
        self._no_uncond_weights = no_uncond_weights
        self._no_cond_weights = no_cond_weights
        self._num_cond_embs = num_cond_embs

        ### Create underlying full hypernets ###
        num_hnets = len(chunk_shapes)
        self._hnets = []

        for i in range(num_hnets):
            # Note, even if chunk embeddings are considered conditional, they
            # are maintained in this object and just fed as an external input
            # to the underlying hnet.
            hnet_uncond_in_size = uncond_in_size + chunk_emb_sizes[i]

            # Conditional inputs (`cond_in_size`) will be maintained by the
            # first internal hypernetwork.
            if i == 0:
                hnet_no_cond_weights = no_cond_weights
                hnet_num_cond_embs = num_cond_embs
                if cond_chunk_embs and cond_in_size == 0:
                    # If there are no other conditional embeddings except the
                    # chunk embeddings, we tell the first underlying hnet
                    # explicitly that it doesn't need to maintain any
                    # conditional weights to avoid that it will throw a warning.
                    hnet_num_cond_embs = 0
            else:
                # All other hypernetworks will be passed the conditional
                # embeddings from the first hypernet as input.
                hnet_no_cond_weights = True
                hnet_num_cond_embs = 0

            self._hnets.append(
                HMLP(chunk_shapes[i],
                     uncond_in_size=hnet_uncond_in_size,
                     cond_in_size=cond_in_size,
                     no_uncond_weights=no_uncond_weights,
                     no_cond_weights=hnet_no_cond_weights,
                     num_cond_embs=hnet_num_cond_embs,
                     **hmlp_kwargs[i]))

        ### Setup attributes required by interface ###
        # Most of these attributes are taken over from the internally
        # maintained hypernetworks.
        self._target_shapes = target_shapes
        self._num_known_conds = self._num_cond_embs

        # As we just append the weights of the internal hypernets we will have
        # output weights all over the place.
        # Additionally, it would be complicated to assign outputs to target
        # outputs, as we do not know, what is happening in the `assembly_fct`.
        # Also, keep in mind that we will append chunk embeddings at the end
        # of `param_shapes`.
        self._mask_fc_out = False

        self._unconditional_param_shapes_ref = []
        self._param_shapes = []
        self._param_shapes_meta = []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        for i, hnet in enumerate(self._hnets):
            # Note, it is important to convert lists into new object and not
            # just copy references!
            # Note, we have to adapt all references if `i > 0`.

            ps_len_old = len(self._param_shapes)

            for ref in hnet._unconditional_param_shapes_ref:
                self._unconditional_param_shapes_ref.append(ref + ps_len_old)

            if hnet._internal_params is not None:
                if self._internal_params is None:
                    self._internal_params = nn.ParameterList()
                ip_len_old = len(self._internal_params)
                self._internal_params.extend( \
                    nn.ParameterList(hnet._internal_params))
            self._param_shapes.extend(list(hnet._param_shapes))
            for meta in hnet.param_shapes_meta:
                assert 'hnet_ind' not in meta.keys()
                assert 'layer' in meta.keys()
                assert 'index' in meta.keys()
                new_meta = dict(meta)
                new_meta['hnet_ind'] = i
                if i > 0:
                    # FIXME We should properly adjust colliding `layer` IDs.
                    new_meta['layer'] = -1
                new_meta['index'] = meta['index'] + ip_len_old
                self._param_shapes_meta.append(new_meta)

            if hnet._hyper_shapes_learned is not None:
                if self._hyper_shapes_learned is None:
                    self._hyper_shapes_learned = []
                    self._hyper_shapes_learned_ref = []
                self._hyper_shapes_learned.extend( \
                    list(hnet._hyper_shapes_learned))
                for ref in hnet._hyper_shapes_learned_ref:
                    self._hyper_shapes_learned_ref.append(ref + ps_len_old)
            if hnet._hyper_shapes_distilled is not None:
                if self._hyper_shapes_distilled is None:
                    self._hyper_shapes_distilled = []
                self._hyper_shapes_distilled.extend( \
                    list(hnet._hyper_shapes_distilled))

            if self._has_bias is None:
                self._has_bias = hnet._has_bias
            elif self._has_bias != hnet._has_bias:
                self._has_bias = False
                # FIXME We should overwrite the getter and throw an error!
                warn('Some internally maintained hypernetworks use biases, ' +
                     'while others don\'t. Setting attribute "has_bias" to ' +
                     'False.')

            if self._has_fc_out is None:
                self._has_fc_out = hnet._has_fc_out
            else:
                assert self._has_fc_out == hnet._has_fc_out

            if self._has_linear_out is None:
                self._has_linear_out = hnet._has_linear_out
            else:
                assert self._has_linear_out == hnet._has_linear_out

            self._layer_weight_tensors.extend( \
                nn.ParameterList(hnet._layer_weight_tensors))
            self._layer_bias_vectors.extend( \
                nn.ParameterList(hnet._layer_bias_vectors))
            if hnet._batchnorm_layers is not None:
                if self._batchnorm_layers is None:
                    self._batchnorm_layers = nn.ModuleList()
                self._batchnorm_layers.extend( \
                    nn.ModuleList(hnet._batchnorm_layers))
            if hnet._context_mod_layers is not None:
                if self._context_mod_layers is None:
                    self._context_mod_layers = nn.ModuleList()
                self._context_mod_layers.extend( \
                    nn.ModuleList(hnet._context_mod_layers))

        if self._hyper_shapes_distilled is not None:
            raise NotImplementedError('Distillation of parameters not ' +
                                      'supported yet!')

        ### Create chunk embeddings ###
        if cond_in_size == 0 and uncond_in_size == 0 and 0 in chunk_emb_sizes:
            raise ValueError('At least one internal hypernetwork has no ' +
                             'chunk embedding(s). Therefore, the input size ' +
                             'might not be 0.')
        if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs:
            # Note, we could also allow this case. It would be analoguous to
            # creating a full hypernet with no unconditional input and one
            # conditional embedding. But the user can explicitly achieve that
            # as noted below.
            raise ValueError('If no external (conditional or unconditional) ' +
                             'input is provided to the hypernetwork, then ' +
                             'it can only learn a fixed output. If this ' +
                             'behavior is desired, please enable ' +
                             '"cond_chunk_embs" and set "num_cond_embs=1".')

        chunk_emb_shapes = []
        # To which internal hnet does the corresponding chunk shape belong to.
        chunk_emb_refs = []
        for i, size in enumerate(chunk_emb_sizes):
            if size == 0:
                # No chunk embeddings for internal hnet `i`.
                continue

            chunk_emb_refs.append(i)
            assert num_per_chunk[i] > 0
            chunk_emb_shapes.append([num_per_chunk[i], size])
        self._chunk_emb_shapes = chunk_emb_shapes
        self._chunk_emb_refs = chunk_emb_refs

        # How often do we have to instantiate the chunk embeddings prescribed by
        # `chunk_emb_shapes`?
        num_cemb_weights = 1
        no_cemb_weights = no_uncond_weights
        if cond_chunk_embs:
            num_cemb_weights = num_cond_embs
            no_cemb_weights = no_cond_weights

        # Number of conditional and unconditional parameters so far.
        tmp_num_uncond = len(self._unconditional_param_shapes_ref)
        tmp_num_cond = len(self._param_shapes) - tmp_num_uncond

        # List of lists of inds.
        # Indices of chunk embedding per condition within
        # `conditional_param_shapes`, if chunk embeddings are conditional.
        # Otherwise, indices of chunk embeddings within
        # `unconditional_param_shapes`.
        self._chunk_emb_inds = [[] for _ in range(num_cemb_weights)]

        for i in range(num_cemb_weights):
            for j, shape in enumerate(chunk_emb_shapes):
                if not no_cemb_weights:
                    self._internal_params.append(nn.Parameter( \
                        data=torch.Tensor(*shape), requires_grad=True))
                    torch.nn.init.normal_(self._internal_params[-1],
                                          mean=0.,
                                          std=1.)
                else:
                    self._hyper_shapes_learned.append(shape)
                    self._hyper_shapes_learned_ref.append( \
                        len(self.param_shapes))

                if not cond_chunk_embs:
                    self._unconditional_param_shapes_ref.append( \
                        len(self.param_shapes))

                self._param_shapes.append(shape)
                # In principle, these embeddings also belong to the input, so we
                # just assign them as "layer" 0 (note, the underlying hnets use
                # the same layer ID for its embeddings.
                self._param_shapes_meta.append({
                    'name': 'embedding',
                    'index': -1 if no_cemb_weights else \
                        len(self._internal_params)-1,
                    'layer': 0,
                    'info': 'chunk embeddings',
                    'hnet_ind': chunk_emb_refs[j],
                    'cond_id': i if cond_chunk_embs else -1
                })

                if cond_chunk_embs:
                    self._chunk_emb_inds[i].append(tmp_num_cond)
                    tmp_num_cond += 1
                else:
                    self._chunk_emb_inds[i].append(tmp_num_uncond)
                    tmp_num_uncond += 1

        assert len(self.param_shapes) == tmp_num_uncond + tmp_num_cond

        ### Finalize construction ###
        self._is_properly_setup()

        if verbose:
            print('Created Structured Chunked MLP Hypernet.')
            print('It manages %d full hypernetworks internally that produce ' \
                  % (num_hnets) + '%s chunks in total.' % (self.num_chunks))
            print('The internal hypernetworks have a combined output size of ' +
                  '%d compared to %d weights produced by this network.' \
                  % (num_chunk_weights, self.num_outputs))
            print(self)
    def __init__(self,
                 target_shapes,
                 chunk_size,
                 chunk_emb_size=8,
                 cond_chunk_embs=False,
                 uncond_in_size=0,
                 cond_in_size=8,
                 layers=(100, 100),
                 verbose=True,
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_uncond_weights=False,
                 no_cond_weights=False,
                 num_cond_embs=1,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        HyperNetInterface.__init__(self)

        assert isinstance(chunk_size, int) and chunk_size > 0
        assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0

        ### Make constructor arguments internally available ###
        self._chunk_size = chunk_size
        self._chunk_emb_size = chunk_emb_size
        self._cond_chunk_embs = cond_chunk_embs
        self._uncond_in_size = uncond_in_size
        self._cond_in_size = cond_in_size
        self._no_uncond_weights = no_uncond_weights
        self._no_cond_weights = no_cond_weights
        self._num_cond_embs = num_cond_embs

        ### Create underlying full hypernet ###
        # Note, even if chunk embeddings are considered conditional, they
        # are maintained in this object and just fed as an external input to the
        # underlying hnet.
        hnet_uncond_in_size = uncond_in_size + chunk_emb_size
        hnet_num_cond_embs = num_cond_embs
        if cond_chunk_embs and num_cond_embs == 0:
            raise ValueError('Conditional chunk embeddings can only be used ' +
                             'if conditions are known to the hypernetwork!')
        if cond_chunk_embs and cond_in_size == 0:
            # If there are no other conditional embeddings except the chunk
            # embeddings, we tell the underlying hnet explicitly that it doesn't
            # need to maintain any conditional weights to avoid that it will
            # throw a warning.
            hnet_num_cond_embs = 0
        self._hnet = HMLP([[chunk_size]],
                          uncond_in_size=hnet_uncond_in_size,
                          cond_in_size=cond_in_size,
                          layers=layers,
                          verbose=False,
                          activation_fn=activation_fn,
                          use_bias=use_bias,
                          no_uncond_weights=no_uncond_weights,
                          no_cond_weights=no_cond_weights,
                          num_cond_embs=hnet_num_cond_embs,
                          dropout_rate=dropout_rate,
                          use_spectral_norm=use_spectral_norm,
                          use_batch_norm=use_batch_norm)

        ### Setup attributes required by interface ###
        # Most of these attributes are taken over from `self._hnet`
        self._target_shapes = target_shapes
        self._num_known_conds = self._num_cond_embs
        self._unconditional_param_shapes_ref = \
            list(self._hnet._unconditional_param_shapes_ref)

        if self._hnet._internal_params is not None:
            self._internal_params = \
                nn.ParameterList(self._hnet._internal_params)
        self._param_shapes = list(self._hnet._param_shapes)
        self._param_shapes_meta = list(self._hnet._param_shapes_meta)
        if self._hnet._hyper_shapes_learned is not None:
            self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned)
            self._hyper_shapes_learned_ref = \
                list(self._hnet._hyper_shapes_learned_ref)
        if self._hnet._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = \
                list(self._hnet._hyper_shapes_distilled)
        self._has_bias = self._hnet._has_bias
        self._has_fc_out = self._hnet._has_fc_out
        # Just to make that clear explicitly. We will additionally append
        # the chunk embeddings at the end of `param_shapes`.
        # We don't prepend it to the beginning, to keep conditional input
        # embeddings at the beginning.
        self._mask_fc_out = False
        self._has_linear_out = self._hnet._has_linear_out
        self._layer_weight_tensors = \
            nn.ParameterList(self._hnet._layer_weight_tensors)
        self._layer_bias_vectors = \
            nn.ParameterList(self._hnet._layer_bias_vectors)
        if self._hnet._batchnorm_layers is not None:
            self._batchnorm_layers = nn.ModuleList(
                self._hnet._batchnorm_layers)
        if self._hnet._context_mod_layers is not None:
            self._context_mod_layers = \
                nn.ModuleList(self._hnet._context_mod_layers)

        ### Create chunk embeddings ###
        if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs:
            # Note, we could also allow this case. It would be analoguous to
            # creating a full hypernet with no unconditional input and one
            # conditional embedding. But the user can explicitly achieve that
            # as noted below.
            raise ValueError('If no external (conditional or unconditional) ' +
                             'input is provided to the hypernetwork, then ' +
                             'it can only learn a fixed output. If this ' +
                             'behavior is desired, please enable ' +
                             '"cond_chunk_embs" and set "num_cond_embs=1".')

        num_cemb_mats = 1
        no_cemb_weights = no_uncond_weights
        if cond_chunk_embs:
            num_cemb_mats = num_cond_embs
            no_cemb_weights = no_cond_weights

        self._cemb_shape = [self.num_chunks, chunk_emb_size]

        for _ in range(num_cemb_mats):
            if not no_cemb_weights:
                self._internal_params.append(nn.Parameter( \
                    data=torch.Tensor(*self._cemb_shape), requires_grad=True))
                torch.nn.init.normal_(self._internal_params[-1],
                                      mean=0.,
                                      std=1.)
            else:
                self._hyper_shapes_learned.append(self._cemb_shape)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            if not cond_chunk_embs:
                self._unconditional_param_shapes_ref.append( \
                    len(self.param_shapes))

            self._param_shapes.append(self._cemb_shape)
            # In principle, these embeddings also belong to the input, so we
            # just assign them as "layer" 0 (note, the underlying hnet uses the
            # same layer ID for its embeddings.
            self._param_shapes_meta.append({
                'name': 'embedding',
                'index': -1 if no_cemb_weights else \
                    len(self._internal_params)-1,
                'layer': 0,
                'info': 'chunk embeddings'
            })

        ### Finalize construction ###
        self._is_properly_setup()

        if verbose:
            print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \
                  % (self.num_chunks, chunk_size))
            print(self)
class ChunkedHMLP(nn.Module, HyperNetInterface):
    """Implementation of a `chunked fully-connected hypernet`.

    The ``target_shapes`` will be flattened and split into chunks of size
    ``chunk_size``. In total, there will be
    ``np.ceil(self.num_outputs/chunk_size)`` chunks, where the last chunk
    produced might contain a remainder that is discarded.

    Each chunk has it's own `chunk embedding` that is fed into the underlying
    hypernetwork.

    Note:
        It is possible to set ``uncond_in_size`` and ``cond_in_size`` to zero
        if ``cond_chunk_embs`` is ``True``.

    Attributes:
        num_chunks (int): The number of chunks that make up the final hypernet
            output. This also corresponds to the number of chunk embeddings
            required per forward sweep.
        chunk_emb_size (int): See constructor argument ``chunk_emb_size``.

    Args:
        (....): See constructor arguments of class
            :class:`hnets.mlp_hnet.HMLP`.
        chunk_size (int): The chunk size, i.e, the number of weights produced by
            single the internally maintained instance of a full hypernet
            (see :class:`hnets.mlp_hnet.HMLP`) at a time (i.e., per chunk
            embedding).
        chunk_emb_size (int): The size of a chunk embedding.
        cond_chunk_embs (bool): Whether chunk embeddings are unconditional
            (``False``) or conditional (``True``) parameters. See constructor
            argument ``cond_chunk_embs``.

            Note:
                Embeddings will be initialized with a normal distribution using
                zero mean and unit variance.
        cond_chunk_embs (bool): Consider chunk embeddings to be conditional.
            In this case, there will be a different set of chunk embeddings per
            condition (specified via ``num_cond_embs``).

            If ``False``, there will be a total of :attr:`num_chunks` chunk
            embeddings that are maintained within :attr:`hnets.hnet_interface.\
HyperNetInterface.unconditional_param_shapes`. If ``True``, there will be
            ``num_cond_embs * self.num_chunks`` chunk embeddings that are
            maintained within :attr:`hnets.hnet_interface.\
HyperNetInterface.conditional_param_shapes`. However, if ``num_cond_embs == 0``,
            then chunk embeddings have to be provided in a special way to the
            :meth:`forward` method (see the corresponding argument ``weights``).
    """
    def __init__(self,
                 target_shapes,
                 chunk_size,
                 chunk_emb_size=8,
                 cond_chunk_embs=False,
                 uncond_in_size=0,
                 cond_in_size=8,
                 layers=(100, 100),
                 verbose=True,
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_uncond_weights=False,
                 no_cond_weights=False,
                 num_cond_embs=1,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        HyperNetInterface.__init__(self)

        assert isinstance(chunk_size, int) and chunk_size > 0
        assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0

        ### Make constructor arguments internally available ###
        self._chunk_size = chunk_size
        self._chunk_emb_size = chunk_emb_size
        self._cond_chunk_embs = cond_chunk_embs
        self._uncond_in_size = uncond_in_size
        self._cond_in_size = cond_in_size
        self._no_uncond_weights = no_uncond_weights
        self._no_cond_weights = no_cond_weights
        self._num_cond_embs = num_cond_embs

        ### Create underlying full hypernet ###
        # Note, even if chunk embeddings are considered conditional, they
        # are maintained in this object and just fed as an external input to the
        # underlying hnet.
        hnet_uncond_in_size = uncond_in_size + chunk_emb_size
        hnet_num_cond_embs = num_cond_embs
        if cond_chunk_embs and num_cond_embs == 0:
            raise ValueError('Conditional chunk embeddings can only be used ' +
                             'if conditions are known to the hypernetwork!')
        if cond_chunk_embs and cond_in_size == 0:
            # If there are no other conditional embeddings except the chunk
            # embeddings, we tell the underlying hnet explicitly that it doesn't
            # need to maintain any conditional weights to avoid that it will
            # throw a warning.
            hnet_num_cond_embs = 0
        self._hnet = HMLP([[chunk_size]],
                          uncond_in_size=hnet_uncond_in_size,
                          cond_in_size=cond_in_size,
                          layers=layers,
                          verbose=False,
                          activation_fn=activation_fn,
                          use_bias=use_bias,
                          no_uncond_weights=no_uncond_weights,
                          no_cond_weights=no_cond_weights,
                          num_cond_embs=hnet_num_cond_embs,
                          dropout_rate=dropout_rate,
                          use_spectral_norm=use_spectral_norm,
                          use_batch_norm=use_batch_norm)

        ### Setup attributes required by interface ###
        # Most of these attributes are taken over from `self._hnet`
        self._target_shapes = target_shapes
        self._num_known_conds = self._num_cond_embs
        self._unconditional_param_shapes_ref = \
            list(self._hnet._unconditional_param_shapes_ref)

        if self._hnet._internal_params is not None:
            self._internal_params = \
                nn.ParameterList(self._hnet._internal_params)
        self._param_shapes = list(self._hnet._param_shapes)
        self._param_shapes_meta = list(self._hnet._param_shapes_meta)
        if self._hnet._hyper_shapes_learned is not None:
            self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned)
            self._hyper_shapes_learned_ref = \
                list(self._hnet._hyper_shapes_learned_ref)
        if self._hnet._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = \
                list(self._hnet._hyper_shapes_distilled)
        self._has_bias = self._hnet._has_bias
        self._has_fc_out = self._hnet._has_fc_out
        # Just to make that clear explicitly. We will additionally append
        # the chunk embeddings at the end of `param_shapes`.
        # We don't prepend it to the beginning, to keep conditional input
        # embeddings at the beginning.
        self._mask_fc_out = False
        self._has_linear_out = self._hnet._has_linear_out
        self._layer_weight_tensors = \
            nn.ParameterList(self._hnet._layer_weight_tensors)
        self._layer_bias_vectors = \
            nn.ParameterList(self._hnet._layer_bias_vectors)
        if self._hnet._batchnorm_layers is not None:
            self._batchnorm_layers = nn.ModuleList(
                self._hnet._batchnorm_layers)
        if self._hnet._context_mod_layers is not None:
            self._context_mod_layers = \
                nn.ModuleList(self._hnet._context_mod_layers)

        ### Create chunk embeddings ###
        if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs:
            # Note, we could also allow this case. It would be analoguous to
            # creating a full hypernet with no unconditional input and one
            # conditional embedding. But the user can explicitly achieve that
            # as noted below.
            raise ValueError('If no external (conditional or unconditional) ' +
                             'input is provided to the hypernetwork, then ' +
                             'it can only learn a fixed output. If this ' +
                             'behavior is desired, please enable ' +
                             '"cond_chunk_embs" and set "num_cond_embs=1".')

        num_cemb_mats = 1
        no_cemb_weights = no_uncond_weights
        if cond_chunk_embs:
            num_cemb_mats = num_cond_embs
            no_cemb_weights = no_cond_weights

        self._cemb_shape = [self.num_chunks, chunk_emb_size]

        for _ in range(num_cemb_mats):
            if not no_cemb_weights:
                self._internal_params.append(nn.Parameter( \
                    data=torch.Tensor(*self._cemb_shape), requires_grad=True))
                torch.nn.init.normal_(self._internal_params[-1],
                                      mean=0.,
                                      std=1.)
            else:
                self._hyper_shapes_learned.append(self._cemb_shape)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            if not cond_chunk_embs:
                self._unconditional_param_shapes_ref.append( \
                    len(self.param_shapes))

            self._param_shapes.append(self._cemb_shape)
            # In principle, these embeddings also belong to the input, so we
            # just assign them as "layer" 0 (note, the underlying hnet uses the
            # same layer ID for its embeddings.
            self._param_shapes_meta.append({
                'name': 'embedding',
                'index': -1 if no_cemb_weights else \
                    len(self._internal_params)-1,
                'layer': 0,
                'info': 'chunk embeddings'
            })

        ### Finalize construction ###
        self._is_properly_setup()

        if verbose:
            print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \
                  % (self.num_chunks, chunk_size))
            print(self)

    @property
    def num_chunks(self):
        """Getter for read-only attribute :attr:`num_chunks`."""
        return int(np.ceil(self.num_outputs / self._chunk_size))

    @property
    def chunk_emb_size(self):
        """Getter for read-only attribute :attr:`chunk_emb_size`."""
        return self._chunk_emb_size

    @property
    def cond_chunk_embs(self):
        """Getter for read-only attribute :attr:`cond_chunk_embs`."""
        return self._cond_chunk_embs

    def forward(self,
                uncond_input=None,
                cond_input=None,
                cond_id=None,
                weights=None,
                distilled_params=None,
                condition=None,
                ret_format='squeezed',
                ext_inputs=None,
                task_emb=None,
                task_id=None,
                theta=None,
                dTheta=None):
        """Compute the weights of a target network.

        Args:
            (....): See docstring of method
                :meth:`hnets.mlp_hnet.HMLP.forward`.
            weights (list or dict, optional): If provided as ``dict`` and
                chunk embeddings are considered conditional (see constructor
                argument ``cond_chunk_embs``), then the additional key
                ``chunk_embs`` can be used to pass a batch of chunk embeddings.
                This option is mutually exclusive with the option of passing
                ``cond_id``. Note, if conditional inputs via ``cond_input`` are
                expected, then the batch sizes must agree.

                A batch of chunk embeddings is expected to be tensor of shape
                ``[B, num_chunks, chunk_emb_size]``, where ``B`` denotes the
                batch size.

        Returns:
            (list or torch.Tensor): See docstring of method
            :meth:`hnets.hnet_interface.HyperNetInterface.forward`.
        """
        cond_chunk_embs = None
        if isinstance(weights, dict):
            if 'chunk_embs' in weights.keys():
                cond_chunk_embs = weights['chunk_embs']
                if not self._cond_chunk_embs:
                    raise ValueError('Key "chunk_embs" for argument ' +
                                     '"weights" is only allowed if chunk ' +
                                     'embeddings are conditional.')
                assert len(cond_chunk_embs.shape) == 3 and \
                    np.all(np.equal(cond_chunk_embs.shape[1:],
                                    [self.num_chunks, self.chunk_emb_size]))

                if cond_id is not None:
                    raise ValueError(
                        'Option "cond_id" is mutually exclusive ' +
                        'with key "chunk_embs" for argument ' + '"weights".')
                assert cond_input is None or \
                    cond_input.shape[0] == cond_chunk_embs.shape[0]

                # Remove `chunk_embs` from dictionary, since upper class parser
                # doesn't know how to deal with it.
                del weights['chunk_embs']
                if len(weights.keys()) == 0:  # Empty dictionary.
                    weights = None

        if cond_input is not None and self._cond_chunk_embs and \
                cond_chunk_embs is None:
            raise ValueError('Conditional chunk embeddings have to be ' +
                             'provided via "weights" if "cond_input" is ' +
                             'specified.')

        _input_required = self._cond_in_size > 0 or self._uncond_in_size > 0
        # We parse `cond_id` afterwards if chunk embeddings are also
        # conditional.
        if self._cond_chunk_embs:
            _parse_cond_id_fct = lambda x, y, z: None
        else:
            _parse_cond_id_fct = None

        uncond_input, cond_input, uncond_weights, cond_weights = \
            self._preprocess_forward_args(_input_required=_input_required,
                _parse_cond_id_fct=_parse_cond_id_fct,
                uncond_input=uncond_input, cond_input=cond_input,
                cond_id=cond_id, weights=weights,
                distilled_params=distilled_params, condition=condition,
                ret_format=ret_format, ext_inputs=ext_inputs, task_emb=task_emb,
                task_id=task_id, theta=theta, dTheta=dTheta)

        ### Translate IDs to conditional inputs ###
        if cond_id is not None and self._cond_chunk_embs:
            assert cond_input is None and cond_chunk_embs is None
            cond_id = [cond_id] if isinstance(cond_id, int) else cond_id

            if cond_weights is None:
                raise ValueError('Forward option "cond_id" can only be ' +
                                 'used if conditional parameters are ' +
                                 'maintained internally or passed to the ' +
                                 'forward method via option "weights".')

            cond_chunk_embs = []
            cond_input = [] if self._cond_in_size > 0 else None

            for i, cid in enumerate(cond_id):
                if cid < 0 or cid >= self._num_cond_embs:
                    raise ValueError('Condition %d not existing!' % (cid))

                # Note, we do not necessarily have conditional embeddings.
                if self._cond_in_size > 0:
                    cond_input.append(cond_weights[cid])

                cond_chunk_embs.append( \
                    cond_weights[-self._num_cond_embs+cid])

            if self._cond_in_size > 0:
                cond_input = torch.stack(cond_input, dim=0)
            cond_chunk_embs = torch.stack(cond_chunk_embs, dim=0)

        ### Assemble hypernetwork input ###
        batch_size = None
        if cond_input is not None:
            batch_size = cond_input.shape[0]
        if cond_chunk_embs is not None:
            assert batch_size is None or batch_size == cond_chunk_embs.shape[0]
            batch_size = cond_chunk_embs.shape[0]
        if uncond_input is not None:
            if batch_size is None:
                batch_size = uncond_input.shape[0]
            else:
                assert batch_size == uncond_input.shape[0]
        assert batch_size is not None

        chunk_embs = None
        if self._cond_chunk_embs:
            assert cond_chunk_embs is not None and \
                len(cond_chunk_embs.shape) == 3
            assert self._cond_in_size == 0 or cond_input is not None
            chunk_embs = cond_chunk_embs
        else:
            assert cond_chunk_embs is None
            chunk_embs = uncond_weights[-1]
            # Insert batch dimension.
            chunk_embs = chunk_embs.expand(batch_size, self.num_chunks,
                                           self.chunk_emb_size)

        # We now have the following setup:
        # cond_input: [batch_size, cond_in_size] or None
        # uncond_input: [batch_size, uncond_in_size] or None
        # chunk_embs: [batch_size, num_chunks, chunk_emb_size]

        # We now first copy the hypernet inputs for each chunk, arriving at
        # cond_input: [batch_size, num_chunks, cond_in_size] or None
        # uncond_input: [batch_size, num_chunks, uncond_in_size] or None
        if cond_input is not None:
            cond_input = cond_input.reshape(batch_size, 1, -1)
            cond_input = cond_input.expand(batch_size, self.num_chunks,
                                           self._cond_in_size)
        if uncond_input is not None:
            uncond_input = uncond_input.reshape(batch_size, 1, -1)
            uncond_input = uncond_input.expand(batch_size, self.num_chunks,
                                               self._uncond_in_size)
            # The chunk embeddings are considered unconditional inputs to the
            # underlying hypernetwork.
            uncond_input = torch.cat([uncond_input, chunk_embs], dim=2)
        else:
            uncond_input = chunk_embs

        # Now we build one big batch for the underlying hypernetwork, with
        # batch size: batch_size * num_chunks.
        if cond_input is not None:
            cond_input = cond_input.reshape(batch_size * self.num_chunks, -1)
        uncond_input = uncond_input.reshape(batch_size * self.num_chunks, -1)

        ### Weight of underlying hypernetwork ###
        weights = dict()
        if cond_weights is not None and self._cond_chunk_embs:
            weights['cond_weights'] = cond_weights[:-self._num_cond_embs]
        elif cond_weights is not None:
            weights['cond_weights'] = cond_weights
        assert uncond_weights is not None
        if self._cond_chunk_embs:
            weights['uncond_weights'] = uncond_weights
        else:
            weights['uncond_weights'] = uncond_weights[:-1]

        ### Process chunks ###
        hnet_out = self._hnet.forward(uncond_input=uncond_input,
                                      cond_input=cond_input,
                                      cond_id=None,
                                      weights=weights,
                                      distilled_params=distilled_params,
                                      condition=condition,
                                      ret_format='flattened')
        assert np.all(
            np.equal(hnet_out.shape,
                     [batch_size * self.num_chunks, self._chunk_size]))

        # FIXME We can skip this line, right?
        hnet_out = hnet_out.view(batch_size, self.num_chunks, self._chunk_size)
        # Concatenate individual chunks.
        hnet_out = hnet_out.view(batch_size,
                                 self.num_chunks * self._chunk_size)

        # Throw away unused part of last chunk.
        hnet_out = hnet_out[:, :self.num_outputs]

        ### Assemble hypernet output ###
        ret = self._flat_to_ret_format(hnet_out, ret_format)

        return ret

    def distillation_targets(self):
        """Targets to be distilled after training.

        See docstring of abstract super method
        :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`.

        Returns:
            See :meth:`hnets.mlp_hnet.HMLP.distillation_targets`.
        """
        # We don't have any additional distillation targets. We also just pass
        # `distilled_params` to the underlying hypernetwork in the `forward`
        # method.
        return self._hnet.distillation_targets

    def apply_chunked_hyperfan_init(self,
                                    method='in',
                                    use_xavier=False,
                                    uncond_var=1.,
                                    cond_var=1.,
                                    eps=1e-5,
                                    cemb_normal_init=False,
                                    mnet=None,
                                    target_vars=None):
        r"""Initialize the network using a chunked hyperfan init.

        Inspired by the method
        `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we
        implemented for the MLP hypernetwork in method
        :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`, we heuristically
        developed a better initialization method for chunked hypernetworks.

        Unfortunately, the `Hyperfan Init` method from the paper does not apply
        to this kind of hypernetwork, since we reuse the same hypernet output
        head for the whole main network.

        Luckily, we can provide a simple heuristic. Similar to
        `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play
        with the variance of the input embeddings to affect the variance of the
        output weights.

        In a chunked hypernetwork, the input for each chunk is identical except
        for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}`
        denote the remaining inputs to the hypernetwork, which are identical
        for all chunks. Then, assuming the hypernetwork was initialized via
        fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}`
        can be written as follows (see documentation of method
        :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`):

        .. math::

            \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \
                \frac{n_c}{n_e+n_c} \text{Var}(c)

        Hence, we can achieve a desired output variance :math:`\text{Var}(v)`
        by initializing the chunk embeddings :math:`\mathbf{c}` via the
        following variance:

        .. math::

            \text{Var}(c) = \max \Big\{ 0, \
                \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \
                n_e \text{Var}(e) \big] \Big\}

        Now, one important question remains. How do we pick a desired output
        variance :math:`\text{Var}(v)` for a chunk?

        Note, a chunk may include weights from several layers. The likelihood
        for this to happen depends on the main net architecture and the chunk
        size (see constructor argument ``chunk_size``). The smaller the chunk
        size, the less likely it is that a chunk will contain elements from
        multiple main net weight tensors.

        In case each chunk would contain only weights from one main net weight
        tensor, we could simply pick the variance :math:`\text{Var}(v)` that
        would have been chosen by a main net initialization method (such as
        Xavier).

        In case a chunk contains contributions from several main net weight
        tensors, we apply the following heuristic. If a chunk contains
        contributions of a set of main network weight tensors
        :math:`W_1, \dots, W_K` with relative contribution sizes\
        :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where
        :math:`n_v` denotes the chunk size and if the corresponding main network
        initialization method would require init variances
        :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request
        a weighted average as follow:

        .. math::

            \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k)

        What about bias vectors? Usually, the variance analysis applied to
        Xavier or Kaiming init assumes that biases are initialized to zero. This
        is not possible in this setting, as it would require assigning a
        negative variance to :math:`\mathbf{c}`. Instead, we follow the default
        PyTorch initialization (e.g., see method ``reset_parameters`` in class
        :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly
        within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where
        :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of
        initialization corresponds to a variance of
        :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`.

        Note:
            All hypernet inputs are assumed to be zero-mean random variables.

        Note:
            To avoid that the variances with which chunks are initialized
            have to be clipped (because they are too small or even negative),
            the variance of the remaining hypernet inputs should be properly
            scaled. In general, one should adhere the following rule

            .. math::

                \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v)

            This method will calculate and print the maximum value that should
            be chosen for :math:`\text{Var}(e)` and will print warnings if
            variances have to be clipped.

        Args:
            (....): See arguments of method
                :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`.
            method (str): The type of initialization that should be applied.
                Possible options are:

                - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output
                  variances of the hypernetwork should correspond to fan-in
                  variances.
                - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output
                  variances of the hypernetwork should correspond to fan-out
                  variances.
                - ``harmonic``: Use the harmonic mean of the fan-in and fan-out
                  variance as target variance of the hypernetwork output.
            eps (float): The minimum variance with which a chunk embedding is
                initialized.
            cemb_normal_init (bool): Use normal init for chunk embeddings
                rather than uniform init.
            target_vars (list or dict, optional): The variance of the
                distribution for each parameter tensor generated by this
                hypernetwork. Target variance values can either be provided as
                list of length ``len(hnet.target_shapes)`` or as dictionary.
                The usage is analoguous to the usage of parameter ``w_val`` of
                method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`.

                Note:
                    This method currently does not allow initial output
                    distributions with non-zero mean. However, the docstring of
                    method
                    :meth:`probabilistic.gauss_hnet_init.gauss_hyperfan_init`
                    describes how this is in principle feasible and might be
                    incorporated in the future.
                
                Note:
                    Unspecified target variances for parameter tensors of type
                    ``'weight'`` or ``'bias'`` are computed as described above.
                    Default target variances for all other parameter tensor
                    types are simply ``1``.
        """
        if method not in ['in', 'out', 'harmonic']:
            raise ValueError('Invalid value "%s" for argument "method".' %
                             method)
        if self.unconditional_params is None:
            assert self._no_uncond_weights
            raise ValueError('Hypernet without internal weights can\'t be ' +
                             'initialized.')
        if self.unconditional_params is None and self._cond_chunk_embs:
            assert self._no_cond_weights
            raise ValueError('Chunked hyperfan init cannot be applied if ' +
                             'chunk embeddings are not internally maintained.')

        ### Extract meta-information about target shapes ###
        # FIXME This section is copied from the HMLP implementation.
        meta = None
        if mnet is not None:
            assert isinstance(mnet, MainNetInterface)

            try:
                meta = mnet.param_shapes_meta
            except:
                meta = None

            if meta is not None:
                if len(self.target_shapes) == len(mnet.param_shapes):
                    pass
                    # meta = mnet.param_shapes_meta
                elif len(self.target_shapes) == len(mnet.hyper_shapes_learned):
                    meta = []
                    for ii in mnet.hyper_shapes_learned_ref:
                        meta.append(mnet.param_shapes_meta[ii])
                else:
                    warn('Target shapes of this hypernetwork could not be ' +
                         'matched to the meta information provided to the ' +
                         'initialization.')
                    meta = None

        # TODO If the user doesn't (or can't) provide an `mnet` instance, we
        # should alternatively allow him to pass meta information directly.
        if meta is None:
            meta = []

            # Heuristical approach to derive meta information from given shapes.
            layer_ind = 0
            for i, s in enumerate(self.target_shapes):
                curr_meta = dict()

                if len(s) > 1:
                    curr_meta['name'] = 'weight'
                    curr_meta['layer'] = layer_ind
                    layer_ind += 1
                else:  # just a heuristic, we can't know
                    curr_meta['name'] = 'bias'
                    if i > 0 and meta[-1]['name'] == 'weight':
                        curr_meta['layer'] = meta[-1]['layer']
                    else:
                        curr_meta['layer'] = -1

                meta.append(curr_meta)

        assert len(meta) == len(self.target_shapes)

        # Mapping from layer index to the corresponding shape.
        layer_shapes = dict()
        # Mapping from layer index to whether the layer has a bias vector.
        layer_has_bias = defaultdict(lambda: False)
        for i, m in enumerate(meta):
            if m['name'] == 'weight' and m['layer'] != -1:
                assert len(self.target_shapes[i]) > 1
                layer_shapes[m['layer']] = self.target_shapes[i]
            if m['name'] == 'bias' and m['layer'] != -1:
                layer_has_bias[m['layer']] = True

        ### Compute input variance ###
        # The input variance does not include the variance of chunk embeddings!
        # Instead, it is the variance of the inputs that are shared across all
        # chunks.
        cond_dim = self._cond_in_size
        uncond_dim = self._uncond_in_size
        # Note, `inp_dim` can be zero if conditional chunk embeddings are used.
        inp_dim = cond_dim + uncond_dim

        inp_var = 0
        if cond_dim > 0:
            inp_var += (cond_dim / inp_dim) * cond_var
        if uncond_dim > 0:
            inp_var += (uncond_dim / inp_dim) * uncond_var

        c_dim = self.chunk_emb_size

        ### Initialize hypernet with fan-in init ###
        if self.batchnorm_layers is not None and len(
                self.batchnorm_layers) > 0:
            # Note, batchnorm layers simply whiten the incoming statistics.
            # Thus, if we tune the variance of chunk embeddings, this variance
            # is normalized by a batchnorm layer and thus vanishes.
            raise RuntimeError('Chunked hyperfan init not applicable if a ' +
                               'hypernetwork with batchnorm layers is used.')

        # Note, the whole internal hypernetwork is initialized with fan-in init
        # to simply pass the variance of all inputs to the hypernet output.
        for i, w_tensor in enumerate(self.layer_weight_tensors):
            if use_xavier:
                iutils.xavier_fan_in_(w_tensor)
            else:
                torch.nn.init.kaiming_uniform_(w_tensor,
                                               mode='fan_in',
                                               nonlinearity='relu')

            if self.has_bias:
                nn.init.zeros_(self.layer_bias_vectors[i])

        ### Compute target variance of each output tensor ###
        if target_vars is None:
            target_vars = [None] * len(self.target_shapes)
        elif isinstance(target_vars, dict):
            target_vars_d = target_vars
            target_vars = []

            for i, m in enumerate(meta):
                if m['name'] in target_vars_d.keys():
                    target_vars.append(target_vars_d[m['name']])
                else:
                    target_vars.append(None)
        else:
            assert isinstance(target_vars, (list, tuple))
            assert len(target_vars) == len(self.target_shapes)

        for i, s in enumerate(self.target_shapes):
            if target_vars[i] is not None:
                # Use user specified target variance.
                continue

            m = meta[i]

            if m['name'] == 'bias':
                if m['layer'] != -1:
                    fan_in, _ = iutils.calc_fan_in_and_out( \
                        layer_shapes[m['layer']])
                else:
                    # FIXME Quick-fix, use fan-out instead.
                    fan_in = s[0]

                target_vars[i] = 1. / (3. * fan_in)

            elif m['name'] == 'weight':
                fan_in, fan_out = iutils.calc_fan_in_and_out(s)

                c_relu = 1 if use_xavier else 2

                var_in = c_relu / fan_in
                var_out = c_relu / fan_out

                if method == 'in':
                    var = var_in
                elif method == 'out':
                    var = var_out
                else:
                    var = 2 * (1. / var_in + 1. / var_out)

                target_vars[i] = var

            else:
                target_vars[i] = 1.

        ### Target variance per chunk ###
        chunk_vars = []
        i = 0
        n = np.prod(self.target_shapes[i])

        for j in range(self.num_chunks):
            m = self._chunk_size
            var = 0

            while m > 0:
                # Special treatment to fill up last chunk.
                if j == self.num_chunks - 1 and i == len(target_vars) - 1:
                    assert n <= m
                    o = m
                else:
                    o = min(m, n)

                var += o / self._chunk_size * target_vars[i]
                m -= o
                n -= o

                if n == 0:
                    i += 1
                    if i < len(target_vars):
                        n = np.prod(self.target_shapes[i])

            chunk_vars.append(var)

        if inp_dim > 0:
            max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars)
            max_inp_std = math.sqrt(max_inp_var)
            print('Initializing hypernet with Chunked Hyperfan Init ...')
            if inp_var >= max_inp_var:
                warn('Note, hypernetwork inputs should have an initial total ' +
                     'variance (std) smaller than %f (%f) in order for this ' \
                     % (max_inp_var, max_inp_std) + 'method to work properly.')

        ### Compute variances of chunk embeddings ###
        # We could have done that in the previous loop. But I think the code is
        # more readible this way.
        c_vars = []
        n_clipped = 0
        for i, var in enumerate(chunk_vars):
            c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var)
            if c_var < eps:
                n_clipped += 1
                #warn('Initial variance of chunk embedding %d has to ' % i + \
                #     'be clipped.')

            c_vars.append(max(eps, c_var))

        if n_clipped > 0:
            warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \
                 'chunk embeddings had to be clipped.')

        ### Initialize chunk embeddings ###
        for i in range(self.num_chunks):
            c_std = math.sqrt(c_vars[i])

            num_conds = self.num_known_conds if self._cond_chunk_embs else 1
            for j in range(num_conds):
                cond_id = j if self._cond_chunk_embs else None

                c_emb = self.get_chunk_emb(chunk_id=i, cond_id=cond_id)

                if cemb_normal_init:
                    torch.nn.init.normal_(c_emb, mean=0, std=c_std)
                else:
                    a = math.sqrt(3.0) * c_std
                    torch.nn.init._no_grad_uniform_(c_emb, -a, a)

    def get_cond_in_emb(self, cond_id):
        """Get the ``cond_id``-th (conditional) input embedding.

        Args:
            (....): See docstring of method
                :meth:`hnets.mlp_hnet.HMLP.get_cond_in_emb`.

        Returns:
            (torch.nn.Parameter)
        """
        return self._hnet.get_cond_in_emb(cond_id)

    def get_chunk_emb(self, chunk_id=None, cond_id=None):
        """Get the ``chunk_id``-th chunk embedding.

        Args:
            chunk_id (int, optional): A number between 0 and :attr:`num_chunks`
                - 1. If not specified, a full chunk matrix with shape
                ``[num_chunks, chunk_emb_size]`` is returned. Otherwise,
                the ``chunk_id``-th row is returned.
            cond_id (int): Is mandatory if constructor argument
                ``cond_chunk_embs`` was set. Determines the set of chunk
                embeddings to be considered.

        Returns:
            (torch.nn.Parameter)
        """
        if self._cond_chunk_embs:
            if cond_id is None:
                raise RuntimeError('Option "cond_id" has to be set if chunk ' +
                                   'embeddings are conditional parameters!')
            if self.conditional_params is None:
                raise RuntimeError('Conditional chunk embeddings are not ' +
                                   'internally maintained!')
            if not isinstance(cond_id, int) or cond_id < 0 or \
                    cond_id >= self._num_cond_embs:
                raise RuntimeError('Option "cond_id" must be between 0 and ' +
                                   '%d!' % (self._num_cond_embs - 1))
            # Note, the last `self._num_cond_embs` params are chunk embeddings.
            chunk_embs = self.conditional_params[-self._num_cond_embs +
                                                 cond_id]
        else:
            assert cond_id is None
            if self.unconditional_params is None:
                raise RuntimeError('Chunk embeddings are not internally ' +
                                   'maintained!')
            chunk_embs = self.unconditional_params[-1]

        if chunk_id is None:
            return chunk_embs
        else:
            if not isinstance(chunk_id, int) or chunk_id < 0 or \
                    chunk_id >= self.num_chunks:
                raise RuntimeError('Option "chunk_id" must be between 0 and ' +
                                   '%d!' % (self.num_chunks - 1))

            return chunk_embs[chunk_id, :]
Esempio n. 5
0
class ChunkedHMLP(nn.Module, HyperNetInterface):
    """Implementation of a `chunked fully-connected hypernet`.

    The ``target_shapes`` will be flattened and split into chunks of size
    ``chunk_size``. In total, there will be
    ``np.ceil(self.num_outputs/chunk_size)`` chunks, where the last chunk
    produced might contain a remainder that is discarded.

    Each chunk has it's own `chunk embedding` that is fed into the underlying
    hypernetwork.

    Note:
        It is possible to set ``uncond_in_size`` and ``cond_in_size`` to zero
        if ``cond_chunk_embs`` is ``True``.

    Attributes:
        num_chunks (int): The number of chunks that make up the final hypernet
            output. This also corresponds to the number of chunk embeddings
            required per forward sweep.
        chunk_emb_size (int): See constructor argument ``chunk_emb_size``.

    Args:
        (....): See constructor arguments of class
            :class:`hnets.mlp_hnet.HMLP`.
        chunk_size (int): The chunk size, i.e, the number of weights produced by
            single the internally maintained instance of a full hypernet
            (see :class:`hnets.mlp_hnet.HMLP`) at a time (i.e., per chunk
            embedding).
        chunk_emb_size (int): The size of a chunk embedding.

            Note:
                Embeddings will be initialized with a normal distribution using
                zero mean and unit variance.
        cond_chunk_embs (bool): Consider chunk embeddings to be conditional.
            In this case, there will be a different set of chunk embeddings per
            condition (specified via ``num_cond_embs``).

            If ``False``, there will be a total of :attr:`num_chunks` chunk
            embeddings that are maintained within :attr:`hnets.hnet_interface.\
HyperNetInterface.unconditional_param_shapes`. If ``True``, there will be
            ``num_cond_embs * self.num_chunks`` chunk embeddings that are
            maintained within :attr:`hnets.hnet_interface.\
HyperNetInterface.conditional_param_shapes`. However, if ``num_cond_embs == 0``,
            then chunk embeddings have to be provided in a special way to the
            :meth:`forward` method (see the corresponding argument ``weights``).
    """
    def __init__(self, target_shapes, chunk_size, chunk_emb_size=8,
                 cond_chunk_embs=False, uncond_in_size=0, cond_in_size=8,
                 layers=(100, 100), verbose=True, activation_fn=torch.nn.ReLU(),
                 use_bias=True, no_uncond_weights=False, no_cond_weights=False,
                 num_cond_embs=1, dropout_rate=-1, use_spectral_norm=False,
                 use_batch_norm=False):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        HyperNetInterface.__init__(self)

        assert isinstance(chunk_size, int) and chunk_size > 0
        assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0

        ### Make constructor arguments internally available ###
        self._chunk_size = chunk_size
        self._chunk_emb_size = chunk_emb_size
        self._cond_chunk_embs = cond_chunk_embs
        self._uncond_in_size = uncond_in_size
        self._cond_in_size = cond_in_size
        self._no_uncond_weights = no_uncond_weights
        self._no_cond_weights = no_cond_weights
        self._num_cond_embs = num_cond_embs

        ### Create underlying full hypernet ###
        # Note, even if chunk embeddings are considered conditional, they
        # are maintained in this object and just fed as an external input to the
        # underlying hnet.
        hnet_uncond_in_size = uncond_in_size + chunk_emb_size
        hnet_num_cond_embs = num_cond_embs
        if cond_chunk_embs and cond_in_size == 0:
            # If there are no other conditional embeddings except the chunk
            # embeddings, we tell the underlying hnet explicitly that it doesn't
            # need to maintain any conditional weights to avoid that it will
            # throw a warning.
            hnet_num_cond_embs = 0
        self._hnet = HMLP([[chunk_size]], uncond_in_size=hnet_uncond_in_size,
            cond_in_size=cond_in_size, layers=layers, verbose=False,
            activation_fn=activation_fn, use_bias=use_bias,
            no_uncond_weights=no_uncond_weights,
            no_cond_weights=no_cond_weights, num_cond_embs=hnet_num_cond_embs,
            dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm)

        ### Setup attributes required by interface ###
        # Most of these attributes are taken over from `self._hnet`
        self._target_shapes = target_shapes
        self._num_known_conds = self._num_cond_embs
        self._unconditional_param_shapes_ref = \
            list(self._hnet._unconditional_param_shapes_ref)

        if self._hnet._internal_params is not None:
            self._internal_params = \
                nn.ParameterList(self._hnet._internal_params)
        self._param_shapes = list(self._hnet._param_shapes)
        self._param_shapes_meta = list(self._hnet._param_shapes_meta)
        if self._hnet._hyper_shapes_learned is not None:
            self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned)
            self._hyper_shapes_learned_ref = \
                list(self._hnet._hyper_shapes_learned_ref)
        if self._hnet._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = \
                list(self._hnet._hyper_shapes_distilled)
        self._has_bias = self._hnet._has_bias
        self._has_fc_out = self._hnet._has_fc_out
        # Just to make that clear explicitly. We will additionally append
        # the chunk embeddings at the end of `param_shapes`.
        # We don't prepend it to the beginning, to keep conditional input
        # embeddings at the beginning.
        self._mask_fc_out = False
        self._has_linear_out = self._hnet._has_linear_out
        self._layer_weight_tensors = \
            nn.ParameterList(self._hnet._layer_weight_tensors)
        self._layer_bias_vectors = \
            nn.ParameterList(self._hnet._layer_bias_vectors)
        if self._hnet._batchnorm_layers is not None:
            self._batchnorm_layers = nn.ModuleList(self._hnet._batchnorm_layers)
        if self._hnet._context_mod_layers is not None:
            self._context_mod_layers = \
                nn.ModuleList(self._hnet._context_mod_layers)

        ### Create chunk embeddings ###
        if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs:
            # Note, we could also allow this case. It would be analoguous to
            # creating a full hypernet with no unconditional input and one
            # conditional embedding. But the user can explicitly achieve that
            # as noted below.
            raise ValueError('If no external (conditional or unconditional) ' +
                             'input is provided to the hypernetwork, then ' +
                             'it can only learn a fixed output. If this ' +
                             'behavior is desired, please enable ' +
                             '"cond_chunk_embs" and set "num_cond_embs=1".')

        num_cemb_mats = 1
        no_cemb_weights = no_uncond_weights
        if cond_chunk_embs:
            num_cemb_mats = num_cond_embs
            no_cemb_weights = no_cond_weights

        self._cemb_shape = [self.num_chunks, chunk_emb_size]

        for _ in range(num_cemb_mats):
            if not no_cemb_weights:
                self._internal_params.append(nn.Parameter( \
                    data=torch.Tensor(*self._cemb_shape), requires_grad=True))
                torch.nn.init.normal_(self._internal_params[-1], mean=0.,
                                      std=1.)
            else:
                self._hyper_shapes_learned.append(self._cemb_shape)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            if not cond_chunk_embs:
                self._unconditional_param_shapes_ref.append( \
                    len(self.param_shapes))

            self._param_shapes.append(self._cemb_shape)
            # In principle, these embeddings also belong to the input, so we
            # just assign them as "layer" 0 (note, the underlying hnet uses the
            # same layer ID for its embeddings.
            self._param_shapes_meta.append({
                'name': 'embedding',
                'index': -1 if no_cemb_weights else \
                    len(self._internal_params)-1,
                'layer': 0,
                'info': 'chunk embeddings'
            })

        ### Finalize construction ###
        self._is_properly_setup()

        if verbose:
            print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \
                  % (self.num_chunks, chunk_size))
            print(self)

    @property
    def num_chunks(self):
        """Getter for read-only attribute :attr:`num_chunks`."""
        return int(np.ceil(self.num_outputs / self._chunk_size))

    @property
    def chunk_emb_size(self):
        """Getter for read-only attribute :attr:`chunk_emb_size`."""
        return self._chunk_emb_size

    def forward(self, uncond_input=None, cond_input=None, cond_id=None,
                weights=None, distilled_params=None, condition=None,
                ret_format='squeezed', ext_inputs=None, task_emb=None,
                task_id=None, theta=None, dTheta=None):
        """Compute the weights of a target network.

        Args:
            (....): See docstring of method
                :meth:`hnets.mlp_hnet.HMLP.forward`.
            weights (list or dict, optional): If provided as ``dict`` and
                chunk embeddings are considered conditional (see constructor
                argument ``cond_chunk_embs``), then the additional key
                ``chunk_embs`` can be used to pass a batch of chunk embeddings.
                This option is mutually exclusive with the option of passing
                ``cond_id``. Note, if conditional inputs via ``cond_input`` are
                expected, then the batch sizes must agree.

                A batch of chunk embeddings is expected to be tensor of shape
                ``[B, num_chunks, chunk_emb_size]``, where ``B`` denotes the
                batch size.

        Returns:
            (list or torch.Tensor): See docstring of method
            :meth:`hnets.hnet_interface.HyperNetInterface.forward`.
        """
        cond_chunk_embs = None
        if isinstance(weights, dict):
            if 'chunk_embs' in weights.keys():
                cond_chunk_embs = weights['chunk_embs']
                if not self._cond_chunk_embs:
                    raise ValueError('Key "chunk_embs" for argument ' +
                                     '"weights" is only allowed if chunk ' +
                                     'embeddings are conditional.')
                assert len(cond_chunk_embs.shape) == 3 and \
                    np.all(np.equal(cond_chunk_embs.shape[1:],
                                    [self.num_chunks, self.chunk_emb_size]))

                if cond_id is not None:
                    raise ValueError('Option "cond_id" is mutually exclusive ' +
                                     'with key "chunk_embs" for argument ' +
                                     '"weights".')
                assert cond_input is None or \
                    cond_input.shape[0] == cond_chunk_embs.shape[0]

                # Remove `chunk_embs` from dictionary, since upper class parser
                # doesn't know how to deal with it.
                del weights['chunk_embs']
                if len(weights.keys()) == 0: # Empty dictionary.
                    weights = None

        if cond_input is not None and self._cond_chunk_embs and \
                cond_chunk_embs is None:
            raise ValueError('Conditional chunk embeddings have to be ' +
                             'provided via "weights" if "cond_input" is ' +
                             'specified.')

        _input_required = self._cond_in_size > 0 or self._uncond_in_size > 0
        # We parse `cond_id` afterwards if chunk embeddings are also
        # conditional.
        if self._cond_chunk_embs:
            _parse_cond_id_fct = lambda x, y, z: None
        else:
            _parse_cond_id_fct = None

        uncond_input, cond_input, uncond_weights, cond_weights = \
            self._preprocess_forward_args(_input_required=_input_required,
                _parse_cond_id_fct=_parse_cond_id_fct,
                uncond_input=uncond_input, cond_input=cond_input,
                cond_id=cond_id, weights=weights,
                distilled_params=distilled_params, condition=condition,
                ret_format=ret_format, ext_inputs=ext_inputs, task_emb=task_emb,
                task_id=task_id, theta=theta, dTheta=dTheta)

        ### Translate IDs to conditional inputs ###
        if cond_id is not None and self._cond_chunk_embs:
            assert cond_input is None and cond_chunk_embs is None
            cond_id = [cond_id] if isinstance(cond_id, int) else cond_id

            if cond_weights is None:
                raise ValueError('Forward option "cond_id" can only be ' +
                                 'used if conditional parameters are ' +
                                 'maintained internally or passed to the ' +
                                 'forward method via option "weights".')

            cond_chunk_embs = []
            cond_input = [] if self._cond_in_size > 0 else None

            for i, cid in enumerate(cond_id):
                if cid < 0 or cid >= self._num_cond_embs:
                    raise ValueError('Condition %d not existing!' % (cid))

                # Note, we do not necessarily have conditional embeddings.
                if self._cond_in_size > 0:
                    cond_input.append(cond_weights[cid])

                cond_chunk_embs.append( \
                    cond_weights[-self._num_cond_embs+cid])

            if self._cond_in_size > 0:
                cond_input = torch.stack(cond_input, dim=0)
            cond_chunk_embs = torch.stack(cond_chunk_embs, dim=0)

        ### Assemble hypernetwork input ###
        batch_size = None
        if cond_input is not None:
            batch_size = cond_input.shape[0]
        if cond_chunk_embs is not None:
            assert batch_size is None or batch_size == cond_chunk_embs.shape[0]
            batch_size = cond_chunk_embs.shape[0]
        if uncond_input is not None:
            if batch_size is None:
                batch_size = uncond_input.shape[0]
            else:
                assert batch_size == uncond_input.shape[0]
        assert batch_size is not None

        chunk_embs = None
        if self._cond_chunk_embs:
            assert cond_chunk_embs is not None and \
                len(cond_chunk_embs.shape) == 3
            assert self._cond_in_size == 0 or cond_input is not None
            chunk_embs = cond_chunk_embs
        else:
            assert cond_chunk_embs is None
            chunk_embs = uncond_weights[-1]
            # Insert batch dimension.
            chunk_embs = chunk_embs.expand(batch_size, self.num_chunks,
                                           self.chunk_emb_size)

        # We now have the following setup:
        # cond_input: [batch_size, cond_in_size] or None
        # uncond_input: [batch_size, uncond_in_size] or None
        # chunk_embs: [batch_size, num_chunks, chunk_emb_size]

        # We now first copy the hypernet inputs for each chunk, arriving at
        # cond_input: [batch_size, num_chunks, cond_in_size] or None
        # uncond_input: [batch_size, num_chunks, uncond_in_size] or None
        if cond_input is not None:
            cond_input = cond_input.reshape(batch_size, 1, -1)
            cond_input = cond_input.expand(batch_size, self.num_chunks,
                                           self._cond_in_size)
        if uncond_input is not None:
            uncond_input = uncond_input.reshape(batch_size, 1, -1)
            uncond_input = uncond_input.expand(batch_size, self.num_chunks,
                                               self._uncond_in_size)
            # The chunk embeddings are considered unconditional inputs to the
            # underlying hypernetwork.
            uncond_input = torch.cat([uncond_input, chunk_embs], dim=2)
        else:
            uncond_input = chunk_embs

        # Now we build one big batch for the underlying hypernetwork, with
        # batch size: batch_size * num_chunks.
        if cond_input is not None:
            cond_input = cond_input.reshape(batch_size * self.num_chunks, -1)
        uncond_input = uncond_input.reshape(batch_size * self.num_chunks, -1)

        ### Weight of underlying hypernetwork ###
        weights = dict()
        if cond_weights is not None and self._cond_chunk_embs:
            weights['cond_weights'] = cond_weights[:-self._num_cond_embs]
        elif cond_weights is not None:
            weights['cond_weights'] = cond_weights
        assert uncond_weights is not None
        if self._cond_chunk_embs:
            weights['uncond_weights'] = uncond_weights
        else:
            weights['uncond_weights'] = uncond_weights[:-1]

        ### Process chunks ###
        hnet_out = self._hnet.forward(uncond_input=uncond_input,
                cond_input=cond_input, cond_id=None, weights=weights,
                distilled_params=distilled_params, condition=condition,
                ret_format='flattened')
        assert np.all(np.equal(hnet_out.shape, [batch_size * self.num_chunks,
                                                self._chunk_size]))

        # FIXME We can skip this line, right?
        hnet_out = hnet_out.view(batch_size, self.num_chunks, self._chunk_size)
        # Concatenate individual chunks.
        hnet_out = hnet_out.view(batch_size, self.num_chunks * self._chunk_size)

        ### Assemble hypernet output ###
        ret = self._flat_to_ret_format(hnet_out, ret_format)

        return ret

    def distillation_targets(self):
        """Targets to be distilled after training.

        See docstring of abstract super method
        :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`.

        Returns:
            See :meth:`hnets.mlp_hnet.HMLP.distillation_targets`.
        """
        # We don't have any additional distillation targets. We also just pass
        # `distilled_params` to the underlying hypernetwork in the `forward`
        # method.
        return self._hnet.distillation_targets

    def apply_chunked_hyperfan_init(self, method='in', use_xavier=False,
                                    uncond_var=1., cond_var=1., eps=1e-5,
                                    cemb_normal_init=False):
        """Not implemented yet!"""
        # TODO Translate from old hypernet implementation and take meta
        # information of generated parameters into account.
        raise NotImplementedError()

    def get_cond_in_emb(self, cond_id):
        """Get the ``cond_id``-th (conditional) input embedding.

        Args:
            (....): See docstring of method
                :meth:`hnets.mlp_hnet.HMLP.get_cond_in_emb`.

        Returns:
            (torch.nn.Parameter)
        """
        return self._hnet.get_cond_in_emb(cond_id)

    def get_chunk_emb(self, chunk_id=None, cond_id=None):
        """Get the ``chunk_id``-th chunk embedding.

        Args:
            chunk_id (int, optional): A number between 0 and :attr:`num_chunks`
                - 1. If not specified, a full chunk matrix with shape
                ``[num_chunks, chunk_emb_size]`` is returned. Otherwise,
                the ``chunk_id``-th row is returned.
            cond_id (int): Is mandatory if constructor argument
                ``cond_chunk_embs`` was set. Determines the set of chunk
                embeddings to be considered.

        Returns:
            (torch.nn.Parameter)
        """
        if self._cond_chunk_embs:
            if cond_id is None:
                raise RuntimeError('Option "cond_id" has to be set if chunk ' +
                                   'embeddings are conditional parameters!')
            if self.conditional_params is None:
                raise RuntimeError('Conditional chunk embeddings are not ' +
                                   'internally maintained!')
            if not isinstance(cond_id, int) or cond_id < 0 or \
                    cond_id >= self._num_cond_embs:
                raise RuntimeError('Option "cond_id" must be between 0 and ' +
                                   '%d!' % (self._num_cond_embs-1))
            # Note, the last `self._num_cond_embs` params are chunk embeddings.
            chunk_embs = self.conditional_params[-self._num_cond_embs+cond_id]
        else:
            assert cond_id is None
            if self.unconditional_params is None:
                raise RuntimeError('Chunk embeddings are not internally ' +
                                   'maintained!')
            chunk_embs = self.unconditional_params[-1]

        if chunk_id is None:
            return chunk_embs
        else:
            if not isinstance(chunk_id, int) or chunk_id < 0 or \
                    chunk_id >= self.num_chunks:
                raise RuntimeError('Option "chunk_id" must be between 0 and ' +
                                   '%d!' % (self.num_chunks-1))

            return chunk_embs[chunk_id, :]