Beispiel #1
0
def generate_classifier(config, data_handlers, device):
    """Create a classifier network. Depending on the experiment and method, 
    the method manages to build either a classifier for task inference 
    or a classifier that solves our task is build. This also implies if the
    network will receive weights from a hypernetwork or will have weights 
    on its own.
    Following important configurations will be determined in order to create
    the classifier: \n 
    * in- and output and hidden layer dimensions of the classifier. \n
    * architecture, chunk- and task-embedding details of the hypernetwork. 


    See :class:`mnets.mlp.MLP` for details on the network that will be created
        to be a classifier. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.
        
    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
    
    Returns: 
        (tuple): Tuple containing:
        - **net**: The classifier network.
        - **class_hnet**: (optional) The classifier's hypernetwork.
    """
    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2

    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.class_incremental:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.training_task_infer or config.class_incremental:
        # task inference network
        config.out_dim = 1

    # have all output neurons already build up for cl 2
    if config.cl_scenario != 2:
        n_out = config.out_dim * config.num_tasks
    else:
        n_out = config.out_dim

    if config.training_task_infer or config.class_incremental:
        n_out = config.num_tasks

        # build classifier
    print('For the Classifier: ')
    class_arch = misc.str_to_ints(config.class_fc_arch)
    if config.training_with_hnet:
        no_weights = True
    else:
        no_weights = False

    net = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=class_arch,
              activation_fn=misc.str_to_act(config.class_net_act),
              dropout_rate=config.class_dropout_rate,
              no_weights=no_weights).to(device)

    print('Constructed MLP with shapes: ', net.param_shapes)

    config.num_weights_class_net = \
        MainNetInterface.shapes_to_num_weights(net.param_shapes)
    # build classifier hnet
    # this is set in the run method in train.py
    if config.training_with_hnet:

        class_hnet = sim_utils.get_hnet_model(config,
                                              config.num_tasks,
                                              device,
                                              net.param_shapes,
                                              cprefix='class_')
        init_params = list(class_hnet.parameters())

        config.num_weights_class_hyper_net = sum(
            p.numel() for p in class_hnet.parameters() if p.requires_grad)
        config.compression_ratio_class = config.num_weights_class_hyper_net / \
                                         config.num_weights_class_net
        print('Created classifier Hypernetwork with ratio: ',
              config.compression_ratio_class)
        if config.compression_ratio_class > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network, not might not be directly ' +
                  'comparable with the number of parameters of work we ' +
                  'compare against.')
    else:
        class_hnet = None
        init_params = list(net.parameters())
        config.num_weights_class_hyper_net = None
        config.compression_ratio_class = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if config.training_with_hnet:
        for temb in class_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(class_hnet, 'chunk_embeddings'):
        for emb in class_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    if not config.training_with_hnet:
        return net
    else:
        return net, class_hnet
Beispiel #2
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False,
                   **mnet_kwargs):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:

            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``wrn``: :class:`mnets.wide_resnet.WRN`
            - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
            - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer`
            - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.
        **mnet_kwargs: Additional keyword arguments that will be passed to the
            main network constructor.

    Returns:
        The created main network model.
    """
    assert (net_type in [
        'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp',
        'simple_rnn', 'wrn', 'iresnet'
    ])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    # This argument has to be handled during usage of the network and not during
    # construction.
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        # Default keyword arguments of class MLP.
        dkws = misc.get_default_args(MLP.__init__)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)
        assert hc('resnet_block_depth') and hc('resnet_channel_sizes')

        # Default keyword arguments of class ResNet.
        dkws = misc.get_default_args(ResNet.__init__)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('resnet_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')),
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'wrn':
        assert (len(out_shape) == 1)
        assert hc('wrn_block_depth') and hc('wrn_widening_factor')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(WRN.__init__)

        mnet = WRN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('wrn_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')),
            verbose=True,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            k=gc('wrn_widening_factor'),
            use_fc_bias=gc('wrn_use_fc_bias'),
            dropout_rate=gc('dropout_rate'),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'iresnet':
        assert (len(out_shape) == 1)
        assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \
            and hc('iresnet_blocks_per_group') \
            and hc('iresnet_bottleneck_blocks') \
            and hc('iresnet_projection_shortcut')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(ResNetIN.__init__)

        mnet = ResNetIN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            use_bias=assign(not no_bias, dkws['use_bias']),
            use_fc_bias=gc('wrn_use_fc_bias'),
            num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')),
            blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')),
            projection_shortcut=gc('iresnet_projection_shortcut'),
            bottleneck_blocks=gc('iresnet_bottleneck_blocks'),
            #cutout_mod=False,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #chw_input_format=False,
            verbose=True,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class ZenkeNet.
        dkws = misc.get_default_args(ZenkeNet.__init__)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            **mnet_kwargs).to(device)

    elif net_type == 'bio_conv_net':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class BioConvNet.
        #dkws = misc.get_default_args(BioConvNet.__init__)

        mnet = BioConvNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            no_weights=no_weights,
            #init_weights=None,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'chunked_mlp':
        assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \
               hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \
               hc('cmlp_cemb_dim')
        assert len(in_shape) == 1 and len(out_shape) == 1

        # Default keyword arguments of class ChunkSqueezer.
        dkws = misc.get_default_args(ChunkSqueezer.__init__)

        mnet = ChunkSqueezer(
            n_in=in_shape[0],
            n_out=out_shape[0],
            inp_chunk_dim=gc('cmlp_in_cdim'),
            out_chunk_dim=gc('cmlp_out_cdim'),
            cemb_size=gc('cmlp_cemb_dim'),
            #cemb_init_std=1.,
            red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')),
            net_layers=misc.str_to_ints(gc('cmlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #dynamic_biases=None,
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'lenet':
        assert hc('lenet_type')
        assert len(out_shape) == 1

        # Default keyword arguments of class LeNet.
        dkws = misc.get_default_args(LeNet.__init__)

        mnet = LeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,
            arch=gc('lenet_type'),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            # TODO Context-mod weights.
            **mnet_kwargs).to(device)

    else:
        assert (net_type == 'simple_rnn')
        assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \
            hc('srnn_post_fc_layers')  and hc('srnn_no_fc_out') and \
            hc('srnn_rec_type')
        assert len(in_shape) == 1 and len(out_shape) == 1

        if gc('srnn_rec_type') == 'lstm':
            use_lstm = True
        else:
            assert gc('srnn_rec_type') == 'elman'
            use_lstm = False

        # Default keyword arguments of class SimpleRNN.
        dkws = misc.get_default_args(SimpleRNN.__init__)

        rnn_layers = misc.str_to_ints(gc('srnn_rec_layers'))
        fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers'))
        if gc('srnn_no_fc_out'):
            rnn_layers.append(out_shape[0])
        else:
            fc_layers.append(out_shape[0])

        mnet = SimpleRNN(n_in=in_shape[0],
                         rnn_layers=rnn_layers,
                         fc_layers_pre=misc.str_to_ints(
                             gc('srnn_pre_fc_layers')),
                         fc_layers=fc_layers,
                         activation=assign(net_act, dkws['activation']),
                         use_lstm=use_lstm,
                         use_bias=assign(not no_bias, dkws['use_bias']),
                         no_weights=no_weights,
                         verbose=True,
                         **mnet_kwargs).to(device)

    return mnet
Beispiel #3
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:
            
            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.

    Returns:
        The created main network model.
    """
    assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net'])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # FIXME if an argument wasn't specified, then we use the default value that
    # is currently (at time of implementation) in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, torch.nn.ReLU()),
            use_bias=not assign(no_bias, False),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, -1),
            use_spectral_norm=assign(specnorm, False),
            use_batch_norm=assign(use_bn, False),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, True),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
        ).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, 0.25)).to(device)
    else:
        assert (net_type == 'bio_conv_net')
        assert (len(out_shape) == 1)

        raise NotImplementedError('Implementation not publicly available!')

    return mnet
Beispiel #4
0
def generate_replay_networks(config,
                             data_handlers,
                             device,
                             create_rp_hnet=True,
                             only_train_replay=False):
    """Create a replay model that consists of either a encoder/decoder or
    a discriminator/generator pair. Additionally, this method manages the 
    creation of a hypernetwork for the generator/decoder. 
    Following important configurations will be determined in order to create
    the replay model: 
    * in- and output and hidden layer dimensions of the encoder/decoder. 
    * architecture, chunk- and task-embedding details of decoder's hypernetwork. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device..
        create_rp_hnet: Whether a hypernetwork for the replay should be 
            constructed. If not, the decoder/generator will have 
            trainable weights on its own.
        only_train_replay: We normally do not train on the last task since we do 
            not need to replay this last tasks data. But if we want a replay 
            method to be able to generate data from all tasks then we set this 
            option to true.

    Returns:
        (tuple): Tuple containing:

        - **enc**: Encoder/discriminator network instance.
        - **dec**: Decoder/generator networkinstance.
        - **dec_hnet**: Hypernetwork instance for the decoder/generator. This 
            return value is None if no hypernetwork should be constructed.
    """

    if config.replay_method == 'gan':
        n_out = 1
    else:
        n_out = config.latent_dim * 2

    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2
    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.single_class_replay:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.infer_task_id:
        # task inference network
        config.out_dim = 1

    # builld encoder
    print('For the replay encoder/discriminator: ')
    enc_arch = misc.str_to_ints(config.enc_fc_arch)
    enc = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=enc_arch,
              activation_fn=misc.str_to_act(config.enc_net_act),
              dropout_rate=config.enc_dropout_rate,
              no_weights=False).to(device)
    print('Constructed MLP with shapes: ', enc.param_shapes)
    init_params = list(enc.parameters())
    # builld decoder
    print('For the replay decoder/generator: ')
    dec_arch = misc.str_to_ints(config.dec_fc_arch)
    # add dimensions for conditional input
    n_out = config.latent_dim

    if config.conditional_replay:
        n_out += config.conditional_dim

    dec = MLP(n_in=n_out,
              n_out=n_in,
              hidden_layers=dec_arch,
              activation_fn=misc.str_to_act(config.dec_net_act),
              use_bias=True,
              no_weights=config.rp_beta > 0,
              dropout_rate=config.dec_dropout_rate).to(device)

    print('Constructed MLP with shapes: ', dec.param_shapes)
    config.num_weights_enc = \
                        MainNetInterface.shapes_to_num_weights(enc.param_shapes)

    config.num_weights_dec = \
                        MainNetInterface.shapes_to_num_weights(dec.param_shapes)
    config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec
    # we do not need a replay model for the last task

    # train on last task or not
    if only_train_replay:
        subtr = 0
    else:
        subtr = 1

    num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1

    if config.single_class_replay:
        # we do not need a replay model for the last task
        if config.num_tasks > 1:
            num_embeddings = config.out_dim * (config.num_tasks - subtr)
        else:
            num_embeddings = config.out_dim * (config.num_tasks)

    config.num_embeddings = num_embeddings
    # build decoder hnet
    if create_rp_hnet:
        print('For the decoder/generator hypernetwork: ')
        d_hnet = sim_utils.get_hnet_model(config,
                                          num_embeddings,
                                          device,
                                          dec.hyper_shapes_learned,
                                          cprefix='rp_')

        init_params += list(d_hnet.parameters())

        config.num_weights_rp_hyper_net = sum(p.numel()
                                              for p in d_hnet.parameters()
                                              if p.requires_grad)
        config.compression_ratio_rp = config.num_weights_rp_hyper_net / \
                                                        config.num_weights_dec
        print('Created replay hypernetwork with ratio: ',
              config.compression_ratio_rp)
        if config.compression_ratio_rp > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network,\nthis might not be directly ' +
                  'comparable with the number of parameters of methods we ' +
                  'compare against.')
    else:
        num_embeddings = config.num_tasks - subtr
        d_hnet = None
        init_params += list(dec.parameters())
        config.num_weights_rp_hyper_net = 0
        config.compression_ratio_rp = 0

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_rp_hnet:
        for temb in d_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(d_hnet, 'chunk_embeddings'):
        for emb in d_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    return enc, dec, d_hnet
Beispiel #5
0
class BiRNN(nn.Module, MainNetInterface):
    r"""Implementation of a bidirectional RNN.

    Note:
        The output is non-linear if the last layer is recurrent! Otherwise,
        logits are returned (cmp. attribute
        :attr:`mnets.mnet_interface.MainNetInterface.has_fc_out`).

    Example:
        Here is an example instantiation of a BiLSTM with a single bidirectional
        layer of dimensionality 256, assuming 100 dimensional inputs and
        10 dimensional outputs.

        .. code-block:: python

            net = BiRNN(rnn_args={'n_in': 100, 'rnn_layers': [256],
                                  'use_lstm': True, 'fc_layers_pre': [],
                                  'fc_layers': []},
                        mlp_args={'n_in': 512, 'n_out': 10,
                                  'hidden_layers': []},
                        no_weights=False)

    Attributes:
        preprocess_fct (func, optional): See constructor argument
            ``preprocess_fct``.
        num_rec_layers (int): See attribute
            :attr:`mnets.simple_rnn.SimpleRNN.num_rec_layers`. Total number of
            recurrent layer, where each bidirectional layer consists of at
            least two recurrent layers (forward and backward layer).
        use_lstm (bool): See attribute
            :attr:`mnets.simple_rnn.SimpleRNN.use_lstm`.

    Args:
        rnn_args (dict or list): A dictionary of arguments for an instance of
            class :class:`mnets.simple_rnn.SimpleRNN`. These arguments will be
            used to create two instances of this class, one representing the
            forward RNN and one the backward RNN.

            Note, each of these instances may contain multiple layers, even
            non-recurrent layers. The outputs of such an instance are considered
            the hidden activations :math:`\hat{h}_{1:T}^{(f)}` or
            :math:`\hat{h}_{1:T}^{(b)}`, respectively.

            To realize multiple bidirectional layers (which in itself can be
            multi-layer RNNs), one may provide a list of dictionaries. Each
            entry in such list will be used to generate a single bidirectional
            layer (i.e., consisting of two instances of class
            :class:`mnets.simple_rnn.SimpleRNN`). Note, the input size of
            each new layer has to be twice the size of :math:`\hat{h}_t^{(f)}`
            from the previous layer.
        mlp_args (dict, optional): A dictionary of arguments for class
            :class:`mnets.mlp.MLP`. The input size of such an MLP should be
            twice the size of :math:`\hat{h}_t^{(f)}`. If ``None``, then the
            output of the last bidirectional layer is considered the output of
            the network.
        preprocess_fct (func, optional): A function handle can be provided,
            that will process inputs ``x`` passed to the method :meth:`forward`.
            An example usecase could be the translation or selection of word
            embeddings.

            The function handle must have the signature:
            ``preprocess_fct(x, seq_lengths=None)``. See the corresponding
            argument descriptions of method :meth:`forward`.The function is
            expected to return the preprocessed ``x``.
        no_weights (bool): See parameter ``no_weights`` of class
            :class:`mnets.mlp.MLP`.
        verbose (bool): See parameter ``verbose`` of class
            :class:`mnets.mlp.MLP`.
    """
    def __init__(self,
                 rnn_args={},
                 mlp_args=None,
                 preprocess_fct=None,
                 no_weights=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        assert isinstance(rnn_args, (dict, list, tuple))
        assert mlp_args is None or isinstance(mlp_args, dict)

        if isinstance(rnn_args, dict):
            rnn_args = [rnn_args]

        self._forward_rnns = []
        self._backward_rnns = []
        self._out_mlp = None
        self._preprocess_fct = preprocess_fct
        self._forward_called = False

        # FIXME At the moment we do not control input and output size of
        # individual networks and need to assume that the user sets them
        # correctly.

        ### Create all forward and backward nets for each bidirectional layer.
        for rargs in rnn_args:
            assert isinstance(rargs, dict)
            if 'verbose' not in rargs.keys():
                rargs['verbose'] = False
            if 'no_weights' in rargs.keys() and \
                    rargs['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'bidirectional layer is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in rargs.keys():
                rargs['no_weights'] = no_weights

            self._forward_rnns.append(SimpleRNN(**rargs))
            self._backward_rnns.append(SimpleRNN(**rargs))

        ### Create output network.
        if mlp_args is not None:
            if 'verbose' not in mlp_args.keys():
                mlp_args['verbose'] = False
            if 'no_weights' in mlp_args.keys() and \
                    mlp_args['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'output MLP is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in mlp_args.keys():
                mlp_args['no_weights'] = no_weights

            self._out_mlp = MLP(**mlp_args)

        ### Set all interface attributes correctly.
        if self._out_mlp is None:
            self._has_fc_out = self._forward_rnns[-1].has_fc_out
            # We can't set the following attribute to true, as the output is
            # a concatenation of the outputs from two networks. Therefore, the
            # weights used two compute the outputs are at different locations
            # in the `param_shapes` list.
            self._mask_fc_out = False
            self._has_linear_out = self._forward_rnns[-1].has_linear_out
        else:
            self._has_fc_out = self._out_mlp.has_fc_out
            self._mask_fc_out = self._out_mlp.mask_fc_out
            self._has_linear_out = self._out_mlp.has_linear_out

        # Collect all internal net objects from which we need to collect
        # attributes.
        nets = []
        for i, fnet in enumerate(self._forward_rnns):
            bnet = self._backward_rnns[i]

            nets.append((fnet, 'forward_rnn', i))
            nets.append((bnet, 'backward_rnn', i))
        if self._out_mlp is not None:
            nets.append((self._out_mlp, 'out_mlp', -1))

        # Iterate over all nets to collect their attribute values.
        self._param_shapes = []
        self._param_shapes_meta = []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        for i, net_tup in enumerate(nets):
            net, net_type, net_id = net_tup
            # Note, it is important to convert lists into new object and not
            # just copy references!
            # Note, we have to adapt all references if `i > 0`.

            # Sanity check:
            if i == 0:
                cm_nw = net._context_mod_no_weights
            elif cm_nw != net._context_mod_no_weights:
                raise ValueError('Network expect that either all internal ' +
                                 'networks maintain their context-mod ' +
                                 'weights or non of them does!')

            ps_len_old = len(self._param_shapes)

            if net._internal_params is not None:
                if self._internal_params is None:
                    self._internal_params = nn.ParameterList()
                ip_len_old = len(self._internal_params)
                self._internal_params.extend( \
                    nn.ParameterList(net._internal_params))
            self._param_shapes.extend(list(net._param_shapes))
            for meta in net.param_shapes_meta:
                assert 'birnn_layer_type' not in meta.keys()
                assert 'birnn_layer_id' not in meta.keys()

                new_meta = dict(meta)
                new_meta['birnn_layer_type'] = net_type
                new_meta['birnn_layer_id'] = net_id
                if i > 0:
                    # FIXME We should properly adjust colliding `layer` IDs.
                    new_meta['layer'] = -1
                new_meta['index'] = meta['index'] + ip_len_old
                self._param_shapes_meta.append(new_meta)

            if net._hyper_shapes_learned is not None:
                if self._hyper_shapes_learned is None:
                    self._hyper_shapes_learned = []
                    self._hyper_shapes_learned_ref = []
                self._hyper_shapes_learned.extend( \
                    list(net._hyper_shapes_learned))
                for ref in net._hyper_shapes_learned_ref:
                    self._hyper_shapes_learned_ref.append(ref + ps_len_old)
            if net._hyper_shapes_distilled is not None:
                if self._hyper_shapes_distilled is None:
                    self._hyper_shapes_distilled = []
                self._hyper_shapes_distilled.extend( \
                    list(net._hyper_shapes_distilled))

            if self._has_bias is None:
                self._has_bias = net._has_bias
            elif self._has_bias != net._has_bias:
                self._has_bias = False
                # FIXME We should overwrite the getter and throw an error!
                warn('Some internally maintained networks use biases, ' +
                     'while others don\'t. Setting attribute "has_bias" to ' +
                     'False.')

            self._layer_weight_tensors.extend( \
                nn.ParameterList(net._layer_weight_tensors))
            self._layer_bias_vectors.extend( \
                nn.ParameterList(net._layer_bias_vectors))
            if net._batchnorm_layers is not None:
                if self._batchnorm_layers is None:
                    self._batchnorm_layers = nn.ModuleList()
                self._batchnorm_layers.extend( \
                    nn.ModuleList(net._batchnorm_layers))
            if net._context_mod_layers is not None:
                if self._context_mod_layers is None:
                    self._context_mod_layers = nn.ModuleList()
                self._context_mod_layers.extend( \
                    nn.ModuleList(net._context_mod_layers))

        self._is_properly_setup()

        ### Print user information.
        if verbose:
            print('Constructed Bidirectional RNN with %d weights.' \
                  % self.num_params)

    @property
    def preprocess_fct(self):
        """Getter for attribute :attr:`preprocess_fct`."""
        return self._preprocess_fct

    @preprocess_fct.setter
    def preprocess_fct(self, value):
        """Setter for attribute :attr:`preprocess_fct`.

        Note:
            This setter may only be called before the first call of the
            :meth:`forward` method.
        """
        if self._forward_called:
            raise RuntimeError('Attribute "preprocess_fct" cannot be ' +
                               'modified after method "forward" has been ' +
                               'called.')
        self._preprocess_fct = value

    @property
    def num_rec_layers(self):
        """Getter for read-only attribute :attr:`num_rec_layers`."""
        num_rec_layers = 0
        for net in self._forward_rnns + self._backward_rnns:
            num_rec_layers += net.num_rec_layers
        return num_rec_layers

    @property
    def use_lstm(self):
        """Getter for read-only attribute :attr:`use_lstm`."""
        use_lstm = self._forward_rnns[0].use_lstm
        for i in range(1, len(self._forward_rnns)):
            if self._forward_rnns[i].use_lstm != use_lstm:
                raise RuntimeError('Attribute "use_lstm" not applicable to ' +
                                   'this network as layers use mixed types ' +
                                   'of RNNs.')
        return use_lstm

    def distillation_targets(self):
        """Targets to be distilled after training.

        See docstring of abstract super method
        :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`.
        """
        # SimpleRNNs should not have any distillation targets.
        for net in self._forward_rnns + self._backward_rnns:
            if net.distillation_targets is not None:
                raise RuntimeError()

        if self._out_mlp is not None:
            return self._out_mlp.distillation_targets()
        return None

    def forward(self,
                x,
                weights=None,
                distilled_params=None,
                condition=None,
                seq_lengths=None):
        """Compute the output :math:`y` of this network given the input
        :math:`x`.

        Note:
            If constructor argument ``preprocess_fct`` was set, then all
            inputs ``x`` are first processed by this function.

        Args:
            (....): See docstring of method
                :meth:`mnets.mnet_interface.MainNetInterface.forward`. We
                provide some more specific information below.
            weights (list or dict): See argument ``weights`` of method
                :meth:`mnets.mlp.MLP.forward`.
            distilled_params: Will only be passed to the underlying instance
                of class :class:`mnets.mlp.MLP`
            condition (int or dict, optional): If provided, then this argument
                will be passed as argument ``ckpt_id`` to the method
                :meth:`utils.context_mod_layer.ContextModLayer.forward`.

                When providing as dict, see argument ``condition`` of method
                :meth:`mnets.mlp.MLP.forward` for more details.
            seq_lengths (numpy.ndarray, optional): List of sequence
                lengths. The length of the list has to match the batch size of
                inputs ``x``. The entries will correspond to the unpadded
                sequence lengths. If this option is provided, then the
                bidirectional layers will reverse its input sequences according
                to the unpadded sequence lengths.

                Example:
                    ``x = [[a,b,0,0], [a,b,c,0]].T``. If
                    ``seq_lengths = [2, 3]`` if provided, then the reverse
                    sequences ``[[b,a,0,0], [c,b,a,0]].T`` are fed into the
                    first bidirectional layer (and similarly for all subsequent
                    bidirectional layers). Otherwise reverse sequences
                    ``[[0,0,b,a], [0,c,b,a]].T`` are used.

                Caution:
                    If this option is not provided but padded input sequences
                    are used, the output of a bidirectional layer will depent on
                    the padding. I.e., different padding lengths will lead to
                    different results.

        Returns:
            (torch.Tensor or tuple): Where the tuple is containing:

            - **output** (torch.Tensor): The output of the network.
            - **hidden** (list): ``None`` - not implemented yet.
        """
        # FIXME Delete warning below.
        if seq_lengths is None:
            warn('"seq_lengths" has not been provided to BiRNN.')

        if self._out_mlp is None:
            assert distilled_params is None

        ########################
        ### Parse condition ###
        #######################
        rnn_cmod_cond = None
        mlp_cond = None

        if condition is not None:
            if isinstance(condition, dict):
                if 'cmod_ckpt_id' in condition.keys():
                    rnn_cmod_cond = condition['cmod_ckpt_id']
                    mlp_cond = condition
            else:
                rnn_cmod_cond = condition
                mlp_cond = {'cmod_ckpt_id': condition}

        ########################################
        ### Extract-weights for each network ###
        ########################################
        forward_weights = [None] * len(self._forward_rnns)
        backward_weights = [None] * len(self._backward_rnns)
        mlp_weights = None

        n_cm = self._num_context_mod_shapes()
        int_weights = None
        cm_weights = None
        all_weights = None
        if weights is not None and isinstance(weights, dict):
            if 'internal_weights' in weights.keys():
                int_weights = weights['internal_weights']
            if 'mod_weights' in weights.keys():
                cm_weights = weights['mod_weights']

        elif weights is not None:
            if len(weights) == n_cm:
                cm_weights = weights
            else:
                assert len(weights) == len(self.param_shapes)
                all_weights = weights

        if weights is not None:
            # Collect all context-mod and internal weights if not explicitly
            # passed. Note, those will either be taken from `all_weights` or
            # have to exist internally.
            if n_cm > 0 and cm_weights is None:
                cm_weights = []
                for ii, meta in enumerate(self.param_shapes_meta):
                    if meta['name'].startswith('cm_'):
                        if all_weights is not None:
                            cm_weights.append(all_weights[ii])
                        else:
                            assert meta['index'] != -1
                            cm_weights.append( \
                                self.internal_params[meta['index']])
            if int_weights is None:
                int_weights = []
                for ii, meta in enumerate(self.param_shapes_meta):
                    if not meta['name'].startswith('cm_'):
                        if all_weights is not None:
                            int_weights.append(all_weights[ii])
                        else:
                            assert meta['index'] != -1
                            int_weights.append( \
                                self.internal_params[meta['index']])

            # Now that we have all context-mod and internal weights, we need to
            # distribute them across networks. Therefore, note that the order
            # in which they appear in `param_shapes` matches the order of
            # `cm_weights` and `int_weights`.
            cm_ind = 0
            int_ind = 0
            for ii, meta in enumerate(self.param_shapes_meta):
                net_type = meta['birnn_layer_type']
                net_id = meta['birnn_layer_id']

                if net_type == 'forward_rnn':
                    if forward_weights[net_id] is None:
                        forward_weights[net_id] = dict()
                    curr_weights = forward_weights[net_id]
                elif net_type == 'backward_rnn':
                    if backward_weights[net_id] is None:
                        backward_weights[net_id] = dict()
                    curr_weights = backward_weights[net_id]
                else:
                    assert net_type == 'out_mlp'
                    if mlp_weights is None:
                        mlp_weights = dict()
                    curr_weights = mlp_weights

                if meta['name'].startswith('cm_'):
                    if 'mod_weights' not in curr_weights.keys():
                        curr_weights['mod_weights'] = []
                    curr_weights['mod_weights'].append(cm_weights[cm_ind])
                    cm_ind += 1
                else:
                    if 'internal_weights' not in curr_weights.keys():
                        curr_weights['internal_weights'] = []
                    curr_weights['internal_weights'].append( \
                        int_weights[int_ind])
                    int_ind += 1

        #####################################
        ### Apply potential preprocessing ###
        #####################################
        self._forward_called = True
        if self._preprocess_fct is not None:
            x = self._preprocess_fct(x, seq_lengths=seq_lengths)

        ####################################
        ### Process bidirectional layers ###
        ####################################
        # Create reverse input sequence for backward network.
        if seq_lengths is not None:
            assert seq_lengths.size == x.shape[1]

        def revert_order(inp):
            if seq_lengths is None:
                return torch.flip(inp, [0])
            else:
                inp_back = torch.zeros_like(inp)
                for ii in range(seq_lengths.size):
                    inp_back[:int(seq_lengths[ii]),ii, :] = \
                        torch.flip(inp[:int(seq_lengths[ii]),ii, :], [0])
                return inp_back

        h = x

        for ll, fnet in enumerate(self._forward_rnns):
            bnet = self._backward_rnns[ll]

            # Revert inputs in time before processing them by the backward RNN.
            h_rev = revert_order(h)

            h_f = fnet.forward(h,
                               weights=forward_weights[ll],
                               condition=rnn_cmod_cond,
                               return_hidden=False,
                               return_hidden_int=False)
            h_b = bnet.forward(h_rev,
                               weights=backward_weights[ll],
                               condition=rnn_cmod_cond,
                               return_hidden=False,
                               return_hidden_int=False)

            # Revert outputs in time from the backward RNN before concatenation.
            # NOTE If `seq_lengths` are given, then this function will also set
            # the hidden timesteps corresponding to "padded timesteps" to zero.
            h_b = revert_order(h_b)

            # Set hidden states of `h_f` corresponding to padded timesteps to
            # zero to ensure consistency. Note, will only ever affect those
            # "padded timesteps".
            if seq_lengths is not None:
                for ii in range(seq_lengths.size):
                    h_f[:int(seq_lengths[ii]), ii, :] = 0

            h = torch.cat([h_f, h_b], dim=2)

        ##############################
        ### Compute network output ###
        ##############################
        if self._out_mlp is not None:
            #n_time, n_batch, n_feat = h.shape
            #h = h.view(n_time*n_batch, n_feat)
            h = self._out_mlp.forward(h,
                                      weights=mlp_weights,
                                      distilled_params=distilled_params,
                                      condition=mlp_cond)
            #h = h.view(n_time, n_batch, -1)

        return h

    def init_hh_weights_orthogonal(self):
        """Initialize hidden-to-hidden weights orthogonally.

        This method will call method
        :meth:`mnets.simple_rnn.SimpleRNN.init_hh_weights_orthogonal` of all
        internally maintained instances of class
        :class:`mnets.simple_rnn.SimpleRNN`.
        """
        for net in self._forward_rnns + self._backward_rnns:
            net.init_hh_weights_orthogonal()

    def get_cm_weights(self):
        """Get internal maintained weights that are associated with context-
        modulation.

        Returns:
            (list): List of weights from
            :attr:`mnets.mnet_interface.MainNetInterface.internal_params` that
            are belonging to context-mod layers.
        """
        ret = []
        for i, meta in enumerate(self.param_shapes_meta):
            if not (meta['name'] == 'cm_shift' or meta['name'] == 'cm_scale'):
                continue
            if meta['index'] != -1:
                ret.append(self.internal_params[meta['index']])
        return ret

    def get_non_cm_weights(self):
        """Get internal weights that are not associated with context-modulation.

        Returns:
            (list): List of weights from
            :attr:`mnets.mnet_interface.MainNetInterface.internal_params` that
            are not belonging to context-mod layers.
        """
        n_cm = self._num_context_mod_shapes()
        if n_cm == 0:
            return self.internal_params
        else:
            ret = []
            for i, meta in enumerate(self.param_shapes_meta):
                if meta['name'] == 'cm_shift' or meta['name'] == 'cm_scale':
                    continue
                if meta['index'] != -1:
                    ret.append(self.internal_params[meta['index']])
            return ret
Beispiel #6
0
    def __init__(self,
                 rnn_args={},
                 mlp_args=None,
                 preprocess_fct=None,
                 no_weights=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        assert isinstance(rnn_args, (dict, list, tuple))
        assert mlp_args is None or isinstance(mlp_args, dict)

        if isinstance(rnn_args, dict):
            rnn_args = [rnn_args]

        self._forward_rnns = []
        self._backward_rnns = []
        self._out_mlp = None
        self._preprocess_fct = preprocess_fct
        self._forward_called = False

        # FIXME At the moment we do not control input and output size of
        # individual networks and need to assume that the user sets them
        # correctly.

        ### Create all forward and backward nets for each bidirectional layer.
        for rargs in rnn_args:
            assert isinstance(rargs, dict)
            if 'verbose' not in rargs.keys():
                rargs['verbose'] = False
            if 'no_weights' in rargs.keys() and \
                    rargs['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'bidirectional layer is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in rargs.keys():
                rargs['no_weights'] = no_weights

            self._forward_rnns.append(SimpleRNN(**rargs))
            self._backward_rnns.append(SimpleRNN(**rargs))

        ### Create output network.
        if mlp_args is not None:
            if 'verbose' not in mlp_args.keys():
                mlp_args['verbose'] = False
            if 'no_weights' in mlp_args.keys() and \
                    mlp_args['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'output MLP is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in mlp_args.keys():
                mlp_args['no_weights'] = no_weights

            self._out_mlp = MLP(**mlp_args)

        ### Set all interface attributes correctly.
        if self._out_mlp is None:
            self._has_fc_out = self._forward_rnns[-1].has_fc_out
            # We can't set the following attribute to true, as the output is
            # a concatenation of the outputs from two networks. Therefore, the
            # weights used two compute the outputs are at different locations
            # in the `param_shapes` list.
            self._mask_fc_out = False
            self._has_linear_out = self._forward_rnns[-1].has_linear_out
        else:
            self._has_fc_out = self._out_mlp.has_fc_out
            self._mask_fc_out = self._out_mlp.mask_fc_out
            self._has_linear_out = self._out_mlp.has_linear_out

        # Collect all internal net objects from which we need to collect
        # attributes.
        nets = []
        for i, fnet in enumerate(self._forward_rnns):
            bnet = self._backward_rnns[i]

            nets.append((fnet, 'forward_rnn', i))
            nets.append((bnet, 'backward_rnn', i))
        if self._out_mlp is not None:
            nets.append((self._out_mlp, 'out_mlp', -1))

        # Iterate over all nets to collect their attribute values.
        self._param_shapes = []
        self._param_shapes_meta = []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        for i, net_tup in enumerate(nets):
            net, net_type, net_id = net_tup
            # Note, it is important to convert lists into new object and not
            # just copy references!
            # Note, we have to adapt all references if `i > 0`.

            # Sanity check:
            if i == 0:
                cm_nw = net._context_mod_no_weights
            elif cm_nw != net._context_mod_no_weights:
                raise ValueError('Network expect that either all internal ' +
                                 'networks maintain their context-mod ' +
                                 'weights or non of them does!')

            ps_len_old = len(self._param_shapes)

            if net._internal_params is not None:
                if self._internal_params is None:
                    self._internal_params = nn.ParameterList()
                ip_len_old = len(self._internal_params)
                self._internal_params.extend( \
                    nn.ParameterList(net._internal_params))
            self._param_shapes.extend(list(net._param_shapes))
            for meta in net.param_shapes_meta:
                assert 'birnn_layer_type' not in meta.keys()
                assert 'birnn_layer_id' not in meta.keys()

                new_meta = dict(meta)
                new_meta['birnn_layer_type'] = net_type
                new_meta['birnn_layer_id'] = net_id
                if i > 0:
                    # FIXME We should properly adjust colliding `layer` IDs.
                    new_meta['layer'] = -1
                new_meta['index'] = meta['index'] + ip_len_old
                self._param_shapes_meta.append(new_meta)

            if net._hyper_shapes_learned is not None:
                if self._hyper_shapes_learned is None:
                    self._hyper_shapes_learned = []
                    self._hyper_shapes_learned_ref = []
                self._hyper_shapes_learned.extend( \
                    list(net._hyper_shapes_learned))
                for ref in net._hyper_shapes_learned_ref:
                    self._hyper_shapes_learned_ref.append(ref + ps_len_old)
            if net._hyper_shapes_distilled is not None:
                if self._hyper_shapes_distilled is None:
                    self._hyper_shapes_distilled = []
                self._hyper_shapes_distilled.extend( \
                    list(net._hyper_shapes_distilled))

            if self._has_bias is None:
                self._has_bias = net._has_bias
            elif self._has_bias != net._has_bias:
                self._has_bias = False
                # FIXME We should overwrite the getter and throw an error!
                warn('Some internally maintained networks use biases, ' +
                     'while others don\'t. Setting attribute "has_bias" to ' +
                     'False.')

            self._layer_weight_tensors.extend( \
                nn.ParameterList(net._layer_weight_tensors))
            self._layer_bias_vectors.extend( \
                nn.ParameterList(net._layer_bias_vectors))
            if net._batchnorm_layers is not None:
                if self._batchnorm_layers is None:
                    self._batchnorm_layers = nn.ModuleList()
                self._batchnorm_layers.extend( \
                    nn.ModuleList(net._batchnorm_layers))
            if net._context_mod_layers is not None:
                if self._context_mod_layers is None:
                    self._context_mod_layers = nn.ModuleList()
                self._context_mod_layers.extend( \
                    nn.ModuleList(net._context_mod_layers))

        self._is_properly_setup()

        ### Print user information.
        if verbose:
            print('Constructed Bidirectional RNN with %d weights.' \
                  % self.num_params)
class ChunkSqueezer(nn.Module, MainNetInterface):
    """An MLP that first reduces the dimensionality of its inputs.

    The input dimensionality ``n_in`` is first reduced by a `reducer` network
    (which is an instance of class :class:`mnets.mlp.MLP`) using a chunking
    strategy. The reduced input will be then passed to the actual `network`
    (which is another instance of :class:`mnets.mlp.MLP`) to compute an output.
    
    Args:
        n_in (int): Input dimensionality.
        n_out (int): Number of output neurons.
        inp_chunk_dim (int): The input (dimensionality ``n_in``) will be split
            into chunks of size ``inp_chunk_dim``. Thus, there will be
            ``np.ceil(n_in/inp_chunk_dim)`` input chunks that are individually
            squeezed through the `reducer` network.

            Note:
                If the last chunk chunk might be zero-padded.
        out_chunk_dim (int): The output size of the `reducer` network. The
            input size of the actual `network` is then
            ``np.ceil(n_in/inp_chunk_dim) * out_chunk_dim``.
        cemb_size (int): The `reducer` network processes every chunk
            individually. In order to do so, it needs to know which chunk it is
            processing. Therefore, it is conditioned on a learned chunk
            embedding (there will be ``np.ceil(n_in/inp_chunk_dim)`` chunk
            embeddings). The dimensionality of these chunk embeddings is
            dertermined by this argument.
        cemb_init_std (float): Standard deviation used for the normal
            initialization of the chunk embeddings.
        red_layers (list or tuple): The architecture of the `reducer` network.
            See argument ``hidden_layers`` of class :class:`mnets.mlp.MLP`.
        net_layers (list or tuple): The architecture of the actual `network`.
            See argument ``hidden_layers`` of class :class:`mnets.mlp.MLP`. 
        activation_fn: The nonlinearity used in hidden layers. If ``None``, no
            nonlinearity will be applied.
        use_bias: Will be passed as option ``use_bias`` to the underlying MLPs
            (see :class:`mnets.mlp.MLP`).
        dynamic_biases (list, optional): This option determines the hidden
            layers of the `reducer` networks that receive the chunk embedding as
            dynamic biases. It is a list of indexes with the first hidden layer
            having index 0 and the output of the `reducer` would have index
            ``len(red_layers)``. The chunk embeddings will be transformed
            through a fully connected layer (no bias) and then added as
            "dynamic" bias to the output of the corresponding hidden layer.

            Note:
                If left unspecified, the chunk embeddings will just be another
                input to the `reducer` network.
        no_weights (bool): If set to ``True``, no trainable parameters will be
            constructed, i.e., weights are assumed to be produced ad-hoc
            by a hypernetwork and passed to the :meth:`forward` method.
        init_weights (optional): This option is for convinience reasons.
            The option expects a list of parameter values that are used to
            initialize the network weights. As such, it provides a
            convinient way of initializing a network with a weight draw
            produced by the hypernetwork.

            Note, internal weights (see 
            :attr:`mnets.mnet_interface.MainNetInterface.weights`) will be
            affected by this argument only.
        dropout_rate (float): Will be passed as option ``dropout_rate`` to the
            underlying MLPs (see :class:`mnets.mlp.MLP`).
        use_spectral_norm (bool): Will be passed as option ``use_spectral_norm``
            to the underlying MLPs (see :class:`mnets.mlp.MLP`).
        use_batch_norm (bool): Will be passed as option ``use_batch_norm``
            to the underlying MLPs (see :class:`mnets.mlp.MLP`).
        bn_track_stats (bool): Will be passed as option ``bn_track_stats``
            to the underlying MLPs (see :class:`mnets.mlp.MLP`).
        distill_bn_stats (bool): Will be passed as option ``distill_bn_stats``
            to the underlying MLPs (see :class:`mnets.mlp.MLP`).
    """
    def __init__(self,
                 n_in,
                 n_out=1,
                 inp_chunk_dim=100,
                 out_chunk_dim=10,
                 cemb_size=8,
                 cemb_init_std=1.,
                 red_layers=(10, 10),
                 net_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 dynamic_biases=None,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        self._n_in = n_in
        self._n_out = n_out
        self._inp_chunk_dim = inp_chunk_dim
        self._out_chunk_dim = out_chunk_dim
        self._cemb_size = cemb_size
        self._a_fun = activation_fn
        self._no_weights = no_weights

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True  # Ensure that `out_fn` is `None`!

        self._param_shapes = []
        #self._param_shapes_meta = [] # TODO implement!
        self._weights = None if no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None if not no_weights else []
        #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
        #    is None else [] # TODO implement.

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        self._context_mod_layers = None
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        #################################
        ### Generate Chunk Embeddings ###
        #################################
        self._num_cembs = int(np.ceil(n_in / inp_chunk_dim))
        last_chunk_size = n_in % inp_chunk_dim
        if last_chunk_size != 0:
            self._pad = inp_chunk_dim - last_chunk_size
        else:
            self._pad = -1

        cemb_shape = [self._num_cembs, cemb_size]
        self._param_shapes.append(cemb_shape)
        if no_weights:
            self._cembs = None
            self._hyper_shapes_learned.append(cemb_shape)
        else:
            self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape),
                                       requires_grad=True)
            nn.init.normal_(self._cembs, mean=0., std=cemb_init_std)

            self._weights.append(self._cembs)

        ############################
        ### Setup Dynamic Biases ###
        ############################
        self._has_dyn_bias = None
        if dynamic_biases is not None:
            assert np.all(np.array(dynamic_biases) >= 0) and \
                   np.all(np.array(dynamic_biases) < len(red_layers) + 1)
            dynamic_biases = np.sort(np.unique(dynamic_biases))

            # For each layer in the `reducer`, where we want to apply a dynamic
            # bias, we have to create a weight matrix for a corresponding
            # linear layer (we just ignore)
            self._dyn_bias_weights = nn.ModuleList()
            self._has_dyn_bias = []

            for i in range(len(red_layers) + 1):
                if i in dynamic_biases:
                    self._has_dyn_bias.append(True)

                    trgt_dim = out_chunk_dim
                    if i < len(red_layers):
                        trgt_dim = red_layers[i]
                    trgt_shape = [trgt_dim, cemb_size]

                    self._param_shapes.append(trgt_shape)
                    if not no_weights:
                        self._dyn_bias_weights.append(None)
                        self._hyper_shapes_learned.append(trgt_shape)
                    else:
                        self._dyn_bias_weights.append(nn.Parameter( \
                            torch.Tensor(*trgt_shape), requires_grad=True))
                        self._weights.append(self._dyn_bias_weights[-1])

                        init_params(self._dyn_bias_weights[-1])

                        self._layer_weight_tensors.append( \
                            self._dyn_bias_weights[-1])
                        self._layer_bias_vectors.append(None)
                else:
                    self._has_dyn_bias.append(False)
                    self._dyn_bias_weights.append(None)

        ################################
        ### Create `Reducer` Network ###
        ################################
        red_inp_dim = inp_chunk_dim + \
            (cemb_size if dynamic_biases is None else 0)
        self._reducer = MLP(
            n_in=red_inp_dim,
            n_out=out_chunk_dim,
            hidden_layers=red_layers,
            activation_fn=activation_fn,
            use_bias=use_bias,
            no_weights=no_weights,
            init_weights=None,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm,
            bn_track_stats=bn_track_stats,
            distill_bn_stats=distill_bn_stats,
            # We use context modulation to realize dynamic biases, since they
            # allow a different modulation per sample in the input mini-batch.
            # Hence, we can process several chunks in parallel with the reducer
            # network.
            use_context_mod=not dynamic_biases is None,
            context_mod_inputs=False,
            no_last_layer_context_mod=False,
            context_mod_no_weights=True,
            context_mod_post_activation=False,
            context_mod_gain_offset=False,
            context_mod_gain_softplus=False,
            out_fn=None,
            verbose=True)

        if dynamic_biases is not None:
            # FIXME We have to extract the param shapes from
            # `self._reducer.param_shapes`, as well as from
            # `self._reducer._hyper_shapes_learned` that belong to context-mod
            # layers. We may not add them to our own `param_shapes` attribute,
            # as these are not parameters (due to our misuse of the context-mod
            # layers).
            # Note, in the `forward` method, we need to supply context-mod
            # weights for all reducer networks, independent on whether they have
            # a dynamic bias or not. We can do so, by providing constant ones
            # for all gains and constance zero-shift for all layers without
            # dynamic biases (note, we need to ensure the correct batch dim!).
            raise NotImplementedError(
                'Dynamic biases are not yet implemented!')

        assert self._reducer._context_mod_layers is None

        ### Overtake all attributes from the underlying MLP.
        for s in self._reducer.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._reducer._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._reducer._weights:
                self._weights.append(p)

        for p in self._reducer._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._reducer._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._reducer._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._reducer._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = []
            for s in self._reducer._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        ###############################
        ### Create Actual `Network` ###
        ###############################
        net_inp_dim = out_chunk_dim * self._num_cembs
        self._network = MLP(n_in=net_inp_dim,
                            n_out=n_out,
                            hidden_layers=net_layers,
                            activation_fn=activation_fn,
                            use_bias=use_bias,
                            no_weights=no_weights,
                            init_weights=None,
                            dropout_rate=dropout_rate,
                            use_spectral_norm=use_spectral_norm,
                            use_batch_norm=use_batch_norm,
                            bn_track_stats=bn_track_stats,
                            distill_bn_stats=distill_bn_stats,
                            use_context_mod=False,
                            out_fn=None,
                            verbose=True)

        ### Overtake all attributes from the underlying MLP.
        for s in self._network.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._network._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._network._weights:
                self._weights.append(p)

        for p in self._network._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._network._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._network._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._hyper_shapes_distilled is not None:
            assert self._network._hyper_shapes_distilled is not None
            for s in self._network._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        #####################################
        ### Takeover given Initialization ###
        #####################################
        if init_weights is not None:
            assert len(init_weights) == len(self._weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self._param_shapes[i]))
                self._weights[i].data = init_weights[i]

        ######################
        ### Finalize Setup ###
        ######################
        num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes)
        print('Constructed MLP that processes dimensionality reduced inputs ' +
              'through chunking. The network has a total of %d weights.' %
              num_weights)

        self._is_properly_setup()

    def distillation_targets(self):
        """Targets to be distilled after training.

        See docstring of abstract super method
        :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`.

        This method will return the distillation targets from the 2 underlying
        networks, see method :meth:`mnets.mlp.MLP.distillation_targets`.

        Returns:
            The target tensors corresponding to the shapes specified in
            attribute :attr:`hyper_shapes_distilled`.
        """
        if self.hyper_shapes_distilled is None:
            return None

        ret = self._reducer.distillation_targets + \
            self._network.distillation_targets

        return ret

    def forward(self, x, weights=None, distilled_params=None, condition=None):
        """Compute the output :math:`y` of this network given the input
        :math:`x`.

        Args:
            (....): See docstring of method
                :meth:`mnets.mnet_interface.MainNetInterface.forward`. We
                provide some more specific information below.
            distilled_params: Will be split and passed as distillation targets
                to the underying instances of class :class:`mnets.mlp.MLP` if
                specified.
            condition (optional, int or dict): Will be passed to the underlying
                instances of class :class:`mnets.mlp.MLP`.

        Returns:
            The output :math:`y` of the network.
        """
        if self._no_weights and weights is None:
            raise Exception('Network was generated without weights. ' +
                            'Hence, "weights" option may not be None.')

        if weights is None:
            weights = self._weights
        else:
            assert len(weights) == len(self.param_shapes)
            for i, s in enumerate(self.param_shapes):
                assert np.all(np.equal(s, list(weights[i].shape)))

        #########################################
        ### Extract parameters from `weights` ###
        #########################################
        cembs = weights[0]
        w_ind = 1

        if self._has_dyn_bias is not None:
            w_ind_new = w_ind + len(self._dyn_bias_weights)
            dyn_bias_weights = weights[w_ind:w_ind_new]
            w_ind = w_ind_new

            # TODO use `dyn_bias_weights` to construct weights for context-mod
            # layers.
            raise NotImplementedError

        w_ind_new = w_ind + len(self._reducer.param_shapes)
        red_weights = weights[w_ind:w_ind_new]
        w_ind = w_ind_new

        w_ind_new = w_ind + len(self._network.param_shapes)
        net_weights = weights[w_ind:w_ind_new]
        w_ind = w_ind_new

        red_distilled_params = None
        net_distilled_params = None
        if distilled_params is not None:
            if self.hyper_shapes_distilled is None:
                raise ValueError(
                    'Argument "distilled_params" can only be ' +
                    'provided if the return value of ' +
                    'method "distillation_targets()" is not None.')

            assert len(distilled_params) == len(self.hyper_shapes_distilled)
            red_distilled_params = \
                distilled_params[:len(self._reducer.hyper_shapes_distilled)]
            net_distilled_params = \
                distilled_params[len(self._reducer.hyper_shapes_distilled):]

        ###########################
        ### Chunk network input ###
        ###########################
        assert x.shape[1] == self._n_in

        if self._pad != -1:
            x = F.pad(x, (0, self._pad))
            assert x.shape[1] % self._out_chunk_dim == 0

        batch_size = x.shape[0]
        # We now split the input `x` into chunks and convert them into
        # separate samples, i.e., the `batch_size` will be multiplied by the
        # number of chunks.
        # So, we parallel process a huge batch with a small network rather than
        # processing a huge input with a huge network.

        chunks = torch.split(x, self._inp_chunk_dim, dim=1)
        # Concatenate the chunks along the batch dimension.
        chunks = torch.cat(chunks, dim=0)
        if self._has_dyn_bias is not None:
            raise NotImplementedError()
        else:
            # Within a chunk the same chunk embedding is used.
            cembs = torch.split(cembs, 1, dim=0)
            cembs = [emb.expand(batch_size, -1) for emb in cembs]
            cembs = torch.cat(cembs, dim=0)

            chunks = torch.cat([chunks, cembs], dim=1)

        ###################################
        ### Reduce input dimensionality ###
        ###################################
        if self._has_dyn_bias is not None:
            # TODO pass context-mod weights to `reducer`.
            raise NotImplementedError()
        chunks = self._reducer.forward(chunks,
                                       weights=red_weights,
                                       distilled_params=red_distilled_params,
                                       condition=condition)

        ### Reformat `reducer` output into the input of the actual `network`.
        chunks = torch.split(chunks, batch_size, dim=0)
        net_input = torch.cat(chunks, dim=1)
        assert net_input.shape[0] == batch_size

        ###############################
        ### Compute network output ###
        ##############################
        return self._network.forward(net_input,
                                     weights=net_weights,
                                     distilled_params=net_distilled_params,
                                     condition=condition)
    def __init__(self,
                 n_in,
                 n_out=1,
                 inp_chunk_dim=100,
                 out_chunk_dim=10,
                 cemb_size=8,
                 cemb_init_std=1.,
                 red_layers=(10, 10),
                 net_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 dynamic_biases=None,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        self._n_in = n_in
        self._n_out = n_out
        self._inp_chunk_dim = inp_chunk_dim
        self._out_chunk_dim = out_chunk_dim
        self._cemb_size = cemb_size
        self._a_fun = activation_fn
        self._no_weights = no_weights

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True  # Ensure that `out_fn` is `None`!

        self._param_shapes = []
        #self._param_shapes_meta = [] # TODO implement!
        self._weights = None if no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None if not no_weights else []
        #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
        #    is None else [] # TODO implement.

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        self._context_mod_layers = None
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        #################################
        ### Generate Chunk Embeddings ###
        #################################
        self._num_cembs = int(np.ceil(n_in / inp_chunk_dim))
        last_chunk_size = n_in % inp_chunk_dim
        if last_chunk_size != 0:
            self._pad = inp_chunk_dim - last_chunk_size
        else:
            self._pad = -1

        cemb_shape = [self._num_cembs, cemb_size]
        self._param_shapes.append(cemb_shape)
        if no_weights:
            self._cembs = None
            self._hyper_shapes_learned.append(cemb_shape)
        else:
            self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape),
                                       requires_grad=True)
            nn.init.normal_(self._cembs, mean=0., std=cemb_init_std)

            self._weights.append(self._cembs)

        ############################
        ### Setup Dynamic Biases ###
        ############################
        self._has_dyn_bias = None
        if dynamic_biases is not None:
            assert np.all(np.array(dynamic_biases) >= 0) and \
                   np.all(np.array(dynamic_biases) < len(red_layers) + 1)
            dynamic_biases = np.sort(np.unique(dynamic_biases))

            # For each layer in the `reducer`, where we want to apply a dynamic
            # bias, we have to create a weight matrix for a corresponding
            # linear layer (we just ignore)
            self._dyn_bias_weights = nn.ModuleList()
            self._has_dyn_bias = []

            for i in range(len(red_layers) + 1):
                if i in dynamic_biases:
                    self._has_dyn_bias.append(True)

                    trgt_dim = out_chunk_dim
                    if i < len(red_layers):
                        trgt_dim = red_layers[i]
                    trgt_shape = [trgt_dim, cemb_size]

                    self._param_shapes.append(trgt_shape)
                    if not no_weights:
                        self._dyn_bias_weights.append(None)
                        self._hyper_shapes_learned.append(trgt_shape)
                    else:
                        self._dyn_bias_weights.append(nn.Parameter( \
                            torch.Tensor(*trgt_shape), requires_grad=True))
                        self._weights.append(self._dyn_bias_weights[-1])

                        init_params(self._dyn_bias_weights[-1])

                        self._layer_weight_tensors.append( \
                            self._dyn_bias_weights[-1])
                        self._layer_bias_vectors.append(None)
                else:
                    self._has_dyn_bias.append(False)
                    self._dyn_bias_weights.append(None)

        ################################
        ### Create `Reducer` Network ###
        ################################
        red_inp_dim = inp_chunk_dim + \
            (cemb_size if dynamic_biases is None else 0)
        self._reducer = MLP(
            n_in=red_inp_dim,
            n_out=out_chunk_dim,
            hidden_layers=red_layers,
            activation_fn=activation_fn,
            use_bias=use_bias,
            no_weights=no_weights,
            init_weights=None,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm,
            bn_track_stats=bn_track_stats,
            distill_bn_stats=distill_bn_stats,
            # We use context modulation to realize dynamic biases, since they
            # allow a different modulation per sample in the input mini-batch.
            # Hence, we can process several chunks in parallel with the reducer
            # network.
            use_context_mod=not dynamic_biases is None,
            context_mod_inputs=False,
            no_last_layer_context_mod=False,
            context_mod_no_weights=True,
            context_mod_post_activation=False,
            context_mod_gain_offset=False,
            context_mod_gain_softplus=False,
            out_fn=None,
            verbose=True)

        if dynamic_biases is not None:
            # FIXME We have to extract the param shapes from
            # `self._reducer.param_shapes`, as well as from
            # `self._reducer._hyper_shapes_learned` that belong to context-mod
            # layers. We may not add them to our own `param_shapes` attribute,
            # as these are not parameters (due to our misuse of the context-mod
            # layers).
            # Note, in the `forward` method, we need to supply context-mod
            # weights for all reducer networks, independent on whether they have
            # a dynamic bias or not. We can do so, by providing constant ones
            # for all gains and constance zero-shift for all layers without
            # dynamic biases (note, we need to ensure the correct batch dim!).
            raise NotImplementedError(
                'Dynamic biases are not yet implemented!')

        assert self._reducer._context_mod_layers is None

        ### Overtake all attributes from the underlying MLP.
        for s in self._reducer.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._reducer._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._reducer._weights:
                self._weights.append(p)

        for p in self._reducer._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._reducer._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._reducer._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._reducer._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = []
            for s in self._reducer._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        ###############################
        ### Create Actual `Network` ###
        ###############################
        net_inp_dim = out_chunk_dim * self._num_cembs
        self._network = MLP(n_in=net_inp_dim,
                            n_out=n_out,
                            hidden_layers=net_layers,
                            activation_fn=activation_fn,
                            use_bias=use_bias,
                            no_weights=no_weights,
                            init_weights=None,
                            dropout_rate=dropout_rate,
                            use_spectral_norm=use_spectral_norm,
                            use_batch_norm=use_batch_norm,
                            bn_track_stats=bn_track_stats,
                            distill_bn_stats=distill_bn_stats,
                            use_context_mod=False,
                            out_fn=None,
                            verbose=True)

        ### Overtake all attributes from the underlying MLP.
        for s in self._network.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._network._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._network._weights:
                self._weights.append(p)

        for p in self._network._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._network._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._network._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._hyper_shapes_distilled is not None:
            assert self._network._hyper_shapes_distilled is not None
            for s in self._network._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        #####################################
        ### Takeover given Initialization ###
        #####################################
        if init_weights is not None:
            assert len(init_weights) == len(self._weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self._param_shapes[i]))
                self._weights[i].data = init_weights[i]

        ######################
        ### Finalize Setup ###
        ######################
        num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes)
        print('Constructed MLP that processes dimensionality reduced inputs ' +
              'through chunking. The network has a total of %d weights.' %
              num_weights)

        self._is_properly_setup()