Ejemplo n.º 1
0
def generate_classifier(config, data_handlers, device):
    """Create a classifier network. Depending on the experiment and method, 
    the method manages to build either a classifier for task inference 
    or a classifier that solves our task is build. This also implies if the
    network will receive weights from a hypernetwork or will have weights 
    on its own.
    Following important configurations will be determined in order to create
    the classifier: \n 
    * in- and output and hidden layer dimensions of the classifier. \n
    * architecture, chunk- and task-embedding details of the hypernetwork. 


    See :class:`mnets.mlp.MLP` for details on the network that will be created
        to be a classifier. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.
        
    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
    
    Returns: 
        (tuple): Tuple containing:
        - **net**: The classifier network.
        - **class_hnet**: (optional) The classifier's hypernetwork.
    """
    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2

    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.class_incremental:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.training_task_infer or config.class_incremental:
        # task inference network
        config.out_dim = 1

    # have all output neurons already build up for cl 2
    if config.cl_scenario != 2:
        n_out = config.out_dim * config.num_tasks
    else:
        n_out = config.out_dim

    if config.training_task_infer or config.class_incremental:
        n_out = config.num_tasks

        # build classifier
    print('For the Classifier: ')
    class_arch = misc.str_to_ints(config.class_fc_arch)
    if config.training_with_hnet:
        no_weights = True
    else:
        no_weights = False

    net = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=class_arch,
              activation_fn=misc.str_to_act(config.class_net_act),
              dropout_rate=config.class_dropout_rate,
              no_weights=no_weights).to(device)

    print('Constructed MLP with shapes: ', net.param_shapes)

    config.num_weights_class_net = \
        MainNetInterface.shapes_to_num_weights(net.param_shapes)
    # build classifier hnet
    # this is set in the run method in train.py
    if config.training_with_hnet:

        class_hnet = sim_utils.get_hnet_model(config,
                                              config.num_tasks,
                                              device,
                                              net.param_shapes,
                                              cprefix='class_')
        init_params = list(class_hnet.parameters())

        config.num_weights_class_hyper_net = sum(
            p.numel() for p in class_hnet.parameters() if p.requires_grad)
        config.compression_ratio_class = config.num_weights_class_hyper_net / \
                                         config.num_weights_class_net
        print('Created classifier Hypernetwork with ratio: ',
              config.compression_ratio_class)
        if config.compression_ratio_class > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network, not might not be directly ' +
                  'comparable with the number of parameters of work we ' +
                  'compare against.')
    else:
        class_hnet = None
        init_params = list(net.parameters())
        config.num_weights_class_hyper_net = None
        config.compression_ratio_class = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if config.training_with_hnet:
        for temb in class_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(class_hnet, 'chunk_embeddings'):
        for emb in class_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    if not config.training_with_hnet:
        return net
    else:
        return net, class_hnet
Ejemplo n.º 2
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:
            
            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.

    Returns:
        The created main network model.
    """
    assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net'])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # FIXME if an argument wasn't specified, then we use the default value that
    # is currently (at time of implementation) in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, torch.nn.ReLU()),
            use_bias=not assign(no_bias, False),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, -1),
            use_spectral_norm=assign(specnorm, False),
            use_batch_norm=assign(use_bn, False),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, True),
            bn_track_stats=assign(not bn_no_running_stats, True),
            distill_bn_stats=assign(bn_distill_stats, False),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
        ).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, 0.25)).to(device)
    else:
        assert (net_type == 'bio_conv_net')
        assert (len(out_shape) == 1)

        raise NotImplementedError('Implementation not publicly available!')

    return mnet
Ejemplo n.º 3
0
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hypernet_args`.
        num_tasks (int): The number of task embeddings the hypernetwork should
            have.
        device: PyTorch device.
        mnet_shapes: Dimensions of the weight tensors of the main network.
            See main net argument
            :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hypernet_args`.

    Returns:
        The created hypernet model.
    """
    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    hyper_chunks = misc.str_to_ints(gc('hyper_chunks'))
    assert (len(hyper_chunks) in [1, 2, 3])
    if len(hyper_chunks) == 1:
        hyper_chunks = hyper_chunks[0]

    hnet_arch = misc.str_to_ints(gc('hnet_arch'))
    sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters'))
    sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels'))
    sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers'))

    hnet_act = misc.str_to_act(gc('hnet_act'))

    if isinstance(hyper_chunks, list):  # Chunked self-attention hypernet
        if len(sa_hnet_kernels) == 1:
            sa_hnet_kernels = sa_hnet_kernels[0]
        # Note, that the user can specify the kernel size for each dimension and
        # layer separately.
        elif len(sa_hnet_kernels) > 2 and \
            len(sa_hnet_kernels) == gc('sa_hnet_num_layers') * 2:
            tmp = sa_hnet_kernels
            sa_hnet_kernels = []
            for i in range(0, len(tmp), 2):
                sa_hnet_kernels.append([tmp[i], tmp[i + 1]])

        if gc('hnet_dropout_rate') != -1:
            warn('SA-Hypernet doesn\'t use dropout. Dropout rate will be ' +
                 'ignored.')
        if gc('hnet_act') != 'relu':
            warn('SA-Hypernet doesn\'t support the other non-linearities ' +
                 'than ReLUs yet. Option "%shnet_act" (%s) will be ignored.' %
                 (cprefix, gc('hnet_act')))

        hnet = SAHyperNetwork(
            mnet_shapes,
            num_tasks,
            out_size=hyper_chunks,
            num_layers=gc('sa_hnet_num_layers'),
            num_filters=sa_hnet_filters,
            kernel_size=sa_hnet_kernels,
            sa_units=sa_hnet_attention_layers,
            # Note, we don't use an additional hypernet for the remaining
            # weights!
            #rem_layers=hnet_arch,
            te_dim=gc('temb_size'),
            ce_dim=gc('emb_size'),
            no_theta=False,
            # Batchnorm and spectral norma are not yet implemented.
            #use_batch_norm=gc('hnet_batchnorm'),
            #use_spectral_norm=gc('hnet_specnorm'),
            # Droput would only be used for the additional network, which we
            # don't use.
            #dropout_rate=gc('hnet_dropout_rate'),
            discard_remainder=True,
            noise_dim=gc('hnet_noise_dim'),
            temb_std=gc('temb_std')).to(device)

    elif hyper_chunks != -1:  # Chunked fully-connected hypernet
        hnet = ChunkedHyperNetworkHandler(mnet_shapes,
                                          num_tasks,
                                          chunk_dim=hyper_chunks,
                                          layers=hnet_arch,
                                          activation_fn=hnet_act,
                                          te_dim=gc('temb_size'),
                                          ce_dim=gc('emb_size'),
                                          dropout_rate=gc('hnet_dropout_rate'),
                                          noise_dim=gc('hnet_noise_dim'),
                                          temb_std=gc('temb_std')).to(device)

    else:  # Fully-connected hypernet.
        hnet = HyperNetwork(mnet_shapes,
                            num_tasks,
                            layers=hnet_arch,
                            te_dim=gc('temb_size'),
                            activation_fn=hnet_act,
                            dropout_rate=gc('hnet_dropout_rate'),
                            noise_dim=gc('hnet_noise_dim'),
                            temb_std=gc('temb_std')).to(device)

    return hnet
Ejemplo n.º 4
0
def get_mnet_model(config,
                   net_type,
                   in_shape,
                   out_shape,
                   device,
                   cprefix=None,
                   no_weights=False,
                   **mnet_kwargs):
    """Generate a main network instance.

    A helper to generate a main network according to the given the user
    configurations.

    .. note::
        Generation of networks with context-modulation is not yet supported,
        since there is no global argument set in :mod:`utils.cli_args` yet.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.main_net_args`.
        net_type (str): The type of network. The following options are
            available:

            - ``mlp``: :class:`mnets.mlp.MLP`
            - ``resnet``: :class:`mnets.resnet.ResNet`
            - ``wrn``: :class:`mnets.wide_resnet.WRN`
            - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN`
            - ``zenke``: :class:`mnets.zenkenet.ZenkeNet`
            - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet`
            - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer`
            - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN`
        in_shape (list): Shape of network inputs. Can be ``None`` if not
            required by network type.

            For instance: For an MLP network :class:`mnets.mlp.MLP` with 100
            input neurons it should be :code:`in_shape=[100]`.
        out_shape (list): Shape of network outputs. See ``in_shape`` for more
            details.
        device: PyTorch device.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            main networks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.main_net_args`.
        no_weights (bool): Whether the main network should be generated without
            weights.
        **mnet_kwargs: Additional keyword arguments that will be passed to the
            main network constructor.

    Returns:
        The created main network model.
    """
    assert (net_type in [
        'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp',
        'simple_rnn', 'wrn', 'iresnet'
    ])

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    mnet = None

    if hc('net_act'):
        net_act = gc('net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('no_bias')
    dropout_rate = get_val('dropout_rate')
    specnorm = get_val('specnorm')
    batchnorm = get_val('batchnorm')
    no_batchnorm = get_val('no_batchnorm')
    bn_no_running_stats = get_val('bn_no_running_stats')
    bn_distill_stats = get_val('bn_distill_stats')
    # This argument has to be handled during usage of the network and not during
    # construction.
    #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x

    if net_type == 'mlp':
        assert (hc('mlp_arch'))
        assert (len(in_shape) == 1 and len(out_shape) == 1)

        # Default keyword arguments of class MLP.
        dkws = misc.get_default_args(MLP.__init__)

        mnet = MLP(
            n_in=in_shape[0],
            n_out=out_shape[0],
            hidden_layers=misc.str_to_ints(gc('mlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #out_fn=None,
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'resnet':
        assert (len(out_shape) == 1)
        assert hc('resnet_block_depth') and hc('resnet_channel_sizes')

        # Default keyword arguments of class ResNet.
        dkws = misc.get_default_args(ResNet.__init__)

        mnet = ResNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('resnet_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')),
            verbose=True,  #n=5,
            no_weights=no_weights,
            #init_weights=None,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'wrn':
        assert (len(out_shape) == 1)
        assert hc('wrn_block_depth') and hc('wrn_widening_factor')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(WRN.__init__)

        mnet = WRN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            n=gc('wrn_block_depth'),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')),
            verbose=True,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            k=gc('wrn_widening_factor'),
            use_fc_bias=gc('wrn_use_fc_bias'),
            dropout_rate=gc('dropout_rate'),
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'iresnet':
        assert (len(out_shape) == 1)
        assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \
            and hc('iresnet_blocks_per_group') \
            and hc('iresnet_bottleneck_blocks') \
            and hc('iresnet_projection_shortcut')

        # Default keyword arguments of class WRN.
        dkws = misc.get_default_args(ResNetIN.__init__)

        mnet = ResNetIN(
            in_shape=in_shape,
            num_classes=out_shape[0],
            use_bias=assign(not no_bias, dkws['use_bias']),
            use_fc_bias=gc('wrn_use_fc_bias'),
            num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')),
            blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')),
            projection_shortcut=gc('iresnet_projection_shortcut'),
            bottleneck_blocks=gc('iresnet_bottleneck_blocks'),
            #cutout_mod=False,
            no_weights=no_weights,
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            #chw_input_format=False,
            verbose=True,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'zenke':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class ZenkeNet.
        dkws = misc.get_default_args(ZenkeNet.__init__)

        mnet = ZenkeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,  #arch='cifar',
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            **mnet_kwargs).to(device)

    elif net_type == 'bio_conv_net':
        assert (len(out_shape) == 1)

        # Default keyword arguments of class BioConvNet.
        #dkws = misc.get_default_args(BioConvNet.__init__)

        mnet = BioConvNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            no_weights=no_weights,
            #init_weights=None,
            #use_context_mod=False,
            #context_mod_inputs=False,
            #no_last_layer_context_mod=False,
            #context_mod_no_weights=False,
            #context_mod_post_activation=False,
            #context_mod_gain_offset=False,
            #context_mod_apply_pixel_wise=False
            **mnet_kwargs).to(device)

    elif net_type == 'chunked_mlp':
        assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \
               hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \
               hc('cmlp_cemb_dim')
        assert len(in_shape) == 1 and len(out_shape) == 1

        # Default keyword arguments of class ChunkSqueezer.
        dkws = misc.get_default_args(ChunkSqueezer.__init__)

        mnet = ChunkSqueezer(
            n_in=in_shape[0],
            n_out=out_shape[0],
            inp_chunk_dim=gc('cmlp_in_cdim'),
            out_chunk_dim=gc('cmlp_out_cdim'),
            cemb_size=gc('cmlp_cemb_dim'),
            #cemb_init_std=1.,
            red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')),
            net_layers=misc.str_to_ints(gc('cmlp_arch')),
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            #dynamic_biases=None,
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm']),
            bn_track_stats=assign(not bn_no_running_stats,
                                  dkws['bn_track_stats']),
            distill_bn_stats=assign(bn_distill_stats,
                                    dkws['distill_bn_stats']),
            verbose=True,
            **mnet_kwargs).to(device)

    elif net_type == 'lenet':
        assert hc('lenet_type')
        assert len(out_shape) == 1

        # Default keyword arguments of class LeNet.
        dkws = misc.get_default_args(LeNet.__init__)

        mnet = LeNet(
            in_shape=in_shape,
            num_classes=out_shape[0],
            verbose=True,
            arch=gc('lenet_type'),
            no_weights=no_weights,
            #init_weights=None,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            # TODO Context-mod weights.
            **mnet_kwargs).to(device)

    else:
        assert (net_type == 'simple_rnn')
        assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \
            hc('srnn_post_fc_layers')  and hc('srnn_no_fc_out') and \
            hc('srnn_rec_type')
        assert len(in_shape) == 1 and len(out_shape) == 1

        if gc('srnn_rec_type') == 'lstm':
            use_lstm = True
        else:
            assert gc('srnn_rec_type') == 'elman'
            use_lstm = False

        # Default keyword arguments of class SimpleRNN.
        dkws = misc.get_default_args(SimpleRNN.__init__)

        rnn_layers = misc.str_to_ints(gc('srnn_rec_layers'))
        fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers'))
        if gc('srnn_no_fc_out'):
            rnn_layers.append(out_shape[0])
        else:
            fc_layers.append(out_shape[0])

        mnet = SimpleRNN(n_in=in_shape[0],
                         rnn_layers=rnn_layers,
                         fc_layers_pre=misc.str_to_ints(
                             gc('srnn_pre_fc_layers')),
                         fc_layers=fc_layers,
                         activation=assign(net_act, dkws['activation']),
                         use_lstm=use_lstm,
                         use_bias=assign(not no_bias, dkws['use_bias']),
                         no_weights=no_weights,
                         verbose=True,
                         **mnet_kwargs).to(device)

    return mnet
Ejemplo n.º 5
0
def get_hypernet(config,
                 device,
                 net_type,
                 target_shapes,
                 num_conds,
                 no_cond_weights=False,
                 no_uncond_weights=False,
                 uncond_in_size=0,
                 shmlp_chunk_shapes=None,
                 shmlp_num_per_chunk=None,
                 shmlp_assembly_fct=None,
                 verbose=True,
                 cprefix=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    Args:
        config (argparse.Namespace): Command-line arguments.

            Note:
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hnet_args`.
        device: PyTorch device.
        net_type (str): The type of network. The following options are
            available:

            - ``'hmlp'``
            - ``'chunked_hmlp'``
            - ``'structured_hmlp'``
            - ``'hdeconv'``
            - ``'chunked_hdeconv'``
        target_shapes (list): See argument ``target_shapes`` of
            :class:`hnets.mlp_hnet.HMLP`.
        num_conds (int): Number of conditions that should be known to the
            hypernetwork.
        no_cond_weights (bool): See argument ``no_cond_weights`` of
            :class:`hnets.mlp_hnet.HMLP`.
        no_uncond_weights (bool): See argument ``no_uncond_weights`` of
            :class:`hnets.mlp_hnet.HMLP`.
        uncond_in_size (int): See argument ``uncond_in_size`` of
            :class:`hnets.mlp_hnet.HMLP`.
        shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of
            :class:`hnets.structured_mlp_hnet.StructuredHMLP`.
        verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this function are prefixed, since several
            hypernetworks should be generated.

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hnet_args`.
    """
    assert net_type in [
        'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv'
    ]

    hnet = None

    ### FIXME Code almost identically copied from `get_mnet_model` ###
    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    def hc(name):
        """Check whether config exists."""
        return hasattr(config, '%s%s' % (cprefix, name))

    if hc('hnet_net_act'):
        net_act = gc('hnet_net_act')
        net_act = misc.str_to_act(net_act)
    else:
        net_act = None

    def get_val(name):
        ret = None
        if hc(name):
            ret = gc(name)
        return ret

    no_bias = get_val('hnet_no_bias')
    dropout_rate = get_val('hnet_dropout_rate')
    specnorm = get_val('hnet_specnorm')
    batchnorm = get_val('hnet_batchnorm')
    no_batchnorm = get_val('hnet_no_batchnorm')
    #bn_no_running_stats = get_val('hnet_bn_no_running_stats')
    #n_distill_stats = get_val('hnet_bn_distill_stats')

    use_bn = None
    if batchnorm is not None:
        use_bn = batchnorm
    elif no_batchnorm is not None:
        use_bn = not no_batchnorm

    # If an argument wasn't specified, then we use the default value that
    # is currently in the constructor.
    assign = lambda x, y: y if x is None else x
    ### FIXME Code copied until here                               ###

    if hc('hmlp_arch'):
        hmlp_arch_is_list = False
        hmlp_arch = gc('hmlp_arch')
        if ';' in hmlp_arch:
            hmlp_arch_is_list = True
            if net_type != 'structured_hmlp':
                raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) +
                                 'contain semicolons for network type ' +
                                 '"structured_hmlp"!')
            hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')]
        else:
            hmlp_arch = misc.str_to_ints(hmlp_arch)
    if hc('chunk_emb_size'):
        chunk_emb_size = gc('chunk_emb_size')
        chunk_emb_size = misc.str_to_ints(chunk_emb_size)
        if len(chunk_emb_size) == 1:
            chunk_emb_size = chunk_emb_size[0]
        else:
            if net_type != 'structured_hmlp':
                raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) +
                                 'only contain multiple values for network ' +
                                 'type "structured_hmlp"!')

    if hc('cond_emb_size'):
        cond_emb_size = gc('cond_emb_size')
    else:
        cond_emb_size = 0

    if net_type == 'hmlp':
        assert hc('hmlp_arch')

        # Default keyword arguments of class HMLP.
        dkws = misc.get_default_args(HMLP.__init__)

        hnet = HMLP(target_shapes,
                    uncond_in_size=uncond_in_size,
                    cond_in_size=cond_emb_size,
                    layers=hmlp_arch,
                    verbose=verbose,
                    activation_fn=assign(net_act, dkws['activation_fn']),
                    use_bias=assign(not no_bias, dkws['use_bias']),
                    no_uncond_weights=no_uncond_weights,
                    no_cond_weights=no_cond_weights,
                    num_cond_embs=num_conds,
                    dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
                    use_spectral_norm=assign(specnorm,
                                             dkws['use_spectral_norm']),
                    use_batch_norm=assign(use_bn,
                                          dkws['use_batch_norm'])).to(device)

    elif net_type == 'chunked_hmlp':
        assert hc('hmlp_arch')
        assert hc('chmlp_chunk_size')
        assert hc('chunk_emb_size')
        cond_chunk_embs = get_val('use_cond_chunk_embs')

        # Default keyword arguments of class ChunkedHMLP.
        dkws = misc.get_default_args(ChunkedHMLP.__init__)

        hnet = ChunkedHMLP(
            target_shapes,
            gc('chmlp_chunk_size'),
            chunk_emb_size=chunk_emb_size,
            cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']),
            uncond_in_size=uncond_in_size,
            cond_in_size=cond_emb_size,
            layers=hmlp_arch,
            verbose=verbose,
            activation_fn=assign(net_act, dkws['activation_fn']),
            use_bias=assign(not no_bias, dkws['use_bias']),
            no_uncond_weights=no_uncond_weights,
            no_cond_weights=no_cond_weights,
            num_cond_embs=num_conds,
            dropout_rate=assign(dropout_rate, dkws['dropout_rate']),
            use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']),
            use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device)

    elif net_type == 'structured_hmlp':
        assert hc('hmlp_arch')
        assert hc('chunk_emb_size')
        cond_chunk_embs = get_val('use_cond_chunk_embs')

        assert shmlp_chunk_shapes is not None and \
            shmlp_num_per_chunk is not None and \
            shmlp_assembly_fct is not None

        # Default keyword arguments of class StructuredHMLP.
        dkws = misc.get_default_args(StructuredHMLP.__init__)
        dkws_hmlp = misc.get_default_args(HMLP.__init__)

        shmlp_hmlp_kwargs = []
        if not hmlp_arch_is_list:
            hmlp_arch = [hmlp_arch]
        for i, arch in enumerate(hmlp_arch):
            shmlp_hmlp_kwargs.append({
                'layers': arch,
                'activation_fn': assign(net_act, dkws_hmlp['activation_fn']),
                'use_bias': assign(not no_bias, dkws_hmlp['use_bias']),
                'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']),
                'use_spectral_norm': \
                    assign(specnorm, dkws_hmlp['use_spectral_norm']),
                'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm'])
            })
        if len(shmlp_hmlp_kwargs) == 1:
            shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0]

        hnet = StructuredHMLP(target_shapes,
                              shmlp_chunk_shapes,
                              shmlp_num_per_chunk,
                              chunk_emb_size,
                              shmlp_hmlp_kwargs,
                              shmlp_assembly_fct,
                              cond_chunk_embs=assign(cond_chunk_embs,
                                                     dkws['cond_chunk_embs']),
                              uncond_in_size=uncond_in_size,
                              cond_in_size=cond_emb_size,
                              verbose=verbose,
                              no_uncond_weights=no_uncond_weights,
                              no_cond_weights=no_cond_weights,
                              num_cond_embs=num_conds).to(device)

    elif net_type == 'hdeconv':
        #HDeconv
        raise NotImplementedError
    else:
        assert net_type == 'chunked_hdeconv'
        #ChunkedHDeconv
        raise NotImplementedError

    return hnet
Ejemplo n.º 6
0
def _generate_networks(config,
                       data_handlers,
                       device,
                       create_hnet=True,
                       create_rnet=False,
                       no_replay=False):
    """Create the main-net, hypernetwork and recognition network.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        create_hnet: Whether a hypernetwork should be constructed. If not, the
            main network will have trainable weights.
        create_rnet: Whether a task-recognition autoencoder should be created.
        no_replay: If the recognition network should be an instance of class
            MainModel rather than of class RecognitionNet (note, for multitask
            learning, no replay network is required).

    Returns:
        mnet: Main network instance.
        hnet: Hypernetwork instance. This return value is None if no
            hypernetwork should be constructed.
        rnet: RecognitionNet instance. This return value is None if no
            recognition network should be constructed.
    """
    num_tasks = len(data_handlers)

    n_x = data_handlers[0].in_shape[0]
    n_y = data_handlers[0].out_shape[0]
    if config.multi_head:
        n_y = n_y * num_tasks

    main_arch = misc.str_to_ints(config.main_arch)
    main_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                            n_out=n_y,
                                            hidden_layers=main_arch)
    mnet = MainNetwork(main_shapes,
                       activation_fn=misc.str_to_act(config.main_act),
                       use_bias=True,
                       no_weights=create_hnet).to(device)
    if create_hnet:
        hnet_arch = misc.str_to_ints(config.hnet_arch)
        hnet = HyperNetwork(main_shapes,
                            num_tasks,
                            layers=hnet_arch,
                            te_dim=config.emb_size,
                            activation_fn=misc.str_to_act(
                                config.hnet_act)).to(device)
        init_params = list(hnet.parameters())
    else:
        hnet = None
        init_params = list(mnet.parameters())

    if create_rnet:
        ae_arch = misc.str_to_ints(config.ae_arch)
        if no_replay:
            rnet_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                                    n_out=num_tasks,
                                                    hidden_layers=ae_arch,
                                                    use_bias=True)
            rnet = MainNetwork(rnet_shapes,
                               activation_fn=misc.str_to_act(config.ae_act),
                               use_bias=True,
                               no_weights=False,
                               dropout_rate=-1,
                               out_fn=lambda x: F.softmax(x, dim=1))
        else:
            rnet = RecognitionNet(n_x,
                                  num_tasks,
                                  dim_z=config.ae_dim_z,
                                  enc_layers=ae_arch,
                                  activation_fn=misc.str_to_act(config.ae_act),
                                  use_bias=True).to(device)
        init_params += list(rnet.parameters())
    else:
        rnet = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        elif config.normal_init:
            torch.nn.init.normal_(W, mean=0, std=config.std_normal_init)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_hnet:
        for temb in hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if config.use_hyperfan_init:
        hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2)

    return mnet, hnet, rnet
def analyse_single_run(out_dir, device, writer, logger, analysis_kwd,
        get_loss_func, accuracy_func, generate_tasks_func, n_samples=-1,
        redo_analyses=False, do_kernel_pca=False, do_supervised_dimred=False,
        timesteps_for_analysis=None, copy_task=True, num_tasks=-1,
        sup_dimred_criterion=None, sup_dimred_args={}):
    """Analyse the hidden dimensionality for an individual run.

    Args:
        out_dir (str): The path to the output directory.
        device: The device.
        writer: The tensorboard writer.
        logger: The logger.
        analysis_kwd (dict): The dictionary containing important keywords for
            the current analysis.
        get_loss_func (func): A handler to generate the loss function.
        accuracy_func (func): A handler to the accuracy function.
        generate_tasks_func (func): A handler to a datahandler generator.
        redo_analyses (boolean, optional): If ``True``, analyses will be redone
            even if they had been stored previously.
        do_kernel_pca (bool, optional): If ``True``, kernel PCA will also be
            used to compute the number of hidden dimensions.
        do_supervised_dimred (bool, optional): If ``True``, supervised linear
            dimensionality reduction will be used to compute the number of
            task-relevant hidden dimensions.
        n_samples (int): The number of samples to be used.
        timesteps_for_analysis (str, optional): The timesteps to be used for the
            PCA analyses.
        copy_task (bool, optional): Indicates whether we are analysing the
            Copy Task or not.
        num_tasks (int, optional): The number of tasks to be considered.
        sup_dimred_criterion (int, optional): If provided, this value will 
            be used as stopping criterion when looking for the number of 
            necessary supervised components to describe the hidden activity.
        sup_dimred_args (dict): Optional arguments (e.g., optimization
            arguments) passed to the supervised dimensionality reduction
            :func:`sequential.ht_analyses.supervised_dimred_utils.\
get_loss_vs_supervised_n_dim`.

    Returns:
        (tuple): Tuple containing:

        - **results**: The dictionary of results for the current run.
        - **settings**: The dictionary with the values of the parameters that
          are specified in `analysis_kwd['fixed_params']`.

    """

    ### Prepare the data and the networks.
    # Load the config
    if not os.path.exists(out_dir):
        raise ValueError('The directory "%s" does not exist.'%out_dir)
    with open(os.path.join(out_dir, "config.pickle"), "rb") as f:
        config = pickle.load(f)
    # Overwrite the directory it it's not the same as the original.
    if config.out_dir != out_dir:
        config.out_dir = out_dir
    # Check for old command line arguments and make compatible with new version.
    config = train_args_sequential.update_cli_args(config)

    print('Working on output directory "%s".' % out_dir)

    # Overwrite the number of tasks.
    if num_tasks == -1:
        num_tasks = config.num_tasks

    if sup_dimred_criterion == -1:
        sup_dimred_criterion = None

    stop_bit=None
    if copy_task:
        # Get the index of the stop bit.
        #stop_bit = getattr(config, analysis_kwd['complexity_measure'])
        # If we do not enforce the condition below, we have to determine the
        # location of the stop bit on a sample-by-sample basis.
        assert config.input_len_step == 0 and config.input_len_variability == 0
        stop_bit = config.first_task_input_len
        if config.pad_after_stop:
            stop_bit = config.pat_len

    ### Sanity checks.
    # Do some sanity checks in the parameters.
    assert config.use_ewc or config.use_si
    if config.use_ewc:
        method = 'ewc'
    elif config.use_si:
        method = 'si'
    for key, value in analysis_kwd['forced_params']:
        assert getattr(config, key) == value
    # Ensure all runs have comparable properties
    if 'num_tasks' not in analysis_kwd['fixed_params']:
        analysis_kwd['fixed_params'].append('num_tasks')

    ### Create the settings dictionary.
    settings = {}
    for key in analysis_kwd['fixed_params']:
        settings[key] = getattr(config, key)
        if key == 'num_tasks':
            settings[key] = num_tasks

    ### Load or create the results dictionary.
    if os.path.exists(os.path.join(out_dir, "pca_results.pickle")) and \
            not redo_analyses:
        ### Load existing results.
        with open(os.path.join(out_dir, "pca_results.pickle"), "rb") as f:
            results = pickle.load(f)
        print('PCA analyses have been done and stored previously and reloaded.')
        assert num_tasks == -1 or results['num_tasks'] == num_tasks

        if 'mean_fisher' in results:
            results['mean_importance'] = results['mean_fisher']
            results['mean_importance_ho'] = results['mean_fisher_ho']
    else:
        ### Prepare the environment.
        # Define functions.
        task_loss_func = get_loss_func(config, device, logger)
        accuracy_func = accuracy_func
        # Generate datahandlers
        dhandlers = generate_tasks_func(config, logger, writer=writer)
        config.show_plots = True
        plc.visualise_data(dhandlers, config, device)
        # Generate the networks
        shared = argparse.Namespace()
        # FIXME might not work for all datasets (e.g., PoS tagging).
        shared.feature_size = dhandlers[0].in_shape[0]
        target_net, hnet, _ = stu.generate_networks(config, shared, dhandlers,
                                                    device)

        ### Initialize the results dictionary.
        results = {}
        if copy_task:
            results['masked'] = config.pat_len
            results['pad_after_stop'] = config.pad_after_stop
            results['accs_per_ts'] = []
            results['permutation'] = []
        results['expl_var_per_ts'] = []
        results['kexpl_var_per_ts'] = []
        results['expl_var_per_ts_yt'] = []
        results['kexpl_var_per_ts_yt'] = []
        results['complexity_measure'] = getattr(config, \
            analysis_kwd['complexity_measure'])
        results['complexity_measure_name'] = \
            analysis_kwd['complexity_measure_name']
        results['num_tasks'] = num_tasks
        results['final_acc'] = []
        results['final_loss'] = []
        results['mean_importance'] = []
        results['mean_importance_ho'] = []
        results['expl_var'] = []
        results['kexpl_var'] = []
        results['expl_var_yt'] = []
        results['kexpl_var_yt'] = []
        if do_supervised_dimred:
            # Note, in the code 'loss_n_dim_supervised' plays, for the
            # supervised dimensionality reduction, the same role as 'expl_var'
            # for the standard PCA analysis, i.e. we store the explained
            # variance (resp. loss) as a function of how many dimensions are
            # taken into account, and then select a threshold for the explained
            # variance (resp. loss) to determine the number of intrinsic
            # dimensions.
            results['loss_n_dim_supervised'] = []
            results['accu_n_dim_supervised'] = []
            if copy_task:
                results['accu_n_dim_sup_at_stop'] = []
                results['loss_n_dim_sup_at_stop'] = []

        # Iterate over all tasks and accumulate results in lists within the
        # results dictionary values.
        all_during_act = []
        all_during_act_yt = []
        for task_id in range(num_tasks):

            if copy_task:
                results['permutation'].append(dhandlers[task_id].permutation)

            ### Load the checkpointed during model for the corresponding task.
            # Note, the return values of the function below are just references
            # to the variables `target_net` and `hnet`, which are modified in-
            # place.
            mnet, hnet = load_models(out_dir, device, logger, target_net, hnet,
                wembs=None, task_id=task_id, method=method)
            # FIXME Should we disentangle weight matrices and bias vectors?
            hh_imp_values = get_importance_values(mnet, connection_type='hh',
                method=method)
            results['mean_importance'].append(np.mean(hh_imp_values))
            ho_imp_values = get_importance_values(mnet, connection_type='ho',
                method=method)
            if ho_imp_values != []:
                results['mean_importance_ho'].append(np.mean(ho_imp_values))
            else:
                results['mean_importance_ho'].append(np.nan)

            ### Obtain hidden activations and performances.
            # We only measure the final accuracy up to the current task, since
            # we are simulating a continual learning setting with less tasks.
            loss, accs, accs_per_ts = test(dhandlers, device, config, None,
                logger, writer, mnet, hnet, store_activations=True, \
                accuracy_func=accuracy_func, task_loss_func=task_loss_func,
                num_trained=task_id, return_acc_per_ts=True)
            results['final_loss'].append(np.mean(loss[:task_id+1]))
            if accs is None:
                results['final_acc'].append(None)
            else:
                results['final_acc'].append(np.mean(accs[:task_id+1]))
            if copy_task:
                results['accs_per_ts'].append(accs_per_ts[task_id])

            ### Load the internal activations.
            tasks_act, act = get_activations(out_dir, task_id=task_id,
                vanilla_rnn=config.use_vanilla_rnn)
            n_hidden = np.sum(misc.str_to_ints(config.rnn_arch))
            assert act.shape[-1] == n_hidden
            all_during_act.append(act)
            tasks_act_yt, act_yt = get_activations(out_dir, task_id=task_id,
                internal=False, vanilla_rnn=config.use_vanilla_rnn)
            all_during_act_yt.append(act_yt)

            ### Do PCA analyses.
            # Do analyses on internal recurrent activations.
            results = pca_analysis_single_task(act, results,
                do_kernel_pca=do_kernel_pca, n_samples=n_samples,
                timesteps=timesteps_for_analysis, stop_bit=stop_bit,
                do_supervised_dimred=do_supervised_dimred)

            # Do analyses on output recurrent activations.
            results = pca_analysis_single_task(act_yt, results,
                do_kernel_pca=do_kernel_pca, n_samples=n_samples,
                timesteps=timesteps_for_analysis, stop_bit=stop_bit,
                internal=False, do_supervised_dimred=do_supervised_dimred)

            if do_supervised_dimred:
                if not copy_task:
                    raise NotImplementedError('TODO need to adapt the ' +
                        'loss computation for tasks other than the Copy Task.')
                # Do supervised dimensionality reduction on during models.
                loss_dim, accu_dim = get_loss_vs_supervised_n_dim(mnet,
                        hnet, task_loss_func, accuracy_func, dhandlers, config,
                        device, task_id=task_id, criterion=sup_dimred_criterion,
                        writer_dir=out_dir, **sup_dimred_args)
                results['loss_n_dim_supervised'].append(loss_dim)
                results['accu_n_dim_supervised'].append(accu_dim)
                if copy_task:
                    loss_dim, accu_dim = get_loss_vs_supervised_n_dim(mnet,
                            hnet, task_loss_func, accuracy_func, dhandlers,
                            config, device, stop_timestep=stop_bit,
                            task_id=task_id, criterion=sup_dimred_criterion,
                            writer_dir=out_dir, **sup_dimred_args)
                    results['loss_n_dim_sup_at_stop'].append(loss_dim)
                    results['accu_n_dim_sup_at_stop'].append(accu_dim)

        ### Get hidden dimensionality using the final model.
        # Note, here we overwrite the files "int_activations.pickle" and
        # "activations.pickle" that were generated when testing the model of
        # the current task.
        os.remove(os.path.join(out_dir, 'int_activations.pickle'))
        os.remove(os.path.join(out_dir, 'activations.pickle'))
        mnet, hnet = load_models(out_dir, device, logger, target_net, hnet,
                                 wembs=None, method=method)
        _ = test(dhandlers, device, config, shared, logger, writer, mnet,
            hnet, store_activations=True,
            accuracy_func=accuracy_func, task_loss_func=task_loss_func,
            num_trained=task_id, return_acc_per_ts=True)

        # Load internal activations.
        tasks_act, act = get_activations(out_dir, task_id=task_id,
            vanilla_rnn=config.use_vanilla_rnn)
        tasks_act_yt, act_yt = get_activations(out_dir, task_id=task_id,
            internal=False, vanilla_rnn=config.use_vanilla_rnn)

        ### Do PCA analyses on final models.
        results = pca_analysis_all_tasks(act, all_during_act, results,
                do_kernel_pca=do_kernel_pca, n_samples=n_samples,
                timesteps=timesteps_for_analysis, stop_bit=stop_bit,
                copy_task=copy_task)
        results = pca_analysis_all_tasks(act_yt, all_during_act_yt, results,
                do_kernel_pca=do_kernel_pca, n_samples=n_samples,
                timesteps=timesteps_for_analysis, stop_bit=stop_bit,
                copy_task=copy_task, internal=False)

        if do_supervised_dimred and len(all_during_act) > 1:
            ### Do supervised dimensionality reduction on final models.
            # Only do if we dealt with more than one task.
            loss_dim, accu_dim = get_loss_vs_supervised_n_dim(mnet, hnet,
                    task_loss_func, accuracy_func, dhandlers, config, device,
                    criterion=sup_dimred_criterion, writer_dir=out_dir,
                    **sup_dimred_args)
            results['loss_n_dim_supervised_all_tasks'] = loss_dim
            results['accu_n_dim_supervised_all_tasks'] = accu_dim

        # Store pickle results.
        with open(os.path.join(out_dir, 'pca_results.pickle'), 'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return results, settings
Ejemplo n.º 8
0
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None,
                   no_weights=False, no_tembs=False, temb_size=None):
    """Generate a hypernetwork instance.

    A helper to generate the hypernetwork according to the given the user
    configurations.

    .. deprecated:: 1.0
        Please use function :func:`get_hypernet` instead. As this function
        creates deprecated hypernetworks.

    Args:
        config (argparse.Namespace): Command-line arguments.

            .. note::
                The function expects command-line arguments available according
                to the function :func:`utils.cli_args.hypernet_args`.
        num_tasks (int): The number of task embeddings the hypernetwork should
            have.
        device: PyTorch device.
        mnet_shapes: Dimensions of the weight tensors of the main network.
            See main net argument
            :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`.
        cprefix (str, optional): A prefix of the config names. It might be, that
            the config names used in this method are prefixed, since several
            hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or
            ``'dis_'`` when training a GAN).

            Also see docstring of parameter ``prefix`` in function
            :func:`utils.cli_args.hypernet_args`.
        no_weights (bool): Whether the hyper network should be generated without
            internal weights (excluding task embeddings).
        no_tembs (bool): Whether the hypernetwork should be generated without
            internally maintained task embeddings.
        temb_size (int, optional): If user config should be overwritten, then
            this option can be used to specify the dimensionality of task
            embeddings.

    Returns:
        The created hypernet model.
    """
    warn('Please use function "utils.sim_utils.get_hypernet" instead. As ' +\
         'this function creates deprecated hypernetworks.',
         DeprecationWarning)

    if cprefix is None:
        cprefix = ''

    def gc(name):
        """Get config value with that name."""
        return getattr(config, '%s%s' % (cprefix, name))

    hyper_chunks = misc.str_to_ints(gc('hyper_chunks'))
    assert(len(hyper_chunks) in [1,2,3])
    if len(hyper_chunks) == 1:
        hyper_chunks = hyper_chunks[0]

    hnet_arch = misc.str_to_ints(gc('hnet_arch'))
    sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters'))
    sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels'))
    sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers'))

    hnet_act = misc.str_to_act(gc('hnet_act'))

    if temb_size is None:
        temb_size = gc('temb_size')

    if isinstance(hyper_chunks, list): # Chunked self-attention hypernet
        raise NotImplementedError('Not publicly available')

    elif hyper_chunks != -1: # Chunked fully-connected hypernet
        hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks,
            chunk_dim=hyper_chunks, layers=hnet_arch,
            activation_fn=hnet_act, te_dim=temb_size, no_te_embs=no_tembs,
            ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'),
            noise_dim=gc('hnet_noise_dim'), no_weights=no_weights,
            temb_std=gc('temb_std')).to(device)

    else: # Fully-connected hypernet.
        hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch,
            te_dim=temb_size, no_te_embs=no_tembs, activation_fn=hnet_act,
            dropout_rate=gc('hnet_dropout_rate'),
            noise_dim=gc('hnet_noise_dim'), no_weights=no_weights,
            temb_std=gc('temb_std')).to(device)

    return hnet
Ejemplo n.º 9
0
def run(ref_module,
        results_dir='./out/random_seeds',
        config=None,
        ignore_kwds=None,
        forced_params=None):
    """Run the script.

    Args:
        ref_module (str): Name of the reference module which contains the 
            hyperparameter search config that can be modified to gather random 
            seeds.
        results_dir (str, optional): The path where to store the results.
        config: The Namespace object containing argument names and values.
            If provided, all random seeds will be gathered from zero, with no
            reference run.
        ignore_kwds (list, optional): The list of keywords in the config file
            to exclude from the grid.
        forced_params (dict, optional): Dict of key-value pairs specifying
            hyperparameter values that should be fixed across runs
    """
    if ignore_kwds is None:
        ignore_kwds = []
    if forced_params is None:
        forced_params = {}

    ### Parse the command-line arguments.
    parser = argparse.ArgumentParser(description= \
        'Gathering random seeds for the specified experiment.')
    parser.add_argument('--out_dir',
                        type=str,
                        default='',
                        help='The output directory of the run or runs. ' +
                        'For single runs, the configuration will be ' +
                        'loaded and run with different seeds.' +
                        'For multiple runs, i.e. results of ' +
                        'hyperparameter searches, the configuration ' +
                        'leading to the best mean final accuracy ' +
                        'will be selected and run with different seeds. ' +
                        'Default: %(default)s.')
    parser.add_argument('--config_name',
                        type=str,
                        default='hpsearch_random_seeds.py',
                        help='The name of the hpsearch config file. Since ' +
                        'multiple random seed gathering experiments ' +
                        'might be running in parallel, it is important ' +
                        'that this file has a unique name for each ' +
                        'experiment. Default: %(default)s.')
    parser.add_argument('--config_pickle',
                        type=str,
                        default='',
                        help='The path to a pickle file containing a run ' +
                        ' config that will be loaded.')
    parser.add_argument('--num_seeds',
                        type=int,
                        default=10,
                        help='The number of different random seeds.')
    # FIXME `None` is not a valid default value.
    parser.add_argument('--seeds_list',
                        type=str,
                        default=None,
                        help='The list of seeds to use. If specified, ' +
                        '"num_seeds" will be ignored.')
    parser.add_argument('--vary_data_seed',
                        action='store_true',
                        help='If activated, "data_random_seed"s are set ' +
                        'equal to "random_seed"s. Otherwise only ' +
                        '"random_seed"s are varied.')
    parser.add_argument('--num_tot_hours',
                        type=int,
                        metavar='N',
                        default=120,
                        help='If "run_cluster" is activated, then this ' +
                        'option determines the maximum number of hours ' +
                        'the entire search may run on the cluster. ' +
                        'Default: %(default)s.')
    # FIXME Arguments below are copied from hpsearch.
    parser.add_argument('--run_cluster',
                        action='store_true',
                        help='This option would produce jobs for a GPU ' +
                        'cluser running a job scheduler (see option ' +
                        '"scheduler".')
    parser.add_argument('--scheduler',
                        type=str,
                        default='lsf',
                        choices=['lsf', 'slurm'],
                        help='The job scheduler used on the cluster. ' +
                        'Default: %(default)s.')
    parser.add_argument('--num_jobs',
                        type=int,
                        metavar='N',
                        default=8,
                        help='If "run_cluster" is activated, then this ' +
                        'option determines the maximum number of jobs ' +
                        'that can be submitted in parallel. ' +
                        'Default: %(default)s.')
    parser.add_argument('--num_hours',
                        type=int,
                        metavar='N',
                        default=24,
                        help='If "run_cluster" is activated, then this ' +
                        'option determines the maximum number of hours ' +
                        'a job may run on the cluster. ' +
                        'Default: %(default)s.')
    parser.add_argument('--resources',
                        type=str,
                        default='"rusage[mem=8000, ngpus_excl_p=1]"',
                        help='If "run_cluster" is activated and "scheduler" ' +
                        'is "lsf", then this option determines the ' +
                        'resources assigned to job in the ' +
                        'hyperparameter search (option -R of bsub). ' +
                        'Default: %(default)s.')
    parser.add_argument('--slurm_mem',
                        type=str,
                        default='8G',
                        help='If "run_cluster" is activated and "scheduler" ' +
                        'is "slurm", then this value will be passed as ' +
                        'argument "mem" of "sbatch". An empty string ' +
                        'means that "mem" will not be specified. ' +
                        'Default: %(default)s.')
    parser.add_argument('--slurm_gres',
                        type=str,
                        default='gpu:1',
                        help='If "run_cluster" is activated and "scheduler" ' +
                        'is "slurm", then this value will be passed as ' +
                        'argument "gres" of "sbatch". An empty string ' +
                        'means that "gres" will not be specified. ' +
                        'Default: %(default)s.')
    parser.add_argument('--slurm_partition',
                        type=str,
                        default='',
                        help='If "run_cluster" is activated and "scheduler" ' +
                        'is "slurm", then this value will be passed as ' +
                        'argument "partition" of "sbatch". An empty ' +
                        'string means that "partition" will not be ' +
                        'specified. Default: %(default)s.')
    parser.add_argument('--slurm_qos',
                        type=str,
                        default='',
                        help='If "run_cluster" is activated and "scheduler" ' +
                        'is "slurm", then this value will be passed as ' +
                        'argument "qos" of "sbatch". An empty string ' +
                        'means that "qos" will not be specified. ' +
                        'Default: %(default)s.')
    parser.add_argument('--slurm_constraint',
                        type=str,
                        default='',
                        help='If "run_cluster" is activated and "scheduler" ' +
                        'is "slurm", then this value will be passed as ' +
                        'argument "constraint" of "sbatch". An empty ' +
                        'string means that "constraint" will not be ' +
                        'specified. Default: %(default)s.')
    parser.add_argument('--visible_gpus',
                        type=str,
                        default='',
                        help='If "run_cluster" is NOT activated, then this ' +
                        'option determines the CUDA devices visible to ' +
                        'the hyperparameter search. A string of comma ' +
                        'separated integers is expected. If the list is ' +
                        'empty, then all GPUs of the machine are used. ' +
                        'The relative memory usage is specified, i.e., ' +
                        'a number between 0 and 1. If "-1" is given, ' +
                        'the jobs will be executed sequentially and not ' +
                        'assigned to a particular GPU. ' +
                        'Default: %(default)s.')
    parser.add_argument('--allowed_load',
                        type=float,
                        default=0.5,
                        help='If "run_cluster" is NOT activated, then this ' +
                        'option determines the maximum load a GPU may ' +
                        'have such that another process may start on ' +
                        'it. The relative load is specified, i.e., a ' +
                        'number between 0 and 1. Default: %(default)s.')
    parser.add_argument('--allowed_memory',
                        type=float,
                        default=0.5,
                        help='If "run_cluster" is NOT activated, then this ' +
                        'option determines the maximum memory usage a ' +
                        'GPU may have such that another process may ' +
                        'start on it. Default: %(default)s.')
    parser.add_argument('--sim_startup_time',
                        type=int,
                        metavar='N',
                        default=60,
                        help='If "run_cluster" is NOT activated, then this ' +
                        'option determines the startup time of ' +
                        'simulations. If a job was assigned to a GPU, ' +
                        'then this time (in seconds) has to pass before ' +
                        'options "allowed_load" and "allowed_memory" ' +
                        'are checked to decide whether a new process ' +
                        'can be send to a GPU.Default: %(default)s.')
    parser.add_argument('--max_num_jobs_per_gpu',
                        type=int,
                        metavar='N',
                        default=1,
                        help='If "run_cluster" is NOT activated, then this ' +
                        'option determines the maximum number of jobs ' +
                        'per GPU that can be submitted in parallel. ' +
                        'Note, this script does not validate whether ' +
                        'other processes are already assigned to a GPU. ' +
                        'Default: %(default)s.')
    cmd_args = parser.parse_args()
    out_dir = cmd_args.out_dir

    if cmd_args.out_dir == '' and cmd_args.config_pickle != '':
        with open(cmd_args.config_pickle, "rb") as f:
            config = pickle.load(f)

    # Either a config or an experiment folder need to be provided.
    assert config is not None or cmd_args.out_dir != ''
    if cmd_args.out_dir == '':
        out_dir = config.out_dir

    # Make sure that the provided hpsearch config file name does not exist.
    config_name = cmd_args.config_name
    if config_name[-3:] != '.py':
        config_name = config_name + '.py'
    if os.path.exists(config_name):
        overwrite = input('The config file "%s" '% config_name + \
            'already exists! Do you want to overwrite the file? [y/n] ')
        if not overwrite in ['yes', 'y', 'Y']:
            exit()

    # The following ensures that we can safely use `basename` later on.
    out_dir = os.path.normpath(out_dir)

    ### Create directory for results.
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    # Define a subfolder for the current random seed runs.
    results_dir = os.path.join(results_dir, os.path.basename(out_dir))
    print('Random seeds will be gathered in folder %s.' % results_dir)

    if os.path.exists(results_dir):

        # If random seeds have been gathered already, simply get the results for
        # publication.
        write_seeds_summary(results_dir)

        raise RuntimeError('Output directory %s already exists! ' %results_dir+\
            'seems like random seeds already have been gathered.')

    ### Get the experiments config.
    num_seeds = cmd_args.num_seeds
    if config is None:

        # Check if the current directory corresponds to a single run or not.
        # FIXME quick and dirty solution to figure out, whether it's a single
        # run.
        single_run = False
        if not os.path.exists(os.path.join(out_dir, 'search_results.csv')) \
                and not os.path.exists(os.path.join(out_dir, \
                    'postprocessing_results.csv')):
            single_run = True

        # Get the configuration.
        if single_run:
            config = get_single_run_config(out_dir)
            best_out_dir = out_dir
        else:
            config, best_out_dir = get_hpsearch_config(out_dir)

        # Since we already have a reference run, we can run one seed less.
        num_seeds -= 1

    if cmd_args.seeds_list is not None:
        seeds_list = misc.str_to_ints(cmd_args.seeds_list)
        cmd_args.num_seeds = len(seeds_list)
    else:
        seeds_list = list(range(num_seeds))

    # Replace config values provided via `forced_params`.
    if len(forced_params.keys()) > 0:
        for kwd, value in forced_params.items():
            setattr(config, kwd, value)

    ### Write down the hp search grid module in its own file.
    ref_module_basename = ref_module[[i for i,e in \
        enumerate(ref_module) if e == '.'][-1]+1:]
    ref_module_path = ref_module[:[i for i,e in \
        enumerate(ref_module) if e == '.'][-1]+1]
    shutil.copy(ref_module_basename + '.py', config_name)

    # Define the kwds to be added to the grid.
    kwds = list(vars(config).keys())
    for kwd in ignore_kwds:
        if kwd in kwds:
            kwds.remove(kwd)

    # Remove old grid and write new grid, and remove conditions.
    grid_loc = delete_object_from_text(config_name, 'grid', '{', '}')
    random_seeds = write_new_grid_to_text(config_name, config, grid_loc, \
        seeds_list, cmd_args, kwds=kwds)
    cond_loc = delete_object_from_text(config_name, 'conditions', \
        '[', ']')
    write_new_conditions_to_text(config_name, cond_loc, random_seeds, cmd_args)

    ### Run the hpsearch code with different random seeds.
    hpsearch_module = ref_module_path + config_name[:-3]
    cmd_str = get_command_line(hpsearch_module, results_dir, cmd_args)
    print(cmd_str)

    if cmd_args.run_cluster and cmd_args.scheduler == 'slurm':
        # FIXME hacky solution to write SLURM job script.
        # FIXME might be wrong to give the same `slurm_qos` to the hpsearch,
        # as the job might have to run much longer.
        job_script_fn = hpsearch._write_slurm_script(
            Namespace(
                **{
                    'num_hours': cmd_args.num_tot_hours,
                    'slurm_mem': '8G',
                    'slurm_gres': '',
                    'slurm_partition': cmd_args.slurm_partition,
                    'slurm_qos': cmd_args.slurm_qos,
                    'slurm_constraint': cmd_args.slurm_constraint,
                }), cmd_str, 'random_seeds')

        cmd_str = 'sbatch %s' % job_script_fn
        print('We will execute command "%s".' % cmd_str)

    # Execute the program.
    print('Starting gathering random seeds...')
    ret = call(cmd_str, shell=True, executable='/bin/bash')
    print('Call finished with return code %d.' % ret)

    ### Add results of the reference run to our results folder.
    new_best_out_dir = os.path.join(results_dir, os.path.basename(out_dir))
    copy_tree(best_out_dir, new_best_out_dir)

    ### Store results of given run in CSV file.
    # FIXME Extremely ugly solution.
    imported_grid_module = importlib.import_module(hpsearch_module)
    hpsearch._read_config(imported_grid_module)

    results_file = os.path.join(results_dir, 'search_results.csv')
    cmd_dict = dict()
    for k in kwds:
        cmd_dict[k] = getattr(config, k)

    # Get training results.
    performance_dict = hpsearch._SUMMARY_PARSER_HANDLE(new_best_out_dir, -1)
    for k, v in performance_dict.items():
        cmd_dict[k] = v

    # Create or update the CSV file summarizing all runs.
    panda_frame = pd.DataFrame.from_dict(cmd_dict)
    if os.path.isfile(results_file):
        old_frame = pd.read_csv(results_file, sep=';')
        panda_frame = pd.concat([old_frame, panda_frame], sort=True)
    panda_frame.to_csv(results_file, sep=';', index=False)

    # Create a text file aggregating all results for publication.
    write_seeds_summary(results_dir)
Ejemplo n.º 10
0
def generate_replay_networks(config,
                             data_handlers,
                             device,
                             create_rp_hnet=True,
                             only_train_replay=False):
    """Create a replay model that consists of either a encoder/decoder or
    a discriminator/generator pair. Additionally, this method manages the 
    creation of a hypernetwork for the generator/decoder. 
    Following important configurations will be determined in order to create
    the replay model: 
    * in- and output and hidden layer dimensions of the encoder/decoder. 
    * architecture, chunk- and task-embedding details of decoder's hypernetwork. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device..
        create_rp_hnet: Whether a hypernetwork for the replay should be 
            constructed. If not, the decoder/generator will have 
            trainable weights on its own.
        only_train_replay: We normally do not train on the last task since we do 
            not need to replay this last tasks data. But if we want a replay 
            method to be able to generate data from all tasks then we set this 
            option to true.

    Returns:
        (tuple): Tuple containing:

        - **enc**: Encoder/discriminator network instance.
        - **dec**: Decoder/generator networkinstance.
        - **dec_hnet**: Hypernetwork instance for the decoder/generator. This 
            return value is None if no hypernetwork should be constructed.
    """

    if config.replay_method == 'gan':
        n_out = 1
    else:
        n_out = config.latent_dim * 2

    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2
    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.single_class_replay:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.infer_task_id:
        # task inference network
        config.out_dim = 1

    # builld encoder
    print('For the replay encoder/discriminator: ')
    enc_arch = misc.str_to_ints(config.enc_fc_arch)
    enc = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=enc_arch,
              activation_fn=misc.str_to_act(config.enc_net_act),
              dropout_rate=config.enc_dropout_rate,
              no_weights=False).to(device)
    print('Constructed MLP with shapes: ', enc.param_shapes)
    init_params = list(enc.parameters())
    # builld decoder
    print('For the replay decoder/generator: ')
    dec_arch = misc.str_to_ints(config.dec_fc_arch)
    # add dimensions for conditional input
    n_out = config.latent_dim

    if config.conditional_replay:
        n_out += config.conditional_dim

    dec = MLP(n_in=n_out,
              n_out=n_in,
              hidden_layers=dec_arch,
              activation_fn=misc.str_to_act(config.dec_net_act),
              use_bias=True,
              no_weights=config.rp_beta > 0,
              dropout_rate=config.dec_dropout_rate).to(device)

    print('Constructed MLP with shapes: ', dec.param_shapes)
    config.num_weights_enc = \
                        MainNetInterface.shapes_to_num_weights(enc.param_shapes)

    config.num_weights_dec = \
                        MainNetInterface.shapes_to_num_weights(dec.param_shapes)
    config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec
    # we do not need a replay model for the last task

    # train on last task or not
    if only_train_replay:
        subtr = 0
    else:
        subtr = 1

    num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1

    if config.single_class_replay:
        # we do not need a replay model for the last task
        if config.num_tasks > 1:
            num_embeddings = config.out_dim * (config.num_tasks - subtr)
        else:
            num_embeddings = config.out_dim * (config.num_tasks)

    config.num_embeddings = num_embeddings
    # build decoder hnet
    if create_rp_hnet:
        print('For the decoder/generator hypernetwork: ')
        d_hnet = sim_utils.get_hnet_model(config,
                                          num_embeddings,
                                          device,
                                          dec.hyper_shapes_learned,
                                          cprefix='rp_')

        init_params += list(d_hnet.parameters())

        config.num_weights_rp_hyper_net = sum(p.numel()
                                              for p in d_hnet.parameters()
                                              if p.requires_grad)
        config.compression_ratio_rp = config.num_weights_rp_hyper_net / \
                                                        config.num_weights_dec
        print('Created replay hypernetwork with ratio: ',
              config.compression_ratio_rp)
        if config.compression_ratio_rp > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network,\nthis might not be directly ' +
                  'comparable with the number of parameters of methods we ' +
                  'compare against.')
    else:
        num_embeddings = config.num_tasks - subtr
        d_hnet = None
        init_params += list(dec.parameters())
        config.num_weights_rp_hyper_net = 0
        config.compression_ratio_rp = 0

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_rp_hnet:
        for temb in d_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(d_hnet, 'chunk_embeddings'):
        for emb in d_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    return enc, dec, d_hnet
Ejemplo n.º 11
0
def generate_gauss_networks(config,
                            logger,
                            data_handlers,
                            device,
                            no_mnet_weights=None,
                            create_hnet=True,
                            in_shape=None,
                            out_shape=None,
                            net_type='mlp',
                            non_gaussian=False):
    """Create main network and potentially the corresponding hypernetwork.

    The function will first create a normal MLP and then convert it into a
    network with Gaussian weight distribution by using the wrapper
    :class:`probabilistic.gauss_mnet_interface.GaussianBNNWrapper`.

    This function also takes care of weight initialization.

    Args:
        config: Command-line arguments.
        logger: Console (and file) logger.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        no_mnet_weights (bool, optional): Whether the main network should not
            have trainable weights. If left unspecified, then the main network
            will only have trainable weights if ``create_hnet`` is ``False``.
        create_hnet (bool): Whether a hypernetwork should be constructed.
        in_shape (list, optional): Input shape that is passed to function
            :func:`utils.sim_utils.get_mnet_model` as argument ``in_shape``.
            If not specified, it is set to ``[data_handlers[0].in_shape[0]]``.
        out_shape (list, optional): Output shape that is passed to function
            :func:`utils.sim_utils.get_mnet_model` as argument ``out_shape``.
            If not specified, it is set to ``[data_handlers[0].out_shape[0]]``.
        net_type (str): See argument ``net_type`` of function
            :func:`utils.sim_utils.get_mnet_model`.
        non_gaussian (bool): If ``True``, then the main network will not be
            converted into a Gaussian network. Hence, networks remain
            deterministic.

    Returns:
        (tuple): Tuple containing:

        - **mnet**: Main network instance.
        - **hnet** (optional): Hypernetwork instance. This return value is
          ``None`` if no hypernetwork should be constructed.
    """
    assert not hasattr(config, 'mean_only') or config.mean_only == non_gaussian
    assert not non_gaussian or not config.local_reparam_trick
    assert not non_gaussian or not config.hyper_gauss_init

    num_tasks = len(data_handlers)

    # Should be set, except for regression.
    if in_shape is None or out_shape is None:
        assert in_shape is None and out_shape is None
        assert net_type == 'mlp'
        assert hasattr(config, 'multi_head')

        n_x = data_handlers[0].in_shape[0]
        n_y = data_handlers[0].out_shape[0]
        if config.multi_head:
            n_y = n_y * num_tasks

        in_shape = [n_x]
        out_shape = [n_y]

    ### Main network.
    logger.info('Creating main network ...')

    if no_mnet_weights is None:
        no_mnet_weights = create_hnet

    if config.local_reparam_trick:
        if net_type != 'mlp':
            raise NotImplementedError('The local reparametrization trick is ' +
                                      'only implemented for MLPs so far!')
            assert len(in_shape) == 1 and len(out_shape) == 1

        mlp_arch = utils.str_to_ints(config.mlp_arch)
        net_act = utils.str_to_act(config.net_act)
        mnet = GaussianMLP(n_in=in_shape[0],
                           n_out=out_shape[0],
                           hidden_layers=mlp_arch,
                           activation_fn=net_act,
                           use_bias=not config.no_bias,
                           no_weights=no_mnet_weights).to(device)
    else:
        mnet_kwargs = {}
        if net_type == 'iresnet':
            mnet_kwargs['cutout_mod'] = True
        mnet = sutils.get_mnet_model(config,
                                     net_type,
                                     in_shape,
                                     out_shape,
                                     device,
                                     no_weights=no_mnet_weights,
                                     **mnet_kwargs)

    # Initiaize main net weights, if any.
    assert (not hasattr(config, 'custom_network_init'))
    mnet.custom_init(normal_init=config.normal_init,
                     normal_std=config.std_normal_init,
                     zero_bias=True)

    # Convert main net into Gaussian BNN.
    orig_mnet = mnet
    if not non_gaussian:
        mnet = GaussianBNNWrapper(mnet,
                                  no_mean_reinit=config.keep_orig_init,
                                  logvar_encoding=config.use_logvar_enc,
                                  apply_rho_offset=True,
                                  is_radial=config.radial_bnn).to(device)
    else:
        logger.debug('Created main network will not be converted into a ' +
                     'Gaussian main network.')

    ### Hypernet.
    hnet = None
    if create_hnet:
        logger.info('Creating hypernetwork ...')

        chunk_shapes, num_per_chunk, assembly_fct = None, None, None
        if config.hnet_type == 'structured_hmlp':
            if net_type == 'resnet':
                chunk_shapes, num_per_chunk, orig_assembly_fct = \
                    resnet_chunking(orig_mnet,
                                    gcd_chunking=config.shmlp_gcd_chunking)
            elif net_type == 'wrn':
                chunk_shapes, num_per_chunk, orig_assembly_fct = \
                    wrn_chunking(orig_mnet,
                        gcd_chunking=config.shmlp_gcd_chunking,
                        ignore_bn_weights=False, ignore_out_weights=False)
            else:
                raise NotImplementedError(
                    '"structured_hmlp" not implemented ' +
                    'for network of type %s.' % net_type)

            if non_gaussian:
                assembly_fct = orig_assembly_fct
            else:
                chunk_shapes = chunk_shapes + chunk_shapes
                num_per_chunk = num_per_chunk + num_per_chunk

                def assembly_fct_gauss(list_of_chunks):
                    n = len(list_of_chunks)
                    mean_chunks = list_of_chunks[:n // 2]
                    rho_chunks = list_of_chunks[n // 2:]

                    return orig_assembly_fct(mean_chunks) + \
                        orig_assembly_fct(rho_chunks)

                assembly_fct = assembly_fct_gauss

        # For now, we either produce all or no weights with the hypernet.
        # Note, it can be that the mnet was produced with internal weights.
        assert mnet.hyper_shapes_learned is None or \
            len(mnet.param_shapes) == len(mnet.hyper_shapes_learned)

        hnet = sutils.get_hypernet(config,
                                   device,
                                   config.hnet_type,
                                   mnet.param_shapes,
                                   num_tasks,
                                   shmlp_chunk_shapes=chunk_shapes,
                                   shmlp_num_per_chunk=num_per_chunk,
                                   shmlp_assembly_fct=assembly_fct)

        if config.hnet_out_masking != 0:
            logger.info('Generating binary masks to select task-specific ' +
                        'subnetworks from hypernetwork.')
            # Add a wrapper around the hypernpetwork that masks its outputs
            # using a task-specific binary mask layer per layer. Note that
            # output weights are not masked.

            # Ensure that masks are kind of deterministic for a given hyper-
            # param config/task.
            mask_gen = torch.Generator()
            mask_gen = mask_gen.manual_seed(42)

            # Generate a random binary mask per task.
            assert len(mnet.param_shapes) == len(hnet.target_shapes)
            hnet_out_masks = []
            for tid in range(config.num_tasks):
                hnet_out_mask = []
                for layer_shapes, is_output in zip(mnet.param_shapes, \
                        mnet.get_output_weight_mask()):
                    layer_mask = torch.ones(layer_shapes)
                    if is_output is None:
                        # We only mask weights that are not output weights.
                        layer_mask = torch.rand(layer_shapes,
                                                generator=mask_gen)
                        layer_mask[layer_mask > config.hnet_out_masking] = 1
                        layer_mask[layer_mask <= config.hnet_out_masking] = 0
                    hnet_out_mask.append(layer_mask)
                hnet_out_masks.append(hnet_out_mask)

            hnet_out_masks = hnet.convert_out_format(hnet_out_masks,
                                                     'sequential', 'flattened')

            def hnet_out_masking_func(hnet_out_int,
                                      uncond_input=None,
                                      cond_input=None,
                                      cond_id=None):
                assert isinstance(cond_id, (int, list))
                if isinstance(cond_id, int):
                    cond_id = [cond_id]

                hnet_out_int[hnet_out_masks[cond_id, :] == 0] = 0
                return hnet_out_int

            def hnet_inp_handler(uncond_input=None,
                                 cond_input=None,
                                 cond_id=None):  # Identity
                return uncond_input, cond_input, cond_id

            hnet = HPerturbWrapper(hnet,
                                   output_handler=hnet_out_masking_func,
                                   input_handler=hnet_inp_handler)

        #if config.hnet_type == 'structured_hmlp':
        #    print(num_per_chunk)
        #    for ii, int_hnet in enumerate(hnet.internal_hnets):
        #        print('   Internal hnet %d with %d outputs.' % \
        #              (ii, int_hnet.num_outputs))

        ### Initialize hypernetwork.
        if not config.hyper_gauss_init:
            apply_custom_hnet_init(config, logger, hnet)
        else:
            # Initialize task embeddings, if any.
            hnet_helpers.init_conditional_embeddings(
                hnet, normal_std=config.std_normal_temb)

            gauss_hyperfan_init(hnet,
                                mnet=mnet,
                                use_xavier=True,
                                cond_var=config.std_normal_temb**2,
                                keep_hyperfan_mean=config.keep_orig_init)

    return mnet, hnet
Ejemplo n.º 12
0
def run(grid_module=None,
        results_dir='./out/random_seeds',
        config=None,
        ignore_kwds=None,
        forced_params=None,
        summary_keys=None,
        summary_sem=False,
        summary_precs=None,
        hpmod_path=None):
    """Run the script.

    Args:
        grid_module (str, optional): Name of the reference module which contains
            the hyperparameter search config that can be modified to gather
            random seeds.
        results_dir (str, optional): The path where the hpsearch should store
            its results.
        config: The Namespace object containing argument names and values.
            If provided, all random seeds will be gathered from zero, with no
            reference run.
        ignore_kwds (list, optional): A list of keywords in the config file
            to exclude from the grid.
        forced_params (dict, optional): Dict of key-value pairs specifying
            hyperparameter values that should be fixed across runs.
        summary_keys (list, optional): If provided, those mean and std of those
            summary keys will be written by function
            :func:`write_seeds_summary`. Otherwise, the performance key defined
            in ``grid_module`` will be used.
        summary_sem (bool): Whether SEM or SD should be calculated in function
            :func:`write_seeds_summary`.
        summary_precs (list or int, optional): The precision with which the
            summary statistics according to ``summary_keys`` should be listed.
        hpmod_path (str, optional): If the hpsearch doesn't reside in the same
            directory as the calling script, then we need to know from where to
            start the hpsearch.
    """
    if ignore_kwds is None:
        ignore_kwds = []
    if forced_params is None:
        forced_params = {}

    ### Parse the command-line arguments.
    parser = argparse.ArgumentParser(description= \
        'Gathering random seeds for the specified experiment.')
    parser.add_argument('--seeds_dir',
                        type=str,
                        default='',
                        help='If provided, all other arguments (except ' +
                        '"grid_module") are ignored! ' +
                        'This is supposed to be the output folder of a ' +
                        'random seed gathering experiment. If provided, ' +
                        'the results (for different seeds) within this ' +
                        'directory are gathered and written to a human-' +
                        'readible text file.')
    parser.add_argument('--run_dir',
                        type=str,
                        default='',
                        help='The output directory of a simulation or a ' +
                        'hyperparameter search. '
                        'For single runs, the configuration will be ' +
                        'loaded and run with different seeds.' +
                        'For multiple runs, i.e. results of ' +
                        'hyperparameter searches, the configuration ' +
                        'leading to the best performance will be ' +
                        'selected and run with different seeds.')
    parser.add_argument('--config_name',
                        type=str,
                        default='hpsearch_random_seeds',
                        help='A name for this call of gathering random ' +
                        'seeds. As multiple gatherings might be running ' +
                        'in parallel, it is important that this name is ' +
                        'unique name for each experiment. ' +
                        'Default: %(default)s.')
    parser.add_argument('--grid_module', type=str, default=grid_module,
                        help='See CLI argument "grid_module" of ' +
                             'hyperparameter search script "hpsearch". ' +
                             ('Default: %(default)s.' \
                              if grid_module is not None else ''))
    parser.add_argument('--num_seeds',
                        type=int,
                        default=10,
                        help='The number of different random seeds.')
    parser.add_argument('--seeds_list',
                        type=str,
                        default='',
                        help='The list of seeds to use. If specified, ' +
                        '"num_seeds" will be ignored.')
    parser.add_argument('--vary_data_seed',
                        action='store_true',
                        help='If activated, "data_random_seed"s are set ' +
                        'equal to "random_seed"s. Otherwise only ' +
                        '"random_seed"s are varied.')
    parser.add_argument('--start_gathering',
                        action='store_true',
                        help='If activated, the actual gathering of random ' +
                        'seeds is started via the "hpsearch.py" script.')

    # Arguments only required if `start_gathering`.
    hpgroup = parser.add_argument_group('Hpsearch call options')
    hpgroup.add_argument('--hps_num_hours',
                         type=int,
                         metavar='N',
                         default=24,
                         help='If "run_cluster" is activated, then this ' +
                         'option determines the maximum number of hours ' +
                         'the entire search may run on the cluster. ' +
                         'Default: %(default)s.')
    hpgroup.add_argument(
        '--hps_resources',
        type=str,
        default='"rusage[mem=8000]"',
        help='If "run_cluster" is activated and "scheduler" ' +
        'is "lsf", then this option determines the ' +
        'resources assigned to the entire ' +
        'hyperparameter search (option -R of bsub). ' +
        'Default: %(default)s.')
    hpgroup.add_argument('--hps_slurm_mem',
                         type=str,
                         default='8G',
                         help='See option "slum_mem". This argument effects ' +
                         'hyperparameter search itself. '
                         'Default: %(default)s.')

    rsgroup = parser.add_argument_group('Random seed hpsearch options')
    hpsearch.hpsearch_cli_arguments(rsgroup,
                                    show_out_dir=False,
                                    show_grid_module=False)
    cmd_args = parser.parse_args()

    grid_module = cmd_args.grid_module
    if grid_module is None:
        raise ValueError('"grid_module" needs to be specified.')
    grid_module = importlib.import_module(grid_module)
    hpsearch._read_config(grid_module, require_perf_eval_handle=True)

    if summary_keys is None:
        summary_keys = [hpsearch._PERFORMANCE_KEY]

    ####################################################
    ### Aggregate results of random seed experiments ###
    ####################################################
    if len(cmd_args.seeds_dir):
        print('Writing seed summary ...')
        write_seeds_summary(cmd_args.seeds_dir, summary_keys, summary_sem,
                            summary_precs)
        exit(0)

    #######################################################
    ### Create hp config grid for random seed gathering ###
    #######################################################
    if len(cmd_args.seeds_list) > 0:
        seeds_list = misc.str_to_ints(cmd_args.seeds_list)
        cmd_args.num_seeds = len(seeds_list)
    else:
        seeds_list = list(range(cmd_args.num_seeds))

    if config is not None and cmd_args.run_dir != '':
        raise ValueError('"run_dir" may not be specified if configuration ' +
                         'is provided directly.')

    # The directory in which the hpsearch results should be written. Will only
    # be specified if the `config` is read from a finished simulation.
    hpsearch_dir = None
    # Get config if not provided.
    if config is None:
        if not os.path.exists(cmd_args.run_dir):
            raise_error = True
            # FIXME hacky solution.
            if cmd_args.run_cwd != '':
                tmp_dir = os.path.join(cmd_args.run_cwd, cmd_args.run_dir)
                if os.path.exists(tmp_dir):
                    cmd_args.run_dir = tmp_dir
                    raise_error = False
            if raise_error:
                raise ValueError('Directory "%s" does not exist!' % \
                                 cmd_args.run_dir)

        # FIXME A bit of a shady decision.
        single_run = False
        if os.path.exists(os.path.join(cmd_args.run_dir, 'config.pickle')):
            single_run = True

        # Get the configuration.
        if single_run:
            config = get_single_run_config(cmd_args.run_dir)
            run_dir = cmd_args.run_dir
        else:
            config, run_dir = get_best_hpsearch_config(cmd_args.run_dir)

        # We should already have one random seed.
        try:
            performance_dict = hpsearch._SUMMARY_PARSER_HANDLE(run_dir, -1)
            has_finished = int(performance_dict['finished'][0])
            if not has_finished:
                raise Exception()

            use_run = True

        except:
            use_run = False

        if use_run:
            # The following ensures that we can safely use `basename` later on.
            run_dir = os.path.normpath(run_dir)

            if not os.path.isabs(results_dir):
                if os.path.isdir(cmd_args.run_cwd):
                    results_dir = os.path.join(cmd_args.run_cwd, results_dir)
            results_dir = os.path.abspath(results_dir)
            hpsearch_dir = os.path.join(results_dir, os.path.basename(run_dir))

            if os.path.exists(hpsearch_dir):
                # TODO attempt to write summary and exclude existing seeds.
                warn('Folder "%s" already exists.' % hpsearch_dir)
                print('Attempting to aggregate random seed results ...')

                gathered_seeds = write_seeds_summary(hpsearch_dir,
                                                     summary_keys,
                                                     summary_sem,
                                                     summary_precs,
                                                     ret_seeds=True)

                if len(gathered_seeds) >= len(seeds_list):
                    print('Already enough seeds have been gathered!')
                    exit(0)

                for gs in gathered_seeds:
                    if gs in seeds_list:
                        seeds_list.remove(gs)
                    else:
                        ignored_seed = seeds_list.pop()
                        if len(cmd_args.seeds_list) > 0:
                            print('Seed %d is ignored as seed %d already ' \
                                  % (ignored_seed, gs) + 'exists.')

            else:
                os.makedirs(hpsearch_dir)
                # We utilize the already existing random seed.
                shutil.copytree(
                    run_dir,
                    os.path.join(hpsearch_dir, os.path.basename(run_dir)))
                if config.random_seed in seeds_list:
                    seeds_list.remove(config.random_seed)
                else:
                    ignored_seed = seeds_list.pop()
                    if len(cmd_args.seeds_list) > 0:
                        print('Seed %d is ignored as seed %d already exists.' \
                              % (ignored_seed, config.random_seed))

    print('%d random seeds will be gathered!' % len(seeds_list))

    ### Which attributes of the `config` should be ignored?
    # We never set the ouput directory.
    if hpsearch._OUT_ARG not in ignore_kwds:
        ignore_kwds.append(hpsearch._OUT_ARG)

    for kwd in ignore_kwds:
        delattr(config, kwd)

    ### Replace config values provided via `forced_params`.
    if len(forced_params.keys()) > 0:
        for kwd, value in forced_params.items():
            setattr(config, kwd, value)

    ### Get a filename for where to store the search grid.
    config_dn, config_bn = os.path.split(cmd_args.config_name)
    if len(config_dn) == 0:  # No relative path given, store only temporary.
        config_dn = tempfile.gettempdir()
    else:
        config_dn = os.path.abspath(config_dn)
    config_fn_prefix = os.path.splitext(config_bn)[0]
    config_name = os.path.join(config_dn, config_fn_prefix + '.pickle')
    if os.path.exists(config_name):
        if len(config_dn) > 0:
            overwrite = input('The config file "%s" ' % config_name + \
                'already exists! Do you want to overwrite the file? [y/n] ')
            if not overwrite in ['yes', 'y', 'Y']:
                exit(1)
        else:  # Get random temporary filename.
            config_name_temp = tempfile.NamedTemporaryFile( \
                prefix=config_fn_prefix, suffix=".pickle")
            print('Search grid "%s" already exists, using name "%s" instead!' \
                  % (config_name, config_name_temp.name))
            config_name = config_name_temp.name
            config_name_temp.close()

    ### Build and store hpconfig for random seed gathering!
    grid, conditions = build_grid_and_conditions(cmd_args, config, seeds_list)

    rseed_config = {'grid': grid, 'conditions': conditions}
    with open(config_name, 'wb') as f:
        pickle.dump(rseed_config, f)

    ### Gather random seeds.
    if cmd_args.start_gathering:

        cmd_str = get_hpsearch_call(cmd_args,
                                    len(seeds_list),
                                    config_name,
                                    hpsearch_dir=hpsearch_dir)
        print(cmd_str)

        ### Start hpsearch.
        if hpmod_path is not None:
            backup_curr_path = os.getcwd()
            os.chdir(hpmod_path)
        if cmd_args.run_cluster and cmd_args.scheduler == 'slurm':
            # FIXME hacky solution to write SLURM job script.
            # FIXME might be wrong to give the same `slurm_qos` to the hpsearch,
            # as the job might have to run much longer.
            job_script_fn = hpsearch._write_slurm_script(
                Namespace(
                    **{
                        'num_hours': cmd_args.hps_num_hours,
                        'slurm_mem': cmd_args.hps_slurm_mem,
                        'slurm_gres': 'gpu:0',
                        'slurm_partition': cmd_args.slurm_partition,
                        'slurm_qos': cmd_args.slurm_qos,
                        'slurm_constraint': cmd_args.slurm_constraint,
                    }), cmd_str, 'random_seeds')

            cmd_str = 'sbatch %s' % job_script_fn
            print('We will execute command "%s".' % cmd_str)

        # Execute the program.
        print('Starting gathering random seeds...')
        ret = call(cmd_str, shell=True, executable='/bin/bash')
        print('Call finished with return code %d.' % ret)
        if hpmod_path is not None:
            os.chdir(backup_curr_path)

        # If we run the hpsearch on the cluster, then we just submitted a job
        # and the search didn't actually run yet.
        if not cmd_args.run_cluster and hpsearch_dir is not None:
            write_seeds_summary(hpsearch_dir, summary_keys, summary_sem,
                                summary_precs)

        print('Random seed gathering finished successfully!')
        exit(0)

    ### Random seeds not gathered yet - finalize program.
    print(hpsearch_dir is None)
    if hpsearch_dir is not None:
        print('IMPORTANT: At least one random seed has already been ' + \
              'gathered! Please ensure that the hpsearch forces the correct ' +
              'output path.')

    print('Below is a possible hpsearch call:')
    call_appendix = ''
    if hpsearch_dir is not None:
        call_appendix = '--force_out_dir --dont_force_new_dir ' + \
            '--out_dir=%s' % hpsearch_dir
    print()
    print('python3 hpsearch.py --grid_module=%s --grid_config=%s %s' % \
          (cmd_args.grid_module, config_name, call_appendix))
    print()

    # We print the individual paths to allow easy parsing via `awk` and `xargs`.
    if hpsearch_dir is None:
        print('Below is the "grid_module" name and the path to the ' +
              '"grid_config".')
        print(cmd_args.grid_module, config_name)
    else:
        print(
            'Below is the "grid_module" name, the path to the ' +
            '"grid_config" and the output path that should be used for the ' +
            'hpsearch.')
        print(cmd_args.grid_module, config_name, hpsearch_dir)