Пример #1
0
def get_main_model(config, shared, logger, device, no_weights=False):
    """Helper function to generate the main network.

    This function uses :func:`utils.sim_utils.get_mnet_model` to generate the
    main network.

    The function also takes care of weight initialization, if configured.

    Args:
        (....): See docstring of function :func:`load_datasets`.
        device: The PyTorch device.
        no_weights (bool): If ``True``, the main network is generated without
            internal weights.

    Returns:
        The main network.
    """
    net_type = 'mlp'
    logger.info('Building a MLP ...')

    if config.cl_scenario == 1:
        if 'zixuan' in config.note:
            if 'mixemnist' in config.note:
                num_outputs = sum(config.dims)
            elif 'mixceleba' in config.note:
                num_outputs = sum(config.dims)
    print('config.cl_scenario: ', config.cl_scenario)
    logger.info('The network will have %d output neurons.' % num_outputs)

    if 'mixceleba' in config.note:
        in_shape = [32 * 32 * 3]
    elif 'mixemnist' in config.note:
        in_shape = [28 * 28 * 1]

    out_shape = [num_outputs]

    print('in_shape: ', in_shape)
    print('out_shape: ', out_shape)

    # TODO Allow main net only training.
    mnet = sutils.get_mnet_model(config,
                                 net_type,
                                 in_shape,
                                 out_shape,
                                 device,
                                 no_weights=no_weights)

    init_network_weights(mnet.weights, config, logger, net=mnet)

    return mnet
Пример #2
0
def get_main_model(config, shared, logger, device, no_weights=False):
    """Helper function to generate the main network.

    This function uses :func:`utils.sim_utils.get_mnet_model` to generate the
    main network.

    The function also takes care of weight initialization, if configured.

    Args:
        (....): See docstring of function :func:`load_datasets`.
        device: The PyTorch device.
        no_weights (bool): If ``True``, the main network is generated without
            internal weights.

    Returns:
        The main network.
    """
    if shared.experiment == 'zenke':
        net_type = 'zenke'
        logger.info('Building a ZenkeNet ...')

    else:
        net_type = 'resnet'
        logger.info('Building a ResNet ...')

    num_outputs = 10

    if config.cl_scenario == 1 or config.cl_scenario == 3:
        num_outputs *= config.num_tasks

    logger.info('The network will have %d output neurons.' % num_outputs)

    in_shape = [32, 32, 3]
    out_shape = [num_outputs]

    # TODO Allow main net only training.
    mnet = sutils.get_mnet_model(config,
                                 net_type,
                                 in_shape,
                                 out_shape,
                                 device,
                                 no_weights=no_weights)

    init_network_weights(mnet.weights, config, logger, net=mnet)

    return mnet
Пример #3
0
def get_main_model(config, shared, logger, device, no_weights=False):
    """Helper function to generate the main network.

    This function uses :func:`utils.sim_utils.get_mnet_model` to generate the
    main network.

    The function also takes care of weight initialization, if configured.

    Args:
        (....): See docstring of function :func:`load_datasets`.
        device: The PyTorch device.
        no_weights (bool): If ``True``, the main network is generated without
            internal weights.

    Returns:
        The main network.
    """
    if shared.experiment == 'zenke':
        net_type = 'zenke'
        logger.info('Building a ZenkeNet ...')

    elif shared.experiment == 'resnet':
        net_type = 'resnet'
        logger.info('Building a ResNet ...')

    elif shared.experiment == 'mlp':
        net_type = 'mlp'
        logger.info('Building a MLP ...')

    num_outputs = 10

    if config.cl_scenario == 1 or config.cl_scenario == 3:
        num_outputs *= config.num_tasks
        if 'zixuan' in config.note:
            if 'sep-emnist' in config.note and 5 == config.classptask:
                num_outputs = config.classptask * config.ntasks

            elif 'sep-emnist' in config.note and 2 == config.classptask:
                num_outputs = 7 * config.ntasks

            elif 'sep-femnist' in config.note:
                num_outputs = 62 * config.ntasks

            elif 'mixemnist' in config.note:
                num_outputs = 62 * config.ntasks * 2

            elif 'sep-cifar100' in config.note:
                num_outputs = config.classptask * config.ntasks

            elif 'sep-celeba' in config.note:
                num_outputs = 2 * config.ntasks

            elif 'mixceleba' in config.note:
                num_outputs = config.classptask * config.ntasks * 2

    logger.info('The network will have %d output neurons.' % num_outputs)

    if 'mixceleba' in config.note:
        in_shape = [32 * 32 * 3]
    elif 'mixemnist' in config.note:
        in_shape = [28 * 28 * 1]

    out_shape = [num_outputs]

    print('in_shape: ', in_shape)
    print('out_shape: ', out_shape)

    # TODO Allow main net only training.
    mnet = sutils.get_mnet_model(config,
                                 net_type,
                                 in_shape,
                                 out_shape,
                                 device,
                                 no_weights=no_weights)

    init_network_weights(mnet.weights, config, logger, net=mnet)

    return mnet
Пример #4
0
def generate_gauss_networks(config,
                            logger,
                            data_handlers,
                            device,
                            no_mnet_weights=None,
                            create_hnet=True,
                            in_shape=None,
                            out_shape=None,
                            net_type='mlp',
                            non_gaussian=False):
    """Create main network and potentially the corresponding hypernetwork.

    The function will first create a normal MLP and then convert it into a
    network with Gaussian weight distribution by using the wrapper
    :class:`probabilistic.gauss_mnet_interface.GaussianBNNWrapper`.

    This function also takes care of weight initialization.

    Args:
        config: Command-line arguments.
        logger: Console (and file) logger.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        no_mnet_weights (bool, optional): Whether the main network should not
            have trainable weights. If left unspecified, then the main network
            will only have trainable weights if ``create_hnet`` is ``False``.
        create_hnet (bool): Whether a hypernetwork should be constructed.
        in_shape (list, optional): Input shape that is passed to function
            :func:`utils.sim_utils.get_mnet_model` as argument ``in_shape``.
            If not specified, it is set to ``[data_handlers[0].in_shape[0]]``.
        out_shape (list, optional): Output shape that is passed to function
            :func:`utils.sim_utils.get_mnet_model` as argument ``out_shape``.
            If not specified, it is set to ``[data_handlers[0].out_shape[0]]``.
        net_type (str): See argument ``net_type`` of function
            :func:`utils.sim_utils.get_mnet_model`.
        non_gaussian (bool): If ``True``, then the main network will not be
            converted into a Gaussian network. Hence, networks remain
            deterministic.

    Returns:
        (tuple): Tuple containing:

        - **mnet**: Main network instance.
        - **hnet** (optional): Hypernetwork instance. This return value is
          ``None`` if no hypernetwork should be constructed.
    """
    assert not hasattr(config, 'mean_only') or config.mean_only == non_gaussian
    assert not non_gaussian or not config.local_reparam_trick
    assert not non_gaussian or not config.hyper_gauss_init

    num_tasks = len(data_handlers)

    # Should be set, except for regression.
    if in_shape is None or out_shape is None:
        assert in_shape is None and out_shape is None
        assert net_type == 'mlp'
        assert hasattr(config, 'multi_head')

        n_x = data_handlers[0].in_shape[0]
        n_y = data_handlers[0].out_shape[0]
        if config.multi_head:
            n_y = n_y * num_tasks

        in_shape = [n_x]
        out_shape = [n_y]

    ### Main network.
    logger.info('Creating main network ...')

    if no_mnet_weights is None:
        no_mnet_weights = create_hnet

    if config.local_reparam_trick:
        if net_type != 'mlp':
            raise NotImplementedError('The local reparametrization trick is ' +
                                      'only implemented for MLPs so far!')
            assert len(in_shape) == 1 and len(out_shape) == 1

        mlp_arch = utils.str_to_ints(config.mlp_arch)
        net_act = utils.str_to_act(config.net_act)
        mnet = GaussianMLP(n_in=in_shape[0],
                           n_out=out_shape[0],
                           hidden_layers=mlp_arch,
                           activation_fn=net_act,
                           use_bias=not config.no_bias,
                           no_weights=no_mnet_weights).to(device)
    else:
        mnet_kwargs = {}
        if net_type == 'iresnet':
            mnet_kwargs['cutout_mod'] = True
        mnet = sutils.get_mnet_model(config,
                                     net_type,
                                     in_shape,
                                     out_shape,
                                     device,
                                     no_weights=no_mnet_weights,
                                     **mnet_kwargs)

    # Initiaize main net weights, if any.
    assert (not hasattr(config, 'custom_network_init'))
    mnet.custom_init(normal_init=config.normal_init,
                     normal_std=config.std_normal_init,
                     zero_bias=True)

    # Convert main net into Gaussian BNN.
    orig_mnet = mnet
    if not non_gaussian:
        mnet = GaussianBNNWrapper(mnet,
                                  no_mean_reinit=config.keep_orig_init,
                                  logvar_encoding=config.use_logvar_enc,
                                  apply_rho_offset=True,
                                  is_radial=config.radial_bnn).to(device)
    else:
        logger.debug('Created main network will not be converted into a ' +
                     'Gaussian main network.')

    ### Hypernet.
    hnet = None
    if create_hnet:
        logger.info('Creating hypernetwork ...')

        chunk_shapes, num_per_chunk, assembly_fct = None, None, None
        if config.hnet_type == 'structured_hmlp':
            if net_type == 'resnet':
                chunk_shapes, num_per_chunk, orig_assembly_fct = \
                    resnet_chunking(orig_mnet,
                                    gcd_chunking=config.shmlp_gcd_chunking)
            elif net_type == 'wrn':
                chunk_shapes, num_per_chunk, orig_assembly_fct = \
                    wrn_chunking(orig_mnet,
                        gcd_chunking=config.shmlp_gcd_chunking,
                        ignore_bn_weights=False, ignore_out_weights=False)
            else:
                raise NotImplementedError(
                    '"structured_hmlp" not implemented ' +
                    'for network of type %s.' % net_type)

            if non_gaussian:
                assembly_fct = orig_assembly_fct
            else:
                chunk_shapes = chunk_shapes + chunk_shapes
                num_per_chunk = num_per_chunk + num_per_chunk

                def assembly_fct_gauss(list_of_chunks):
                    n = len(list_of_chunks)
                    mean_chunks = list_of_chunks[:n // 2]
                    rho_chunks = list_of_chunks[n // 2:]

                    return orig_assembly_fct(mean_chunks) + \
                        orig_assembly_fct(rho_chunks)

                assembly_fct = assembly_fct_gauss

        # For now, we either produce all or no weights with the hypernet.
        # Note, it can be that the mnet was produced with internal weights.
        assert mnet.hyper_shapes_learned is None or \
            len(mnet.param_shapes) == len(mnet.hyper_shapes_learned)

        hnet = sutils.get_hypernet(config,
                                   device,
                                   config.hnet_type,
                                   mnet.param_shapes,
                                   num_tasks,
                                   shmlp_chunk_shapes=chunk_shapes,
                                   shmlp_num_per_chunk=num_per_chunk,
                                   shmlp_assembly_fct=assembly_fct)

        if config.hnet_out_masking != 0:
            logger.info('Generating binary masks to select task-specific ' +
                        'subnetworks from hypernetwork.')
            # Add a wrapper around the hypernpetwork that masks its outputs
            # using a task-specific binary mask layer per layer. Note that
            # output weights are not masked.

            # Ensure that masks are kind of deterministic for a given hyper-
            # param config/task.
            mask_gen = torch.Generator()
            mask_gen = mask_gen.manual_seed(42)

            # Generate a random binary mask per task.
            assert len(mnet.param_shapes) == len(hnet.target_shapes)
            hnet_out_masks = []
            for tid in range(config.num_tasks):
                hnet_out_mask = []
                for layer_shapes, is_output in zip(mnet.param_shapes, \
                        mnet.get_output_weight_mask()):
                    layer_mask = torch.ones(layer_shapes)
                    if is_output is None:
                        # We only mask weights that are not output weights.
                        layer_mask = torch.rand(layer_shapes,
                                                generator=mask_gen)
                        layer_mask[layer_mask > config.hnet_out_masking] = 1
                        layer_mask[layer_mask <= config.hnet_out_masking] = 0
                    hnet_out_mask.append(layer_mask)
                hnet_out_masks.append(hnet_out_mask)

            hnet_out_masks = hnet.convert_out_format(hnet_out_masks,
                                                     'sequential', 'flattened')

            def hnet_out_masking_func(hnet_out_int,
                                      uncond_input=None,
                                      cond_input=None,
                                      cond_id=None):
                assert isinstance(cond_id, (int, list))
                if isinstance(cond_id, int):
                    cond_id = [cond_id]

                hnet_out_int[hnet_out_masks[cond_id, :] == 0] = 0
                return hnet_out_int

            def hnet_inp_handler(uncond_input=None,
                                 cond_input=None,
                                 cond_id=None):  # Identity
                return uncond_input, cond_input, cond_id

            hnet = HPerturbWrapper(hnet,
                                   output_handler=hnet_out_masking_func,
                                   input_handler=hnet_inp_handler)

        #if config.hnet_type == 'structured_hmlp':
        #    print(num_per_chunk)
        #    for ii, int_hnet in enumerate(hnet.internal_hnets):
        #        print('   Internal hnet %d with %d outputs.' % \
        #              (ii, int_hnet.num_outputs))

        ### Initialize hypernetwork.
        if not config.hyper_gauss_init:
            apply_custom_hnet_init(config, logger, hnet)
        else:
            # Initialize task embeddings, if any.
            hnet_helpers.init_conditional_embeddings(
                hnet, normal_std=config.std_normal_temb)

            gauss_hyperfan_init(hnet,
                                mnet=mnet,
                                use_xavier=True,
                                cond_var=config.std_normal_temb**2,
                                keep_hyperfan_mean=config.keep_orig_init)

    return mnet, hnet