예제 #1
0
def generate_classifier(config, data_handlers, device):
    """Create a classifier network. Depending on the experiment and method, 
    the method manages to build either a classifier for task inference 
    or a classifier that solves our task is build. This also implies if the
    network will receive weights from a hypernetwork or will have weights 
    on its own.
    Following important configurations will be determined in order to create
    the classifier: \n 
    * in- and output and hidden layer dimensions of the classifier. \n
    * architecture, chunk- and task-embedding details of the hypernetwork. 


    See :class:`mnets.mlp.MLP` for details on the network that will be created
        to be a classifier. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.
        
    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
    
    Returns: 
        (tuple): Tuple containing:
        - **net**: The classifier network.
        - **class_hnet**: (optional) The classifier's hypernetwork.
    """
    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2

    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.class_incremental:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.training_task_infer or config.class_incremental:
        # task inference network
        config.out_dim = 1

    # have all output neurons already build up for cl 2
    if config.cl_scenario != 2:
        n_out = config.out_dim * config.num_tasks
    else:
        n_out = config.out_dim

    if config.training_task_infer or config.class_incremental:
        n_out = config.num_tasks

        # build classifier
    print('For the Classifier: ')
    class_arch = misc.str_to_ints(config.class_fc_arch)
    if config.training_with_hnet:
        no_weights = True
    else:
        no_weights = False

    net = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=class_arch,
              activation_fn=misc.str_to_act(config.class_net_act),
              dropout_rate=config.class_dropout_rate,
              no_weights=no_weights).to(device)

    print('Constructed MLP with shapes: ', net.param_shapes)

    config.num_weights_class_net = \
        MainNetInterface.shapes_to_num_weights(net.param_shapes)
    # build classifier hnet
    # this is set in the run method in train.py
    if config.training_with_hnet:

        class_hnet = sim_utils.get_hnet_model(config,
                                              config.num_tasks,
                                              device,
                                              net.param_shapes,
                                              cprefix='class_')
        init_params = list(class_hnet.parameters())

        config.num_weights_class_hyper_net = sum(
            p.numel() for p in class_hnet.parameters() if p.requires_grad)
        config.compression_ratio_class = config.num_weights_class_hyper_net / \
                                         config.num_weights_class_net
        print('Created classifier Hypernetwork with ratio: ',
              config.compression_ratio_class)
        if config.compression_ratio_class > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network, not might not be directly ' +
                  'comparable with the number of parameters of work we ' +
                  'compare against.')
    else:
        class_hnet = None
        init_params = list(net.parameters())
        config.num_weights_class_hyper_net = None
        config.compression_ratio_class = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if config.training_with_hnet:
        for temb in class_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(class_hnet, 'chunk_embeddings'):
        for emb in class_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    if not config.training_with_hnet:
        return net
    else:
        return net, class_hnet
예제 #2
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:
            
            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.

    Returns:
        The created main network model.
    """
    assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net'])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # FIXME if an argument wasn't specified, then we use the default value that
    # is currently (at time of implementation) in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, torch.nn.ReLU()),
            use_bias=not assign(no_bias, False),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, -1),
            use_spectral_norm=assign(specnorm, False),
            use_batch_norm=assign(use_bn, False),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, True),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
        ).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, 0.25)).to(device)
    else:
        assert (net_type == 'bio_conv_net')
        assert (len(out_shape) == 1)

        raise NotImplementedError('Implementation not publicly available!')

    return mnet
예제 #3
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False,
                   **mnet_kwargs):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:

            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``wrn``: :class:`mnets.wide_resnet.WRN`
            - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
            - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer`
            - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.
        **mnet_kwargs: Additional keyword arguments that will be passed to the
            main network constructor.

    Returns:
        The created main network model.
    """
    assert (net_type in [
        'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp',
        'simple_rnn', 'wrn', 'iresnet'
    ])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    # This argument has to be handled during usage of the network and not during
    # construction.
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        # Default keyword arguments of class MLP.
        dkws = misc.get_default_args(MLP.__init__)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)
        assert hc('resnet_block_depth') and hc('resnet_channel_sizes')

        # Default keyword arguments of class ResNet.
        dkws = misc.get_default_args(ResNet.__init__)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('resnet_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')),
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'wrn':
        assert (len(out_shape) == 1)
        assert hc('wrn_block_depth') and hc('wrn_widening_factor')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(WRN.__init__)

        mnet = WRN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('wrn_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')),
            verbose=True,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            k=gc('wrn_widening_factor'),
            use_fc_bias=gc('wrn_use_fc_bias'),
            dropout_rate=gc('dropout_rate'),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'iresnet':
        assert (len(out_shape) == 1)
        assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \
            and hc('iresnet_blocks_per_group') \
            and hc('iresnet_bottleneck_blocks') \
            and hc('iresnet_projection_shortcut')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(ResNetIN.__init__)

        mnet = ResNetIN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            use_bias=assign(not no_bias, dkws['use_bias']),
            use_fc_bias=gc('wrn_use_fc_bias'),
            num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')),
            blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')),
            projection_shortcut=gc('iresnet_projection_shortcut'),
            bottleneck_blocks=gc('iresnet_bottleneck_blocks'),
            #cutout_mod=False,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #chw_input_format=False,
            verbose=True,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class ZenkeNet.
        dkws = misc.get_default_args(ZenkeNet.__init__)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            **mnet_kwargs).to(device)

    elif net_type == 'bio_conv_net':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class BioConvNet.
        #dkws = misc.get_default_args(BioConvNet.__init__)

        mnet = BioConvNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            no_weights=no_weights,
            #init_weights=None,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'chunked_mlp':
        assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \
               hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \
               hc('cmlp_cemb_dim')
        assert len(in_shape) == 1 and len(out_shape) == 1

        # Default keyword arguments of class ChunkSqueezer.
        dkws = misc.get_default_args(ChunkSqueezer.__init__)

        mnet = ChunkSqueezer(
            n_in=in_shape[0],
            n_out=out_shape[0],
            inp_chunk_dim=gc('cmlp_in_cdim'),
            out_chunk_dim=gc('cmlp_out_cdim'),
            cemb_size=gc('cmlp_cemb_dim'),
            #cemb_init_std=1.,
            red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')),
            net_layers=misc.str_to_ints(gc('cmlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #dynamic_biases=None,
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'lenet':
        assert hc('lenet_type')
        assert len(out_shape) == 1

        # Default keyword arguments of class LeNet.
        dkws = misc.get_default_args(LeNet.__init__)

        mnet = LeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,
            arch=gc('lenet_type'),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            # TODO Context-mod weights.
            **mnet_kwargs).to(device)

    else:
        assert (net_type == 'simple_rnn')
        assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \
            hc('srnn_post_fc_layers')  and hc('srnn_no_fc_out') and \
            hc('srnn_rec_type')
        assert len(in_shape) == 1 and len(out_shape) == 1

        if gc('srnn_rec_type') == 'lstm':
            use_lstm = True
        else:
            assert gc('srnn_rec_type') == 'elman'
            use_lstm = False

        # Default keyword arguments of class SimpleRNN.
        dkws = misc.get_default_args(SimpleRNN.__init__)

        rnn_layers = misc.str_to_ints(gc('srnn_rec_layers'))
        fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers'))
        if gc('srnn_no_fc_out'):
            rnn_layers.append(out_shape[0])
        else:
            fc_layers.append(out_shape[0])

        mnet = SimpleRNN(n_in=in_shape[0],
                         rnn_layers=rnn_layers,
                         fc_layers_pre=misc.str_to_ints(
                             gc('srnn_pre_fc_layers')),
                         fc_layers=fc_layers,
                         activation=assign(net_act, dkws['activation']),
                         use_lstm=use_lstm,
                         use_bias=assign(not no_bias, dkws['use_bias']),
                         no_weights=no_weights,
                         verbose=True,
                         **mnet_kwargs).to(device)

    return mnet
예제 #4
0
def generate_replay_networks(config,
                             data_handlers,
                             device,
                             create_rp_hnet=True,
                             only_train_replay=False):
    """Create a replay model that consists of either a encoder/decoder or
    a discriminator/generator pair. Additionally, this method manages the 
    creation of a hypernetwork for the generator/decoder. 
    Following important configurations will be determined in order to create
    the replay model: 
    * in- and output and hidden layer dimensions of the encoder/decoder. 
    * architecture, chunk- and task-embedding details of decoder's hypernetwork. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device..
        create_rp_hnet: Whether a hypernetwork for the replay should be 
            constructed. If not, the decoder/generator will have 
            trainable weights on its own.
        only_train_replay: We normally do not train on the last task since we do 
            not need to replay this last tasks data. But if we want a replay 
            method to be able to generate data from all tasks then we set this 
            option to true.

    Returns:
        (tuple): Tuple containing:

        - **enc**: Encoder/discriminator network instance.
        - **dec**: Decoder/generator networkinstance.
        - **dec_hnet**: Hypernetwork instance for the decoder/generator. This 
            return value is None if no hypernetwork should be constructed.
    """

    if config.replay_method == 'gan':
        n_out = 1
    else:
        n_out = config.latent_dim * 2

    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2
    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.single_class_replay:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.infer_task_id:
        # task inference network
        config.out_dim = 1

    # builld encoder
    print('For the replay encoder/discriminator: ')
    enc_arch = misc.str_to_ints(config.enc_fc_arch)
    enc = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=enc_arch,
              activation_fn=misc.str_to_act(config.enc_net_act),
              dropout_rate=config.enc_dropout_rate,
              no_weights=False).to(device)
    print('Constructed MLP with shapes: ', enc.param_shapes)
    init_params = list(enc.parameters())
    # builld decoder
    print('For the replay decoder/generator: ')
    dec_arch = misc.str_to_ints(config.dec_fc_arch)
    # add dimensions for conditional input
    n_out = config.latent_dim

    if config.conditional_replay:
        n_out += config.conditional_dim

    dec = MLP(n_in=n_out,
              n_out=n_in,
              hidden_layers=dec_arch,
              activation_fn=misc.str_to_act(config.dec_net_act),
              use_bias=True,
              no_weights=config.rp_beta > 0,
              dropout_rate=config.dec_dropout_rate).to(device)

    print('Constructed MLP with shapes: ', dec.param_shapes)
    config.num_weights_enc = \
                        MainNetInterface.shapes_to_num_weights(enc.param_shapes)

    config.num_weights_dec = \
                        MainNetInterface.shapes_to_num_weights(dec.param_shapes)
    config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec
    # we do not need a replay model for the last task

    # train on last task or not
    if only_train_replay:
        subtr = 0
    else:
        subtr = 1

    num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1

    if config.single_class_replay:
        # we do not need a replay model for the last task
        if config.num_tasks > 1:
            num_embeddings = config.out_dim * (config.num_tasks - subtr)
        else:
            num_embeddings = config.out_dim * (config.num_tasks)

    config.num_embeddings = num_embeddings
    # build decoder hnet
    if create_rp_hnet:
        print('For the decoder/generator hypernetwork: ')
        d_hnet = sim_utils.get_hnet_model(config,
                                          num_embeddings,
                                          device,
                                          dec.hyper_shapes_learned,
                                          cprefix='rp_')

        init_params += list(d_hnet.parameters())

        config.num_weights_rp_hyper_net = sum(p.numel()
                                              for p in d_hnet.parameters()
                                              if p.requires_grad)
        config.compression_ratio_rp = config.num_weights_rp_hyper_net / \
                                                        config.num_weights_dec
        print('Created replay hypernetwork with ratio: ',
              config.compression_ratio_rp)
        if config.compression_ratio_rp > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network,\nthis might not be directly ' +
                  'comparable with the number of parameters of methods we ' +
                  'compare against.')
    else:
        num_embeddings = config.num_tasks - subtr
        d_hnet = None
        init_params += list(dec.parameters())
        config.num_weights_rp_hyper_net = 0
        config.compression_ratio_rp = 0

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_rp_hnet:
        for temb in d_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(d_hnet, 'chunk_embeddings'):
        for emb in d_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    return enc, dec, d_hnet
예제 #5
0
    def __init__(self,
                 rnn_args={},
                 mlp_args=None,
                 preprocess_fct=None,
                 no_weights=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        assert isinstance(rnn_args, (dict, list, tuple))
        assert mlp_args is None or isinstance(mlp_args, dict)

        if isinstance(rnn_args, dict):
            rnn_args = [rnn_args]

        self._forward_rnns = []
        self._backward_rnns = []
        self._out_mlp = None
        self._preprocess_fct = preprocess_fct
        self._forward_called = False

        # FIXME At the moment we do not control input and output size of
        # individual networks and need to assume that the user sets them
        # correctly.

        ### Create all forward and backward nets for each bidirectional layer.
        for rargs in rnn_args:
            assert isinstance(rargs, dict)
            if 'verbose' not in rargs.keys():
                rargs['verbose'] = False
            if 'no_weights' in rargs.keys() and \
                    rargs['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'bidirectional layer is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in rargs.keys():
                rargs['no_weights'] = no_weights

            self._forward_rnns.append(SimpleRNN(**rargs))
            self._backward_rnns.append(SimpleRNN(**rargs))

        ### Create output network.
        if mlp_args is not None:
            if 'verbose' not in mlp_args.keys():
                mlp_args['verbose'] = False
            if 'no_weights' in mlp_args.keys() and \
                    mlp_args['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'output MLP is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in mlp_args.keys():
                mlp_args['no_weights'] = no_weights

            self._out_mlp = MLP(**mlp_args)

        ### Set all interface attributes correctly.
        if self._out_mlp is None:
            self._has_fc_out = self._forward_rnns[-1].has_fc_out
            # We can't set the following attribute to true, as the output is
            # a concatenation of the outputs from two networks. Therefore, the
            # weights used two compute the outputs are at different locations
            # in the `param_shapes` list.
            self._mask_fc_out = False
            self._has_linear_out = self._forward_rnns[-1].has_linear_out
        else:
            self._has_fc_out = self._out_mlp.has_fc_out
            self._mask_fc_out = self._out_mlp.mask_fc_out
            self._has_linear_out = self._out_mlp.has_linear_out

        # Collect all internal net objects from which we need to collect
        # attributes.
        nets = []
        for i, fnet in enumerate(self._forward_rnns):
            bnet = self._backward_rnns[i]

            nets.append((fnet, 'forward_rnn', i))
            nets.append((bnet, 'backward_rnn', i))
        if self._out_mlp is not None:
            nets.append((self._out_mlp, 'out_mlp', -1))

        # Iterate over all nets to collect their attribute values.
        self._param_shapes = []
        self._param_shapes_meta = []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        for i, net_tup in enumerate(nets):
            net, net_type, net_id = net_tup
            # Note, it is important to convert lists into new object and not
            # just copy references!
            # Note, we have to adapt all references if `i > 0`.

            # Sanity check:
            if i == 0:
                cm_nw = net._context_mod_no_weights
            elif cm_nw != net._context_mod_no_weights:
                raise ValueError('Network expect that either all internal ' +
                                 'networks maintain their context-mod ' +
                                 'weights or non of them does!')

            ps_len_old = len(self._param_shapes)

            if net._internal_params is not None:
                if self._internal_params is None:
                    self._internal_params = nn.ParameterList()
                ip_len_old = len(self._internal_params)
                self._internal_params.extend( \
                    nn.ParameterList(net._internal_params))
            self._param_shapes.extend(list(net._param_shapes))
            for meta in net.param_shapes_meta:
                assert 'birnn_layer_type' not in meta.keys()
                assert 'birnn_layer_id' not in meta.keys()

                new_meta = dict(meta)
                new_meta['birnn_layer_type'] = net_type
                new_meta['birnn_layer_id'] = net_id
                if i > 0:
                    # FIXME We should properly adjust colliding `layer` IDs.
                    new_meta['layer'] = -1
                new_meta['index'] = meta['index'] + ip_len_old
                self._param_shapes_meta.append(new_meta)

            if net._hyper_shapes_learned is not None:
                if self._hyper_shapes_learned is None:
                    self._hyper_shapes_learned = []
                    self._hyper_shapes_learned_ref = []
                self._hyper_shapes_learned.extend( \
                    list(net._hyper_shapes_learned))
                for ref in net._hyper_shapes_learned_ref:
                    self._hyper_shapes_learned_ref.append(ref + ps_len_old)
            if net._hyper_shapes_distilled is not None:
                if self._hyper_shapes_distilled is None:
                    self._hyper_shapes_distilled = []
                self._hyper_shapes_distilled.extend( \
                    list(net._hyper_shapes_distilled))

            if self._has_bias is None:
                self._has_bias = net._has_bias
            elif self._has_bias != net._has_bias:
                self._has_bias = False
                # FIXME We should overwrite the getter and throw an error!
                warn('Some internally maintained networks use biases, ' +
                     'while others don\'t. Setting attribute "has_bias" to ' +
                     'False.')

            self._layer_weight_tensors.extend( \
                nn.ParameterList(net._layer_weight_tensors))
            self._layer_bias_vectors.extend( \
                nn.ParameterList(net._layer_bias_vectors))
            if net._batchnorm_layers is not None:
                if self._batchnorm_layers is None:
                    self._batchnorm_layers = nn.ModuleList()
                self._batchnorm_layers.extend( \
                    nn.ModuleList(net._batchnorm_layers))
            if net._context_mod_layers is not None:
                if self._context_mod_layers is None:
                    self._context_mod_layers = nn.ModuleList()
                self._context_mod_layers.extend( \
                    nn.ModuleList(net._context_mod_layers))

        self._is_properly_setup()

        ### Print user information.
        if verbose:
            print('Constructed Bidirectional RNN with %d weights.' \
                  % self.num_params)
    def __init__(self,
                 n_in,
                 n_out=1,
                 inp_chunk_dim=100,
                 out_chunk_dim=10,
                 cemb_size=8,
                 cemb_init_std=1.,
                 red_layers=(10, 10),
                 net_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 dynamic_biases=None,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        self._n_in = n_in
        self._n_out = n_out
        self._inp_chunk_dim = inp_chunk_dim
        self._out_chunk_dim = out_chunk_dim
        self._cemb_size = cemb_size
        self._a_fun = activation_fn
        self._no_weights = no_weights

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True  # Ensure that `out_fn` is `None`!

        self._param_shapes = []
        #self._param_shapes_meta = [] # TODO implement!
        self._weights = None if no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None if not no_weights else []
        #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
        #    is None else [] # TODO implement.

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        self._context_mod_layers = None
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        #################################
        ### Generate Chunk Embeddings ###
        #################################
        self._num_cembs = int(np.ceil(n_in / inp_chunk_dim))
        last_chunk_size = n_in % inp_chunk_dim
        if last_chunk_size != 0:
            self._pad = inp_chunk_dim - last_chunk_size
        else:
            self._pad = -1

        cemb_shape = [self._num_cembs, cemb_size]
        self._param_shapes.append(cemb_shape)
        if no_weights:
            self._cembs = None
            self._hyper_shapes_learned.append(cemb_shape)
        else:
            self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape),
                                       requires_grad=True)
            nn.init.normal_(self._cembs, mean=0., std=cemb_init_std)

            self._weights.append(self._cembs)

        ############################
        ### Setup Dynamic Biases ###
        ############################
        self._has_dyn_bias = None
        if dynamic_biases is not None:
            assert np.all(np.array(dynamic_biases) >= 0) and \
                   np.all(np.array(dynamic_biases) < len(red_layers) + 1)
            dynamic_biases = np.sort(np.unique(dynamic_biases))

            # For each layer in the `reducer`, where we want to apply a dynamic
            # bias, we have to create a weight matrix for a corresponding
            # linear layer (we just ignore)
            self._dyn_bias_weights = nn.ModuleList()
            self._has_dyn_bias = []

            for i in range(len(red_layers) + 1):
                if i in dynamic_biases:
                    self._has_dyn_bias.append(True)

                    trgt_dim = out_chunk_dim
                    if i < len(red_layers):
                        trgt_dim = red_layers[i]
                    trgt_shape = [trgt_dim, cemb_size]

                    self._param_shapes.append(trgt_shape)
                    if not no_weights:
                        self._dyn_bias_weights.append(None)
                        self._hyper_shapes_learned.append(trgt_shape)
                    else:
                        self._dyn_bias_weights.append(nn.Parameter( \
                            torch.Tensor(*trgt_shape), requires_grad=True))
                        self._weights.append(self._dyn_bias_weights[-1])

                        init_params(self._dyn_bias_weights[-1])

                        self._layer_weight_tensors.append( \
                            self._dyn_bias_weights[-1])
                        self._layer_bias_vectors.append(None)
                else:
                    self._has_dyn_bias.append(False)
                    self._dyn_bias_weights.append(None)

        ################################
        ### Create `Reducer` Network ###
        ################################
        red_inp_dim = inp_chunk_dim + \
            (cemb_size if dynamic_biases is None else 0)
        self._reducer = MLP(
            n_in=red_inp_dim,
            n_out=out_chunk_dim,
            hidden_layers=red_layers,
            activation_fn=activation_fn,
            use_bias=use_bias,
            no_weights=no_weights,
            init_weights=None,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm,
            bn_track_stats=bn_track_stats,
            distill_bn_stats=distill_bn_stats,
            # We use context modulation to realize dynamic biases, since they
            # allow a different modulation per sample in the input mini-batch.
            # Hence, we can process several chunks in parallel with the reducer
            # network.
            use_context_mod=not dynamic_biases is None,
            context_mod_inputs=False,
            no_last_layer_context_mod=False,
            context_mod_no_weights=True,
            context_mod_post_activation=False,
            context_mod_gain_offset=False,
            context_mod_gain_softplus=False,
            out_fn=None,
            verbose=True)

        if dynamic_biases is not None:
            # FIXME We have to extract the param shapes from
            # `self._reducer.param_shapes`, as well as from
            # `self._reducer._hyper_shapes_learned` that belong to context-mod
            # layers. We may not add them to our own `param_shapes` attribute,
            # as these are not parameters (due to our misuse of the context-mod
            # layers).
            # Note, in the `forward` method, we need to supply context-mod
            # weights for all reducer networks, independent on whether they have
            # a dynamic bias or not. We can do so, by providing constant ones
            # for all gains and constance zero-shift for all layers without
            # dynamic biases (note, we need to ensure the correct batch dim!).
            raise NotImplementedError(
                'Dynamic biases are not yet implemented!')

        assert self._reducer._context_mod_layers is None

        ### Overtake all attributes from the underlying MLP.
        for s in self._reducer.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._reducer._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._reducer._weights:
                self._weights.append(p)

        for p in self._reducer._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._reducer._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._reducer._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._reducer._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = []
            for s in self._reducer._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        ###############################
        ### Create Actual `Network` ###
        ###############################
        net_inp_dim = out_chunk_dim * self._num_cembs
        self._network = MLP(n_in=net_inp_dim,
                            n_out=n_out,
                            hidden_layers=net_layers,
                            activation_fn=activation_fn,
                            use_bias=use_bias,
                            no_weights=no_weights,
                            init_weights=None,
                            dropout_rate=dropout_rate,
                            use_spectral_norm=use_spectral_norm,
                            use_batch_norm=use_batch_norm,
                            bn_track_stats=bn_track_stats,
                            distill_bn_stats=distill_bn_stats,
                            use_context_mod=False,
                            out_fn=None,
                            verbose=True)

        ### Overtake all attributes from the underlying MLP.
        for s in self._network.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._network._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._network._weights:
                self._weights.append(p)

        for p in self._network._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._network._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._network._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._hyper_shapes_distilled is not None:
            assert self._network._hyper_shapes_distilled is not None
            for s in self._network._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        #####################################
        ### Takeover given Initialization ###
        #####################################
        if init_weights is not None:
            assert len(init_weights) == len(self._weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self._param_shapes[i]))
                self._weights[i].data = init_weights[i]

        ######################
        ### Finalize Setup ###
        ######################
        num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes)
        print('Constructed MLP that processes dimensionality reduced inputs ' +
              'through chunking. The network has a total of %d weights.' %
              num_weights)

        self._is_properly_setup()