Esempio n. 1
0
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hypernet_args`.
        num_tasks (int): The number of task embeddings the hypernetwork should
            have.
        device: PyTorch device.
        mnet_shapes: Dimensions of the weight tensors of the main network.
            See main net argument
            :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hypernet_args`.

    Returns:
        The created hypernet model.
    """
    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    hyper_chunks = misc.str_to_ints(gc('hyper_chunks'))
    assert (len(hyper_chunks) in [1, 2, 3])
    if len(hyper_chunks) == 1:
        hyper_chunks = hyper_chunks[0]

    hnet_arch = misc.str_to_ints(gc('hnet_arch'))
    sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters'))
    sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels'))
    sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers'))

    hnet_act = misc.str_to_act(gc('hnet_act'))

    if isinstance(hyper_chunks, list):  # Chunked self-attention hypernet
        if len(sa_hnet_kernels) == 1:
            sa_hnet_kernels = sa_hnet_kernels[0]
        # Note, that the user can specify the kernel size for each dimension and
        # layer separately.
        elif len(sa_hnet_kernels) > 2 and \
            len(sa_hnet_kernels) == gc('sa_hnet_num_layers') * 2:
            tmp = sa_hnet_kernels
            sa_hnet_kernels = []
            for i in range(0, len(tmp), 2):
                sa_hnet_kernels.append([tmp[i], tmp[i + 1]])

        if gc('hnet_dropout_rate') != -1:
            warn('SA-Hypernet doesn\'t use dropout. Dropout rate will be ' +
                 'ignored.')
        if gc('hnet_act') != 'relu':
            warn('SA-Hypernet doesn\'t support the other non-linearities ' +
                 'than ReLUs yet. Option "%shnet_act" (%s) will be ignored.' %
                 (cprefix, gc('hnet_act')))

        hnet = SAHyperNetwork(
            mnet_shapes,
            num_tasks,
            out_size=hyper_chunks,
            num_layers=gc('sa_hnet_num_layers'),
            num_filters=sa_hnet_filters,
            kernel_size=sa_hnet_kernels,
            sa_units=sa_hnet_attention_layers,
            # Note, we don't use an additional hypernet for the remaining
            # weights!
            #rem_layers=hnet_arch,
            te_dim=gc('temb_size'),
            ce_dim=gc('emb_size'),
            no_theta=False,
            # Batchnorm and spectral norma are not yet implemented.
            #use_batch_norm=gc('hnet_batchnorm'),
            #use_spectral_norm=gc('hnet_specnorm'),
            # Droput would only be used for the additional network, which we
            # don't use.
            #dropout_rate=gc('hnet_dropout_rate'),
            discard_remainder=True,
            noise_dim=gc('hnet_noise_dim'),
            temb_std=gc('temb_std')).to(device)

    elif hyper_chunks != -1:  # Chunked fully-connected hypernet
        hnet = ChunkedHyperNetworkHandler(mnet_shapes,
                                          num_tasks,
                                          chunk_dim=hyper_chunks,
                                          layers=hnet_arch,
                                          activation_fn=hnet_act,
                                          te_dim=gc('temb_size'),
                                          ce_dim=gc('emb_size'),
                                          dropout_rate=gc('hnet_dropout_rate'),
                                          noise_dim=gc('hnet_noise_dim'),
                                          temb_std=gc('temb_std')).to(device)

    else:  # Fully-connected hypernet.
        hnet = HyperNetwork(mnet_shapes,
                            num_tasks,
                            layers=hnet_arch,
                            te_dim=gc('temb_size'),
                            activation_fn=hnet_act,
                            dropout_rate=gc('hnet_dropout_rate'),
                            noise_dim=gc('hnet_noise_dim'),
                            temb_std=gc('temb_std')).to(device)

    return hnet
Esempio n. 2
0
def generate_classifier(config, data_handlers, device):
    """Create a classifier network. Depending on the experiment and method, 
    the method manages to build either a classifier for task inference 
    or a classifier that solves our task is build. This also implies if the
    network will receive weights from a hypernetwork or will have weights 
    on its own.
    Following important configurations will be determined in order to create
    the classifier: \n 
    * in- and output and hidden layer dimensions of the classifier. \n
    * architecture, chunk- and task-embedding details of the hypernetwork. 


    See :class:`mnets.mlp.MLP` for details on the network that will be created
        to be a classifier. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.
        
    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
    
    Returns: 
        (tuple): Tuple containing:
        - **net**: The classifier network.
        - **class_hnet**: (optional) The classifier's hypernetwork.
    """
    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2

    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.class_incremental:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.training_task_infer or config.class_incremental:
        # task inference network
        config.out_dim = 1

    # have all output neurons already build up for cl 2
    if config.cl_scenario != 2:
        n_out = config.out_dim * config.num_tasks
    else:
        n_out = config.out_dim

    if config.training_task_infer or config.class_incremental:
        n_out = config.num_tasks

        # build classifier
    print('For the Classifier: ')
    class_arch = misc.str_to_ints(config.class_fc_arch)
    if config.training_with_hnet:
        no_weights = True
    else:
        no_weights = False

    net = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=class_arch,
              activation_fn=misc.str_to_act(config.class_net_act),
              dropout_rate=config.class_dropout_rate,
              no_weights=no_weights).to(device)

    print('Constructed MLP with shapes: ', net.param_shapes)

    config.num_weights_class_net = \
        MainNetInterface.shapes_to_num_weights(net.param_shapes)
    # build classifier hnet
    # this is set in the run method in train.py
    if config.training_with_hnet:

        class_hnet = sim_utils.get_hnet_model(config,
                                              config.num_tasks,
                                              device,
                                              net.param_shapes,
                                              cprefix='class_')
        init_params = list(class_hnet.parameters())

        config.num_weights_class_hyper_net = sum(
            p.numel() for p in class_hnet.parameters() if p.requires_grad)
        config.compression_ratio_class = config.num_weights_class_hyper_net / \
                                         config.num_weights_class_net
        print('Created classifier Hypernetwork with ratio: ',
              config.compression_ratio_class)
        if config.compression_ratio_class > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network, not might not be directly ' +
                  'comparable with the number of parameters of work we ' +
                  'compare against.')
    else:
        class_hnet = None
        init_params = list(net.parameters())
        config.num_weights_class_hyper_net = None
        config.compression_ratio_class = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if config.training_with_hnet:
        for temb in class_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(class_hnet, 'chunk_embeddings'):
        for emb in class_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    if not config.training_with_hnet:
        return net
    else:
        return net, class_hnet
Esempio n. 3
0
def get_hypernet(config,
                 device,
                 net_type,
                 target_shapes,
                 num_conds,
                 no_cond_weights=False,
                 no_uncond_weights=False,
                 uncond_in_size=0,
                 shmlp_chunk_shapes=None,
                 shmlp_num_per_chunk=None,
                 shmlp_assembly_fct=None,
                 verbose=True,
                 cprefix=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    Args:
        config (argparse.Namespace): Command-line arguments.

            Note:
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hnet_args`.
        device: PyTorch device.
        net_type (str): The type of network. The following options are
            available:

            - ``'hmlp'``
            - ``'chunked_hmlp'``
            - ``'structured_hmlp'``
            - ``'hdeconv'``
            - ``'chunked_hdeconv'``
        target_shapes (list): See argument ``target_shapes`` of
            :class:`hnets.mlp_hnet.HMLP`.
        num_conds (int): Number of conditions that should be known to the
            hypernetwork.
        no_cond_weights (bool): See argument ``no_cond_weights`` of
            :class:`hnets.mlp_hnet.HMLP`.
        no_uncond_weights (bool): See argument ``no_uncond_weights`` of
            :class:`hnets.mlp_hnet.HMLP`.
        uncond_in_size (int): See argument ``uncond_in_size`` of
            :class:`hnets.mlp_hnet.HMLP`.
        shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this function are prefixed, since several
            hypernetworks should be generated.

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hnet_args`.
    """
    assert net_type in [
        'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv'
    ]

    hnet = None

    ### FIXME Code almost identically copied from `get_mnet_model` ###
    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    if hc('hnet_net_act'):
        net_act = gc('hnet_net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('hnet_no_bias')
    dropout_rate = get_val('hnet_dropout_rate')
    specnorm = get_val('hnet_specnorm')
    batchnorm = get_val('hnet_batchnorm')
    no_batchnorm = get_val('hnet_no_batchnorm')
    #bn_no_running_stats = get_val('hnet_bn_no_running_stats')
    #n_distill_stats = get_val('hnet_bn_distill_stats')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x
    ### FIXME Code copied until here                               ###

    if hc('hmlp_arch'):
        hmlp_arch_is_list = False
        hmlp_arch = gc('hmlp_arch')
        if ';' in hmlp_arch:
            hmlp_arch_is_list = True
            if net_type != 'structured_hmlp':
                raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) +
                                 'contain semicolons for network type ' +
                                 '"structured_hmlp"!')
            hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')]
        else:
            hmlp_arch = misc.str_to_ints(hmlp_arch)
    if hc('chunk_emb_size'):
        chunk_emb_size = gc('chunk_emb_size')
        chunk_emb_size = misc.str_to_ints(chunk_emb_size)
        if len(chunk_emb_size) == 1:
            chunk_emb_size = chunk_emb_size[0]
        else:
            if net_type != 'structured_hmlp':
                raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) +
                                 'only contain multiple values for network ' +
                                 'type "structured_hmlp"!')

    if hc('cond_emb_size'):
        cond_emb_size = gc('cond_emb_size')
    else:
        cond_emb_size = 0

    if net_type == 'hmlp':
        assert hc('hmlp_arch')

        # Default keyword arguments of class HMLP.
        dkws = misc.get_default_args(HMLP.__init__)

        hnet = HMLP(target_shapes,
                    uncond_in_size=uncond_in_size,
                    cond_in_size=cond_emb_size,
                    layers=hmlp_arch,
                    verbose=verbose,
                    activation_fn=assign(net_act, dkws['activation_fn']),
                    use_bias=assign(not no_bias, dkws['use_bias']),
                    no_uncond_weights=no_uncond_weights,
                    no_cond_weights=no_cond_weights,
                    num_cond_embs=num_conds,
                    dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
                    use_spectral_norm=assign(specnorm,
                                             dkws['use_spectral_norm']),
                    use_batch_norm=assign(use_bn,
                                          dkws['use_batch_norm'])).to(device)

    elif net_type == 'chunked_hmlp':
        assert hc('hmlp_arch')
        assert hc('chmlp_chunk_size')
        assert hc('chunk_emb_size')
        cond_chunk_embs = get_val('use_cond_chunk_embs')

        # Default keyword arguments of class ChunkedHMLP.
        dkws = misc.get_default_args(ChunkedHMLP.__init__)

        hnet = ChunkedHMLP(
            target_shapes,
            gc('chmlp_chunk_size'),
            chunk_emb_size=chunk_emb_size,
            cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']),
            uncond_in_size=uncond_in_size,
            cond_in_size=cond_emb_size,
            layers=hmlp_arch,
            verbose=verbose,
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_uncond_weights=no_uncond_weights,
            no_cond_weights=no_cond_weights,
            num_cond_embs=num_conds,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device)

    elif net_type == 'structured_hmlp':
        assert hc('hmlp_arch')
        assert hc('chunk_emb_size')
        cond_chunk_embs = get_val('use_cond_chunk_embs')

        assert shmlp_chunk_shapes is not None and \
            shmlp_num_per_chunk is not None and \
            shmlp_assembly_fct is not None

        # Default keyword arguments of class StructuredHMLP.
        dkws = misc.get_default_args(StructuredHMLP.__init__)
        dkws_hmlp = misc.get_default_args(HMLP.__init__)

        shmlp_hmlp_kwargs = []
        if not hmlp_arch_is_list:
            hmlp_arch = [hmlp_arch]
        for i, arch in enumerate(hmlp_arch):
            shmlp_hmlp_kwargs.append({
                'layers': arch,
                'activation_fn': assign(net_act, dkws_hmlp['activation_fn']),
                'use_bias': assign(not no_bias, dkws_hmlp['use_bias']),
                'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']),
                'use_spectral_norm': \
                    assign(specnorm, dkws_hmlp['use_spectral_norm']),
                'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm'])
            })
        if len(shmlp_hmlp_kwargs) == 1:
            shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0]

        hnet = StructuredHMLP(target_shapes,
                              shmlp_chunk_shapes,
                              shmlp_num_per_chunk,
                              chunk_emb_size,
                              shmlp_hmlp_kwargs,
                              shmlp_assembly_fct,
                              cond_chunk_embs=assign(cond_chunk_embs,
                                                     dkws['cond_chunk_embs']),
                              uncond_in_size=uncond_in_size,
                              cond_in_size=cond_emb_size,
                              verbose=verbose,
                              no_uncond_weights=no_uncond_weights,
                              no_cond_weights=no_cond_weights,
                              num_cond_embs=num_conds).to(device)

    elif net_type == 'hdeconv':
        #HDeconv
        raise NotImplementedError
    else:
        assert net_type == 'chunked_hdeconv'
        #ChunkedHDeconv
        raise NotImplementedError

    return hnet
Esempio n. 4
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:
            
            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.

    Returns:
        The created main network model.
    """
    assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net'])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # FIXME if an argument wasn't specified, then we use the default value that
    # is currently (at time of implementation) in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, torch.nn.ReLU()),
            use_bias=not assign(no_bias, False),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, -1),
            use_spectral_norm=assign(specnorm, False),
            use_batch_norm=assign(use_bn, False),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, True),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
        ).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, 0.25)).to(device)
    else:
        assert (net_type == 'bio_conv_net')
        assert (len(out_shape) == 1)

        raise NotImplementedError('Implementation not publicly available!')

    return mnet
Esempio n. 5
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False,
                   **mnet_kwargs):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:

            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``wrn``: :class:`mnets.wide_resnet.WRN`
            - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
            - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer`
            - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.
        **mnet_kwargs: Additional keyword arguments that will be passed to the
            main network constructor.

    Returns:
        The created main network model.
    """
    assert (net_type in [
        'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp',
        'simple_rnn', 'wrn', 'iresnet'
    ])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    # This argument has to be handled during usage of the network and not during
    # construction.
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        # Default keyword arguments of class MLP.
        dkws = misc.get_default_args(MLP.__init__)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)
        assert hc('resnet_block_depth') and hc('resnet_channel_sizes')

        # Default keyword arguments of class ResNet.
        dkws = misc.get_default_args(ResNet.__init__)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('resnet_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')),
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'wrn':
        assert (len(out_shape) == 1)
        assert hc('wrn_block_depth') and hc('wrn_widening_factor')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(WRN.__init__)

        mnet = WRN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('wrn_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')),
            verbose=True,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            k=gc('wrn_widening_factor'),
            use_fc_bias=gc('wrn_use_fc_bias'),
            dropout_rate=gc('dropout_rate'),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'iresnet':
        assert (len(out_shape) == 1)
        assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \
            and hc('iresnet_blocks_per_group') \
            and hc('iresnet_bottleneck_blocks') \
            and hc('iresnet_projection_shortcut')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(ResNetIN.__init__)

        mnet = ResNetIN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            use_bias=assign(not no_bias, dkws['use_bias']),
            use_fc_bias=gc('wrn_use_fc_bias'),
            num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')),
            blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')),
            projection_shortcut=gc('iresnet_projection_shortcut'),
            bottleneck_blocks=gc('iresnet_bottleneck_blocks'),
            #cutout_mod=False,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #chw_input_format=False,
            verbose=True,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class ZenkeNet.
        dkws = misc.get_default_args(ZenkeNet.__init__)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            **mnet_kwargs).to(device)

    elif net_type == 'bio_conv_net':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class BioConvNet.
        #dkws = misc.get_default_args(BioConvNet.__init__)

        mnet = BioConvNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            no_weights=no_weights,
            #init_weights=None,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'chunked_mlp':
        assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \
               hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \
               hc('cmlp_cemb_dim')
        assert len(in_shape) == 1 and len(out_shape) == 1

        # Default keyword arguments of class ChunkSqueezer.
        dkws = misc.get_default_args(ChunkSqueezer.__init__)

        mnet = ChunkSqueezer(
            n_in=in_shape[0],
            n_out=out_shape[0],
            inp_chunk_dim=gc('cmlp_in_cdim'),
            out_chunk_dim=gc('cmlp_out_cdim'),
            cemb_size=gc('cmlp_cemb_dim'),
            #cemb_init_std=1.,
            red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')),
            net_layers=misc.str_to_ints(gc('cmlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #dynamic_biases=None,
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'lenet':
        assert hc('lenet_type')
        assert len(out_shape) == 1

        # Default keyword arguments of class LeNet.
        dkws = misc.get_default_args(LeNet.__init__)

        mnet = LeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,
            arch=gc('lenet_type'),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            # TODO Context-mod weights.
            **mnet_kwargs).to(device)

    else:
        assert (net_type == 'simple_rnn')
        assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \
            hc('srnn_post_fc_layers')  and hc('srnn_no_fc_out') and \
            hc('srnn_rec_type')
        assert len(in_shape) == 1 and len(out_shape) == 1

        if gc('srnn_rec_type') == 'lstm':
            use_lstm = True
        else:
            assert gc('srnn_rec_type') == 'elman'
            use_lstm = False

        # Default keyword arguments of class SimpleRNN.
        dkws = misc.get_default_args(SimpleRNN.__init__)

        rnn_layers = misc.str_to_ints(gc('srnn_rec_layers'))
        fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers'))
        if gc('srnn_no_fc_out'):
            rnn_layers.append(out_shape[0])
        else:
            fc_layers.append(out_shape[0])

        mnet = SimpleRNN(n_in=in_shape[0],
                         rnn_layers=rnn_layers,
                         fc_layers_pre=misc.str_to_ints(
                             gc('srnn_pre_fc_layers')),
                         fc_layers=fc_layers,
                         activation=assign(net_act, dkws['activation']),
                         use_lstm=use_lstm,
                         use_bias=assign(not no_bias, dkws['use_bias']),
                         no_weights=no_weights,
                         verbose=True,
                         **mnet_kwargs).to(device)

    return mnet
Esempio n. 6
0
def _generate_networks(config,
                       data_handlers,
                       device,
                       create_hnet=True,
                       create_rnet=False,
                       no_replay=False):
    """Create the main-net, hypernetwork and recognition network.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        create_hnet: Whether a hypernetwork should be constructed. If not, the
            main network will have trainable weights.
        create_rnet: Whether a task-recognition autoencoder should be created.
        no_replay: If the recognition network should be an instance of class
            MainModel rather than of class RecognitionNet (note, for multitask
            learning, no replay network is required).

    Returns:
        mnet: Main network instance.
        hnet: Hypernetwork instance. This return value is None if no
            hypernetwork should be constructed.
        rnet: RecognitionNet instance. This return value is None if no
            recognition network should be constructed.
    """
    num_tasks = len(data_handlers)

    n_x = data_handlers[0].in_shape[0]
    n_y = data_handlers[0].out_shape[0]
    if config.multi_head:
        n_y = n_y * num_tasks

    main_arch = misc.str_to_ints(config.main_arch)
    main_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                            n_out=n_y,
                                            hidden_layers=main_arch)
    mnet = MainNetwork(main_shapes,
                       activation_fn=misc.str_to_act(config.main_act),
                       use_bias=True,
                       no_weights=create_hnet).to(device)
    if create_hnet:
        hnet_arch = misc.str_to_ints(config.hnet_arch)
        hnet = HyperNetwork(main_shapes,
                            num_tasks,
                            layers=hnet_arch,
                            te_dim=config.emb_size,
                            activation_fn=misc.str_to_act(
                                config.hnet_act)).to(device)
        init_params = list(hnet.parameters())
    else:
        hnet = None
        init_params = list(mnet.parameters())

    if create_rnet:
        ae_arch = misc.str_to_ints(config.ae_arch)
        if no_replay:
            rnet_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                                    n_out=num_tasks,
                                                    hidden_layers=ae_arch,
                                                    use_bias=True)
            rnet = MainNetwork(rnet_shapes,
                               activation_fn=misc.str_to_act(config.ae_act),
                               use_bias=True,
                               no_weights=False,
                               dropout_rate=-1,
                               out_fn=lambda x: F.softmax(x, dim=1))
        else:
            rnet = RecognitionNet(n_x,
                                  num_tasks,
                                  dim_z=config.ae_dim_z,
                                  enc_layers=ae_arch,
                                  activation_fn=misc.str_to_act(config.ae_act),
                                  use_bias=True).to(device)
        init_params += list(rnet.parameters())
    else:
        rnet = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        elif config.normal_init:
            torch.nn.init.normal_(W, mean=0, std=config.std_normal_init)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_hnet:
        for temb in hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if config.use_hyperfan_init:
        hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2)

    return mnet, hnet, rnet
Esempio n. 7
0
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None,
                   no_weights=False, no_tembs=False, temb_size=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    .. deprecated:: 1.0
        Please use function :func:`get_hypernet` instead. As this function
        creates deprecated hypernetworks.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hypernet_args`.
        num_tasks (int): The number of task embeddings the hypernetwork should
            have.
        device: PyTorch device.
        mnet_shapes: Dimensions of the weight tensors of the main network.
            See main net argument
            :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hypernet_args`.
        no_weights (bool): Whether the hyper network should be generated without
            internal weights (excluding task embeddings).
        no_tembs (bool): Whether the hypernetwork should be generated without
            internally maintained task embeddings.
        temb_size (int, optional): If user config should be overwritten, then
            this option can be used to specify the dimensionality of task
            embeddings.

    Returns:
        The created hypernet model.
    """
    warn('Please use function "utils.sim_utils.get_hypernet" instead. As ' +\
         'this function creates deprecated hypernetworks.',
         DeprecationWarning)

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    hyper_chunks = misc.str_to_ints(gc('hyper_chunks'))
    assert(len(hyper_chunks) in [1,2,3])
    if len(hyper_chunks) == 1:
        hyper_chunks = hyper_chunks[0]

    hnet_arch = misc.str_to_ints(gc('hnet_arch'))
    sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters'))
    sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels'))
    sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers'))

    hnet_act = misc.str_to_act(gc('hnet_act'))

    if temb_size is None:
        temb_size = gc('temb_size')

    if isinstance(hyper_chunks, list): # Chunked self-attention hypernet
        raise NotImplementedError('Not publicly available')

    elif hyper_chunks != -1: # Chunked fully-connected hypernet
        hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks,
            chunk_dim=hyper_chunks, layers=hnet_arch,
            activation_fn=hnet_act, te_dim=temb_size, no_te_embs=no_tembs,
            ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'),
            noise_dim=gc('hnet_noise_dim'), no_weights=no_weights,
            temb_std=gc('temb_std')).to(device)

    else: # Fully-connected hypernet.
        hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch,
            te_dim=temb_size, no_te_embs=no_tembs, activation_fn=hnet_act,
            dropout_rate=gc('hnet_dropout_rate'),
            noise_dim=gc('hnet_noise_dim'), no_weights=no_weights,
            temb_std=gc('temb_std')).to(device)

    return hnet
Esempio n. 8
0
def generate_replay_networks(config,
                             data_handlers,
                             device,
                             create_rp_hnet=True,
                             only_train_replay=False):
    """Create a replay model that consists of either a encoder/decoder or
    a discriminator/generator pair. Additionally, this method manages the 
    creation of a hypernetwork for the generator/decoder. 
    Following important configurations will be determined in order to create
    the replay model: 
    * in- and output and hidden layer dimensions of the encoder/decoder. 
    * architecture, chunk- and task-embedding details of decoder's hypernetwork. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device..
        create_rp_hnet: Whether a hypernetwork for the replay should be 
            constructed. If not, the decoder/generator will have 
            trainable weights on its own.
        only_train_replay: We normally do not train on the last task since we do 
            not need to replay this last tasks data. But if we want a replay 
            method to be able to generate data from all tasks then we set this 
            option to true.

    Returns:
        (tuple): Tuple containing:

        - **enc**: Encoder/discriminator network instance.
        - **dec**: Decoder/generator networkinstance.
        - **dec_hnet**: Hypernetwork instance for the decoder/generator. This 
            return value is None if no hypernetwork should be constructed.
    """

    if config.replay_method == 'gan':
        n_out = 1
    else:
        n_out = config.latent_dim * 2

    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2
    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.single_class_replay:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.infer_task_id:
        # task inference network
        config.out_dim = 1

    # builld encoder
    print('For the replay encoder/discriminator: ')
    enc_arch = misc.str_to_ints(config.enc_fc_arch)
    enc = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=enc_arch,
              activation_fn=misc.str_to_act(config.enc_net_act),
              dropout_rate=config.enc_dropout_rate,
              no_weights=False).to(device)
    print('Constructed MLP with shapes: ', enc.param_shapes)
    init_params = list(enc.parameters())
    # builld decoder
    print('For the replay decoder/generator: ')
    dec_arch = misc.str_to_ints(config.dec_fc_arch)
    # add dimensions for conditional input
    n_out = config.latent_dim

    if config.conditional_replay:
        n_out += config.conditional_dim

    dec = MLP(n_in=n_out,
              n_out=n_in,
              hidden_layers=dec_arch,
              activation_fn=misc.str_to_act(config.dec_net_act),
              use_bias=True,
              no_weights=config.rp_beta > 0,
              dropout_rate=config.dec_dropout_rate).to(device)

    print('Constructed MLP with shapes: ', dec.param_shapes)
    config.num_weights_enc = \
                        MainNetInterface.shapes_to_num_weights(enc.param_shapes)

    config.num_weights_dec = \
                        MainNetInterface.shapes_to_num_weights(dec.param_shapes)
    config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec
    # we do not need a replay model for the last task

    # train on last task or not
    if only_train_replay:
        subtr = 0
    else:
        subtr = 1

    num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1

    if config.single_class_replay:
        # we do not need a replay model for the last task
        if config.num_tasks > 1:
            num_embeddings = config.out_dim * (config.num_tasks - subtr)
        else:
            num_embeddings = config.out_dim * (config.num_tasks)

    config.num_embeddings = num_embeddings
    # build decoder hnet
    if create_rp_hnet:
        print('For the decoder/generator hypernetwork: ')
        d_hnet = sim_utils.get_hnet_model(config,
                                          num_embeddings,
                                          device,
                                          dec.hyper_shapes_learned,
                                          cprefix='rp_')

        init_params += list(d_hnet.parameters())

        config.num_weights_rp_hyper_net = sum(p.numel()
                                              for p in d_hnet.parameters()
                                              if p.requires_grad)
        config.compression_ratio_rp = config.num_weights_rp_hyper_net / \
                                                        config.num_weights_dec
        print('Created replay hypernetwork with ratio: ',
              config.compression_ratio_rp)
        if config.compression_ratio_rp > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network,\nthis might not be directly ' +
                  'comparable with the number of parameters of methods we ' +
                  'compare against.')
    else:
        num_embeddings = config.num_tasks - subtr
        d_hnet = None
        init_params += list(dec.parameters())
        config.num_weights_rp_hyper_net = 0
        config.compression_ratio_rp = 0

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_rp_hnet:
        for temb in d_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(d_hnet, 'chunk_embeddings'):
        for emb in d_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    return enc, dec, d_hnet
Esempio n. 9
0
def generate_gauss_networks(config,
                            logger,
                            data_handlers,
                            device,
                            no_mnet_weights=None,
                            create_hnet=True,
                            in_shape=None,
                            out_shape=None,
                            net_type='mlp',
                            non_gaussian=False):
    """Create main network and potentially the corresponding hypernetwork.

    The function will first create a normal MLP and then convert it into a
    network with Gaussian weight distribution by using the wrapper
    :class:`probabilistic.gauss_mnet_interface.GaussianBNNWrapper`.

    This function also takes care of weight initialization.

    Args:
        config: Command-line arguments.
        logger: Console (and file) logger.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        no_mnet_weights (bool, optional): Whether the main network should not
            have trainable weights. If left unspecified, then the main network
            will only have trainable weights if ``create_hnet`` is ``False``.
        create_hnet (bool): Whether a hypernetwork should be constructed.
        in_shape (list, optional): Input shape that is passed to function
            :func:`utils.sim_utils.get_mnet_model` as argument ``in_shape``.
            If not specified, it is set to ``[data_handlers[0].in_shape[0]]``.
        out_shape (list, optional): Output shape that is passed to function
            :func:`utils.sim_utils.get_mnet_model` as argument ``out_shape``.
            If not specified, it is set to ``[data_handlers[0].out_shape[0]]``.
        net_type (str): See argument ``net_type`` of function
            :func:`utils.sim_utils.get_mnet_model`.
        non_gaussian (bool): If ``True``, then the main network will not be
            converted into a Gaussian network. Hence, networks remain
            deterministic.

    Returns:
        (tuple): Tuple containing:

        - **mnet**: Main network instance.
        - **hnet** (optional): Hypernetwork instance. This return value is
          ``None`` if no hypernetwork should be constructed.
    """
    assert not hasattr(config, 'mean_only') or config.mean_only == non_gaussian
    assert not non_gaussian or not config.local_reparam_trick
    assert not non_gaussian or not config.hyper_gauss_init

    num_tasks = len(data_handlers)

    # Should be set, except for regression.
    if in_shape is None or out_shape is None:
        assert in_shape is None and out_shape is None
        assert net_type == 'mlp'
        assert hasattr(config, 'multi_head')

        n_x = data_handlers[0].in_shape[0]
        n_y = data_handlers[0].out_shape[0]
        if config.multi_head:
            n_y = n_y * num_tasks

        in_shape = [n_x]
        out_shape = [n_y]

    ### Main network.
    logger.info('Creating main network ...')

    if no_mnet_weights is None:
        no_mnet_weights = create_hnet

    if config.local_reparam_trick:
        if net_type != 'mlp':
            raise NotImplementedError('The local reparametrization trick is ' +
                                      'only implemented for MLPs so far!')
            assert len(in_shape) == 1 and len(out_shape) == 1

        mlp_arch = utils.str_to_ints(config.mlp_arch)
        net_act = utils.str_to_act(config.net_act)
        mnet = GaussianMLP(n_in=in_shape[0],
                           n_out=out_shape[0],
                           hidden_layers=mlp_arch,
                           activation_fn=net_act,
                           use_bias=not config.no_bias,
                           no_weights=no_mnet_weights).to(device)
    else:
        mnet_kwargs = {}
        if net_type == 'iresnet':
            mnet_kwargs['cutout_mod'] = True
        mnet = sutils.get_mnet_model(config,
                                     net_type,
                                     in_shape,
                                     out_shape,
                                     device,
                                     no_weights=no_mnet_weights,
                                     **mnet_kwargs)

    # Initiaize main net weights, if any.
    assert (not hasattr(config, 'custom_network_init'))
    mnet.custom_init(normal_init=config.normal_init,
                     normal_std=config.std_normal_init,
                     zero_bias=True)

    # Convert main net into Gaussian BNN.
    orig_mnet = mnet
    if not non_gaussian:
        mnet = GaussianBNNWrapper(mnet,
                                  no_mean_reinit=config.keep_orig_init,
                                  logvar_encoding=config.use_logvar_enc,
                                  apply_rho_offset=True,
                                  is_radial=config.radial_bnn).to(device)
    else:
        logger.debug('Created main network will not be converted into a ' +
                     'Gaussian main network.')

    ### Hypernet.
    hnet = None
    if create_hnet:
        logger.info('Creating hypernetwork ...')

        chunk_shapes, num_per_chunk, assembly_fct = None, None, None
        if config.hnet_type == 'structured_hmlp':
            if net_type == 'resnet':
                chunk_shapes, num_per_chunk, orig_assembly_fct = \
                    resnet_chunking(orig_mnet,
                                    gcd_chunking=config.shmlp_gcd_chunking)
            elif net_type == 'wrn':
                chunk_shapes, num_per_chunk, orig_assembly_fct = \
                    wrn_chunking(orig_mnet,
                        gcd_chunking=config.shmlp_gcd_chunking,
                        ignore_bn_weights=False, ignore_out_weights=False)
            else:
                raise NotImplementedError(
                    '"structured_hmlp" not implemented ' +
                    'for network of type %s.' % net_type)

            if non_gaussian:
                assembly_fct = orig_assembly_fct
            else:
                chunk_shapes = chunk_shapes + chunk_shapes
                num_per_chunk = num_per_chunk + num_per_chunk

                def assembly_fct_gauss(list_of_chunks):
                    n = len(list_of_chunks)
                    mean_chunks = list_of_chunks[:n // 2]
                    rho_chunks = list_of_chunks[n // 2:]

                    return orig_assembly_fct(mean_chunks) + \
                        orig_assembly_fct(rho_chunks)

                assembly_fct = assembly_fct_gauss

        # For now, we either produce all or no weights with the hypernet.
        # Note, it can be that the mnet was produced with internal weights.
        assert mnet.hyper_shapes_learned is None or \
            len(mnet.param_shapes) == len(mnet.hyper_shapes_learned)

        hnet = sutils.get_hypernet(config,
                                   device,
                                   config.hnet_type,
                                   mnet.param_shapes,
                                   num_tasks,
                                   shmlp_chunk_shapes=chunk_shapes,
                                   shmlp_num_per_chunk=num_per_chunk,
                                   shmlp_assembly_fct=assembly_fct)

        if config.hnet_out_masking != 0:
            logger.info('Generating binary masks to select task-specific ' +
                        'subnetworks from hypernetwork.')
            # Add a wrapper around the hypernpetwork that masks its outputs
            # using a task-specific binary mask layer per layer. Note that
            # output weights are not masked.

            # Ensure that masks are kind of deterministic for a given hyper-
            # param config/task.
            mask_gen = torch.Generator()
            mask_gen = mask_gen.manual_seed(42)

            # Generate a random binary mask per task.
            assert len(mnet.param_shapes) == len(hnet.target_shapes)
            hnet_out_masks = []
            for tid in range(config.num_tasks):
                hnet_out_mask = []
                for layer_shapes, is_output in zip(mnet.param_shapes, \
                        mnet.get_output_weight_mask()):
                    layer_mask = torch.ones(layer_shapes)
                    if is_output is None:
                        # We only mask weights that are not output weights.
                        layer_mask = torch.rand(layer_shapes,
                                                generator=mask_gen)
                        layer_mask[layer_mask > config.hnet_out_masking] = 1
                        layer_mask[layer_mask <= config.hnet_out_masking] = 0
                    hnet_out_mask.append(layer_mask)
                hnet_out_masks.append(hnet_out_mask)

            hnet_out_masks = hnet.convert_out_format(hnet_out_masks,
                                                     'sequential', 'flattened')

            def hnet_out_masking_func(hnet_out_int,
                                      uncond_input=None,
                                      cond_input=None,
                                      cond_id=None):
                assert isinstance(cond_id, (int, list))
                if isinstance(cond_id, int):
                    cond_id = [cond_id]

                hnet_out_int[hnet_out_masks[cond_id, :] == 0] = 0
                return hnet_out_int

            def hnet_inp_handler(uncond_input=None,
                                 cond_input=None,
                                 cond_id=None):  # Identity
                return uncond_input, cond_input, cond_id

            hnet = HPerturbWrapper(hnet,
                                   output_handler=hnet_out_masking_func,
                                   input_handler=hnet_inp_handler)

        #if config.hnet_type == 'structured_hmlp':
        #    print(num_per_chunk)
        #    for ii, int_hnet in enumerate(hnet.internal_hnets):
        #        print('   Internal hnet %d with %d outputs.' % \
        #              (ii, int_hnet.num_outputs))

        ### Initialize hypernetwork.
        if not config.hyper_gauss_init:
            apply_custom_hnet_init(config, logger, hnet)
        else:
            # Initialize task embeddings, if any.
            hnet_helpers.init_conditional_embeddings(
                hnet, normal_std=config.std_normal_temb)

            gauss_hyperfan_init(hnet,
                                mnet=mnet,
                                use_xavier=True,
                                cond_var=config.std_normal_temb**2,
                                keep_hyperfan_mean=config.keep_orig_init)

    return mnet, hnet