示例#1
0
 def __str__(self):
     """Print network information."""
     num_uncond = MainNetInterface.shapes_to_num_weights( \
         self.unconditional_param_shapes)
     num_cond = MainNetInterface.shapes_to_num_weights( \
         self.conditional_param_shapes)
     num_uncond_internal = 0
     num_cond_internal = 0
     if self.unconditional_params is not None:
         num_uncond_internal = MainNetInterface.shapes_to_num_weights( \
             [p.shape for p in self.unconditional_params])
     if self.unconditional_params is not None:
         num_cond_internal = MainNetInterface.shapes_to_num_weights( \
             [p.shape for p in self.conditional_params])
     msg = 'Hypernetwork with %d weights and %d outputs (compression ' + \
         'ratio: %.2f).\nThe network consists of %d unconditional ' + \
         'weights (%d internally maintained) and %d conditional '+ \
         'weights (%d internally maintained).'
     return msg % (self.num_params, self.num_outputs,
                   self.num_params / self.num_outputs, num_uncond,
                   num_uncond_internal, num_cond, num_cond_internal)
示例#2
0
def generate_classifier(config, data_handlers, device):
    """Create a classifier network. Depending on the experiment and method, 
    the method manages to build either a classifier for task inference 
    or a classifier that solves our task is build. This also implies if the
    network will receive weights from a hypernetwork or will have weights 
    on its own.
    Following important configurations will be determined in order to create
    the classifier: \n 
    * in- and output and hidden layer dimensions of the classifier. \n
    * architecture, chunk- and task-embedding details of the hypernetwork. 


    See :class:`mnets.mlp.MLP` for details on the network that will be created
        to be a classifier. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.
        
    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
    
    Returns: 
        (tuple): Tuple containing:
        - **net**: The classifier network.
        - **class_hnet**: (optional) The classifier's hypernetwork.
    """
    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2

    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.class_incremental:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.training_task_infer or config.class_incremental:
        # task inference network
        config.out_dim = 1

    # have all output neurons already build up for cl 2
    if config.cl_scenario != 2:
        n_out = config.out_dim * config.num_tasks
    else:
        n_out = config.out_dim

    if config.training_task_infer or config.class_incremental:
        n_out = config.num_tasks

        # build classifier
    print('For the Classifier: ')
    class_arch = misc.str_to_ints(config.class_fc_arch)
    if config.training_with_hnet:
        no_weights = True
    else:
        no_weights = False

    net = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=class_arch,
              activation_fn=misc.str_to_act(config.class_net_act),
              dropout_rate=config.class_dropout_rate,
              no_weights=no_weights).to(device)

    print('Constructed MLP with shapes: ', net.param_shapes)

    config.num_weights_class_net = \
        MainNetInterface.shapes_to_num_weights(net.param_shapes)
    # build classifier hnet
    # this is set in the run method in train.py
    if config.training_with_hnet:

        class_hnet = sim_utils.get_hnet_model(config,
                                              config.num_tasks,
                                              device,
                                              net.param_shapes,
                                              cprefix='class_')
        init_params = list(class_hnet.parameters())

        config.num_weights_class_hyper_net = sum(
            p.numel() for p in class_hnet.parameters() if p.requires_grad)
        config.compression_ratio_class = config.num_weights_class_hyper_net / \
                                         config.num_weights_class_net
        print('Created classifier Hypernetwork with ratio: ',
              config.compression_ratio_class)
        if config.compression_ratio_class > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network, not might not be directly ' +
                  'comparable with the number of parameters of work we ' +
                  'compare against.')
    else:
        class_hnet = None
        init_params = list(net.parameters())
        config.num_weights_class_hyper_net = None
        config.compression_ratio_class = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if config.training_with_hnet:
        for temb in class_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(class_hnet, 'chunk_embeddings'):
        for emb in class_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    if not config.training_with_hnet:
        return net
    else:
        return net, class_hnet
    def __init__(self, in_shape=(32, 32, 3),
                 num_classes=10, verbose=True, arch='cifar', no_weights=False,
                 init_weights=None, dropout_rate=0.25):
        super(ZenkeNet, self).__init__(num_classes, verbose)

        assert(in_shape[0] == 32 and in_shape[1] == 32)
        self._in_shape = in_shape

        assert(arch in ZenkeNet._architectures.keys())
        self._param_shapes = ZenkeNet._architectures[arch]
        self._param_shapes[-2][0] = num_classes
        self._param_shapes[-1][0] = num_classes

        assert(init_weights is None or no_weights is False)
        self._no_weights = no_weights

        self._use_dropout = dropout_rate != -1

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        # We don't use any output non-linearity.
        self._has_linear_out = True

        self._num_weights = MainNetInterface.shapes_to_num_weights( \
            self._param_shapes)
        if verbose:
            print('Creating a ZenkeNet with %d weights' \
                  % (self._num_weights)
                  + (', that uses dropout.' if self._use_dropout else '.'))

        if self._use_dropout:
            if dropout_rate > 0.5:
                # FIXME not a pretty solution, but we aim to follow the original
                # paper.
                raise ValueError('Dropout rate must be smaller equal 0.5.')
            self._drop_conv = nn.Dropout2d(p=dropout_rate)
            self._drop_fc1 = nn.Dropout(p=dropout_rate * 2.)

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._weights = None
            self._hyper_shapes_learned = self._param_shapes
            self._hyper_shapes_learned_ref = \
                list(range(len(self._param_shapes)))
            self._is_properly_setup()
            return

        ### Define and initialize network weights.
        # Each odd entry of this list will contain a weight Tensor and each
        # even entry a bias vector.
        self._weights = nn.ParameterList()

        for i, dims in enumerate(self._param_shapes):
            self._weights.append(nn.Parameter(torch.Tensor(*dims),
                                              requires_grad=True))

            if i % 2 == 0:
                self._layer_weight_tensors.append(self._weights[i])
            else:
                assert(len(dims) == 1)
                self._layer_bias_vectors.append(self._weights[i])

        if init_weights is not None:
            assert(len(init_weights) == len(self._param_shapes))
            for i in range(len(init_weights)):
                assert(np.all(np.equal(list(init_weights[i].shape),
                                       list(self._weights[i].shape))))
                self._weights[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                init_params(self._layer_weight_tensors[i],
                            self._layer_bias_vectors[i])

        self._is_properly_setup()
示例#4
0
    def __init__(self,
                 target_shapes,
                 num_tasks,
                 chunk_dim=2586,
                 layers=[50, 100],
                 te_dim=8,
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 ce_dim=None,
                 init_weights=None,
                 dropout_rate=-1,
                 noise_dim=-1,
                 temb_std=-1):
        # FIXME find a way using super to handle multiple inheritence.
        #super(ChunkedHyperNetworkHandler, self).__init__()
        nn.Module.__init__(self)
        CLHyperNetInterface.__init__(self)

        assert (len(target_shapes) > 0)
        assert (init_weights is None or no_weights is False)
        assert (ce_dim is not None)
        self._target_shapes = target_shapes
        self._num_tasks = num_tasks
        self._ce_dim = ce_dim
        self._chunk_dim = chunk_dim
        self._layers = layers
        self._use_bias = use_bias
        self._act_fn = activation_fn
        self._init_weights = init_weights
        self._no_weights = no_weights
        self._te_dim = te_dim
        self._noise_dim = noise_dim
        self._temb_std = temb_std
        self._shifts = None  # FIXME temporary test.

        # FIXME: weights should incorporate chunk embeddings as they are part of
        # theta.
        if init_weights is not None:
            warn('Argument "init_weights" does not yet allow initialization ' +
                 'of chunk embeddings.')

        ### Generate Hypernet with chunk_dim output.
        # Note, we can safely pass "temb_std" to the full hypernetwork, as we
        # process all chunks in one big batch and the hypernet will use the same
        # perturbed task embeddings for that reason (i.e., noise is shared).
        self._hypernet = HyperNetwork([[chunk_dim]],
                                      num_tasks,
                                      verbose=False,
                                      layers=layers,
                                      te_dim=te_dim,
                                      activation_fn=activation_fn,
                                      use_bias=use_bias,
                                      no_weights=no_weights,
                                      init_weights=init_weights,
                                      ce_dim=ce_dim +
                                      (noise_dim if noise_dim != -1 else 0),
                                      dropout_rate=dropout_rate,
                                      noise_dim=-1,
                                      temb_std=temb_std)

        self._num_outputs = MainNetInterface.shapes_to_num_weights( \
            self._target_shapes)
        ### Generate embeddings for all weight chunks.
        self._num_chunks = int(np.ceil(self._num_outputs / chunk_dim))
        if no_weights:
            self._embs = None
        else:
            self._embs = nn.Parameter(data=torch.Tensor(
                self._num_chunks, ce_dim),
                                      requires_grad=True)
            nn.init.normal_(self._embs, mean=0., std=1.)

        # Note, the chunk embeddings are part of theta.
        hdims = self._hypernet.theta_shapes
        ntheta = MainNetInterface.shapes_to_num_weights(hdims) + \
            (self._embs.numel() if not no_weights else 0)

        ntembs = int(np.sum([t.numel() for t in self.get_task_embs()]))
        self._num_weights = ntheta + ntembs
        print('Constructed hypernetwork with %d parameters ' % (ntheta \
              + ntembs) + '(%d network weights + %d task embedding weights).'
              % (ntheta, ntembs))

        print('The hypernetwork has a total of %d outputs.' %
              self._num_outputs)

        self._theta_shapes = [[self._num_chunks, ce_dim]] + \
            self._hypernet.theta_shapes

        self._is_properly_setup()
示例#5
0
文件: mlp.py 项目: ZixuanKe/CAT
    def __init__(self,
                 n_in=1,
                 n_out=1,
                 hidden_layers=[2000, 2000],
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 out_fn=None,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritence.
        #super(MainNetwork, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')

        if use_batch_norm and use_context_mod:
            # FIXME Does it make sense to have both enabled?
            # I.e., should we produce a warning or error?
            pass

        self._a_fun = activation_fn
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        #self._use_spectral_norm = use_spectral_norm
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._out_fn = out_fn

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True if out_fn is None else False

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        # FIXME make sure that this implementation is correct in all situations
        # (e.g., what to do if weights are passed to the forward method?).
        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x: x  # identity

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_sizes = []
            if context_mod_inputs:
                cm_sizes.append(n_in)
            cm_sizes.extend(hidden_layers)
            if not no_last_layer_context_mod:
                cm_sizes.append(n_out)

            for i, n in enumerate(cm_sizes):
                cmod_layer = ContextModLayer(
                    n,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[cm_ind].shape),
                                list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]
        print('hidden_layers', hidden_layers)

        ### Define and initialize batch norm weights.
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, n in enumerate(hidden_layers):
                bn_layer = BatchNormLayer(n,
                                          affine=not no_weights,
                                          track_running_stats=bn_track_stats)
                self._batchnorm_layers.append(bn_layer)
                self._param_shapes.extend(bn_layer.param_shapes)

                if no_weights:
                    self._hyper_shapes_learned.extend(bn_layer.param_shapes)
                else:
                    self._weights.extend(bn_layer.weights)

                if distill_bn_stats:
                    self._hyper_shapes_distilled.extend( \
                        [list(p.shape) for p in bn_layer.get_stats(0)])

                # FIXME ugly code. Move initialization somewhere else.
                if not no_weights and init_weights is not None:
                    assert (len(bn_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                        bn_layer.weights[ii].data = init_weights[bn_ind]
                        bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        # Compute shapes of linear layers.
        linear_shapes = MLP.weight_shapes(n_in=n_in,
                                          n_out=n_out,
                                          hidden_layers=hidden_layers,
                                          use_bias=use_bias)
        self._param_shapes.extend(linear_shapes)

        num_weights = MainNetInterface.shapes_to_num_weights(
            self._param_shapes)

        if verbose:
            if use_context_mod:
                cm_num_weights = 0
                for cm_layer in self._context_mod_layers:
                    cm_num_weights += MainNetInterface.shapes_to_num_weights( \
                        cm_layer.param_shapes)

            print(
                'Creating an MLP with %d weights' % num_weights +
                (' (including %d weights associated with-' % cm_num_weights +
                 'context modulation)' if use_context_mod else '') + '.' +
                (' The network uses dropout.' if dropout_rate != -1 else '') +
                (' The network uses batchnorm.' if use_batch_norm else ''))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._hyper_shapes_learned.extend(linear_shapes)
            self._is_properly_setup()
            return

        ### Define and initialize linear weights.
        for i, dims in enumerate(linear_shapes):
            self._weights.append(
                nn.Parameter(torch.Tensor(*dims), requires_grad=True))
            if len(dims) == 1:
                self._layer_bias_vectors.append(self._weights[-1])
            else:
                self._layer_weight_tensors.append(self._weights[-1])

        if init_weights is not None:
            assert (len(init_weights) == len(linear_shapes))
            for i in range(len(init_weights)):
                assert (np.all(
                    np.equal(list(init_weights[i].shape), linear_shapes[i])))
                if use_bias:
                    if i % 2 == 0:
                        self._layer_weight_tensors[i //
                                                   2].data = init_weights[i]
                    else:
                        self._layer_bias_vectors[i // 2].data = init_weights[i]
                else:
                    self._layer_weight_tensors[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                if use_bias:
                    init_params(self._layer_weight_tensors[i],
                                self._layer_bias_vectors[i])
                else:
                    init_params(self._layer_weight_tensors[i])

        self._is_properly_setup()
示例#6
0
    def __init__(self,
                 target_shapes,
                 num_tasks,
                 layers=[50, 100],
                 verbose=True,
                 te_dim=8,
                 no_te_embs=False,
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 ce_dim=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 create_feedback_matrix=False,
                 target_net_out_dim=10,
                 random_scale_feedback_matrix=1.,
                 use_batch_norm=False,
                 noise_dim=-1,
                 temb_std=-1):
        """Build the network. The network will consist of several hidden layers
        and a dedicated output layer for each weight matrix/bias vector.

        The input to the network will be a learned task embedding.

        Args:
            target_shapes: A list of list of integers, denoting the shape of
                each parameter tensor in the main network (hence, determining
                the output of this network).
            num_tasks: The number of task embeddings needed.
            layers: A list of integers, each indicating the size of a hidden
                    layer in this network.
            te_dim: The dimensionality of the task embeddings.
            no_te_embs: If this option is True, no class internal task
                embeddings are constructed and are instead expected to be
                provided to the forward method.
            activation_fn: The nonlinearity used in hidden layers. If None, no
                nonlinearity will be applied.
            use_bias: Whether layers may have bias terms.
            no_weights: If set to True, no trainable parameters will be
                constructed, i.e., weights are assumed to be produced ad-hoc
                by a hypernetwork and passed to the forward function.
                Does not affect task embeddings.
            init_weights (optional): This option is for convenience reasons.
                The option expects a list of parameter values that are used to
                initialize the network weights. As such, it provides a
                convenient way of initializing a network with, for instance, a
                weight draw produced by the hypernetwork.
                Does not affect task embeddings.
            ce_dim (optional): The dimensionality of any additional embeddings,
                (in addition to the task embedding) that will be used as input
                to the hypernetwork. If this option is ``None``, no additional
                input is expected. Otherwise, an additional embedding has to be
                passed to the :meth:`forward` method (see argument
                ``ext_inputs``).
                A typical usecase would be a chunk embedding.
            dropout_rate (optional): If -1, no dropout will be applied.
                Otherwise a number between 0 and 1 is expected, denoting the
                dropout of hidden layers.
            use_spectral_norm: Enable spectral normalization for all layers.
            create_feedback_matrix: A feedback matrix for credit assignment in
                the main network will be created. See attribute
                :attr:`feedback_matrix`.
            target_net_out_dim: Target network output dimension. We need this
                information to create feedback matrices that can be used in
                conjunction with Direct Feedback Alignment. Only needs to be
                specified when enabling ``create_feedback_matrix``.
            random_scale_feedback_matrix: Scale of uniform distribution used
                to create the feedback matrix. Only needs to be specified when
                enabling ``create_feedback_matrix``.
            use_batch_norm: If True, batchnorm will be applied to all hidden
                layers.
            noise_dim (optional): If -1, no noise will be applied.
                Otherwise the hypernetwork will receive as additional input
                zero-mean Gaussian noise with unit variance during training
                (zeroes will be inputted during eval-mode). Note, if a batch of
                inputs is given, then a different noise vector is generated for
                every sample in the batch.
            temb_std (optional): If not -1, the task embeddings will be
                perturbed by zero-mean Gaussian noise with the given std
                (additive noise). The perturbation is only applied if the
                network is in training mode. Note, per batch of external inputs,
                the perturbation of the task embedding will be shared.
        """
        # FIXME find a way using super to handle multiple inheritence.
        #super(HyperNetwork, self).__init__()
        nn.Module.__init__(self)
        CLHyperNetInterface.__init__(self)

        if use_spectral_norm:
            raise NotImplementedError(
                'Spectral normalization not yet ' +
                'implemented for this hypernetwork type.')
        if use_batch_norm:
            # Note, batch normalization only makes sense when batch processing
            # is applied during training (i.e., batch size > 1).
            # As long as we only support processing of 1 task embedding, that
            # means that external inputs are required.
            if ce_dim is None:
                raise ValueError('Can\'t use batchnorm as long as ' +
                                 'hypernetwork process more than 1 sample ' +
                                 '("ce_dim" must be specified).')
            raise NotImplementedError(
                'Batch normalization not yet ' +
                'implemented for this hypernetwork type.')

        assert (len(target_shapes) > 0)
        assert (no_te_embs or num_tasks > 0)
        self._num_tasks = num_tasks

        assert (init_weights is None or no_weights is False)
        self._no_weights = no_weights
        self._no_te_embs = no_te_embs
        self._te_dim = te_dim
        self._size_ext_input = ce_dim
        self._layers = layers
        self._target_shapes = target_shapes
        self._use_bias = use_bias
        self._act_fn = activation_fn
        self._init_weights = init_weights
        self._dropout_rate = dropout_rate
        self._noise_dim = noise_dim
        self._temb_std = temb_std
        self._shifts = None  # FIXME temporary test.

        ### Hidden layers
        self._gen_layers(layers, te_dim, use_bias, no_weights, init_weights,
                         ce_dim, noise_dim)

        if create_feedback_matrix:
            self._create_feedback_matrix(target_shapes, target_net_out_dim,
                                         random_scale_feedback_matrix)

        self._dropout = None
        if dropout_rate != -1:
            assert (dropout_rate >= 0 and dropout_rate <= 1)
            self._dropout = nn.Dropout(dropout_rate)

        # Task embeddings.
        if no_te_embs:
            self._task_embs = None
        else:
            self._task_embs = nn.ParameterList()
            for _ in range(num_tasks):
                self._task_embs.append(
                    nn.Parameter(data=torch.Tensor(te_dim),
                                 requires_grad=True))
                torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.)

        self._theta_shapes = self._hidden_dims + self._out_dims

        ntheta = MainNetInterface.shapes_to_num_weights(self._theta_shapes)
        ntembs = int(np.sum([t.numel() for t in self._task_embs])) \
                if not no_te_embs else 0
        self._num_weights = ntheta + ntembs

        self._num_outputs = MainNetInterface.shapes_to_num_weights( \
            self.target_shapes)

        if verbose:
            print('Constructed hypernetwork with %d parameters (' % (ntheta \
                  + ntembs) + '%d network weights + %d task embedding weights).'
                  % (ntheta, ntembs))
            print('The hypernetwork has a total of %d outputs.' %
                  self._num_outputs)

        self._is_properly_setup()
示例#7
0
    def __init__(self,
                 main_dims,
                 num_tasks,
                 out_size=[64, 64],
                 num_layers=5,
                 num_filters=None,
                 kernel_size=5,
                 sa_units=[1, 3],
                 rem_layers=[50, 50, 50],
                 te_dim=8,
                 ce_dim=8,
                 no_theta=False,
                 init_theta=None,
                 use_batch_norm=False,
                 use_spectral_norm=False,
                 dropout_rate=-1,
                 discard_remainder=False,
                 noise_dim=-1,
                 temb_std=-1):
        # FIXME find a way using super to handle multiple inheritence.
        #super(SAHyperNetwork, self).__init__()
        nn.Module.__init__(self)
        CLHyperNetInterface.__init__(self)

        if init_theta is not None:
            # FIXME I would need to know the number of parameter tensors in each
            # hypernet before creating them to split the list init_theta.
            raise NotImplementedError(
                'Argument "init_theta" not implemented ' + 'yet!')

        assert (init_theta is None or no_theta is False)
        self._no_theta = no_theta
        self._te_dim = te_dim
        self._discard_remainder = discard_remainder

        self._target_shapes = main_dims
        self._num_outputs = MainNetInterface.shapes_to_num_weights(main_dims)
        print('Building a self-attention hypernet for a network with %d '% \
              self._num_outputs + 'weights.')
        assert (len(out_size) in [2, 3])
        self._out_size = out_size
        num_outs = np.prod(out_size)
        assert (num_outs <= self._num_outputs)
        self._noise_dim = noise_dim
        self._temb_std = temb_std

        num_embs = self._num_outputs // num_outs
        rem_weights = self._num_outputs % num_outs

        if rem_weights > 0 and not discard_remainder:
            print('%d remaining weights (%.2f%%) are generated by a fully-' \
                  % (rem_weights, 100.0 * rem_weights / self._num_outputs) + \
                  'connected hypernetwork.')
        elif rem_weights > 0:
            num_embs += 1

            print('%d weights generated by the last chunk of the self-' %
                  (num_outs - rem_weights) + 'attention hypernet will be ' +
                  'discarded.')

        self._num_embs = num_embs

        ### Generate Hypernet.
        self._hypernet = SAHnetPart(out_size=out_size,
                                    num_layers=num_layers,
                                    num_filters=num_filters,
                                    kernel_size=kernel_size,
                                    sa_units=sa_units,
                                    input_dim=te_dim + ce_dim +
                                    (noise_dim if noise_dim != -1 else 0),
                                    use_batch_norm=use_batch_norm,
                                    use_spectral_norm=use_spectral_norm,
                                    no_theta=no_theta,
                                    init_theta=None)

        self._rem_hypernet = None
        self._remainder = rem_weights
        if rem_weights > 0 and not discard_remainder:
            print('A second hypernet for the remainder of the weights has ' +
                  'to be created, as %d is not dividable by %d ' %
                  (self._num_outputs, num_outs) +
                  '(remaidner %d)' % rem_weights)
            self._rem_hypernet = HyperNetwork(
                [[rem_weights]],
                None,
                layers=rem_layers,
                te_dim=te_dim,
                no_te_embs=True,
                no_weights=no_theta,
                ce_dim=(noise_dim if noise_dim != -1 else None),
                dropout_rate=dropout_rate,
                use_batch_norm=use_batch_norm,
                use_spectral_norm=use_spectral_norm,
                noise_dim=-1,
                temb_std=None)

        ### Generate embeddings for all weight chunks.
        if no_theta:
            self._embs = None
        else:
            self._embs = nn.Parameter(data=torch.Tensor(num_embs, ce_dim),
                                      requires_grad=True)
            torch.nn.init.normal_(self._embs, mean=0., std=1.)

        # There is no need for a chunk embedding, as this network always
        # produces the same chunk.
        #if self._remainder > 0  and not discard_remainder:
        #    self._rem_emb = nn.Parameter(data=torch.Tensor(1, ce_dim),
        #                                 requires_grad=True)
        #    torch.nn.init.normal_(self._rem_emb, mean=0., std=1.)

        ### Generate task embeddings.
        self._task_embs = nn.ParameterList()
        # We store individual task embeddings as it makes it easier to pass
        # only subsets of task embeddings to an optimizer.
        for _ in range(num_tasks):
            self._task_embs.append(
                nn.Parameter(data=torch.Tensor(te_dim), requires_grad=True))
            torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.)

        self._num_weights = 0
        for p in list(self.parameters()):
            self._num_weights += np.prod(p.shape)
        print('Total number of parameters in the hypernetwork: %d' %
              self._num_weights)

        self._theta_shapes = [[num_embs, ce_dim]] + \
            self._hypernet.theta_shapes
        if self._rem_hypernet is not None:
            self._theta_shapes += self._rem_hypernet.theta_shapes
示例#8
0
    def __init__(self,
                 n_in=1,
                 n_out=1,
                 hidden_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_gain_softplus=False,
                 out_fn=None,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        # FIXME Spectral norm is incorrectly implemented. Function
        # `nn.utils.spectral_norm` needs to be called in the constructor, such
        # that sepc norm is wrapped around a module.
        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')

        if use_batch_norm and use_context_mod:
            # FIXME Does it make sense to have both enabled?
            # I.e., should we produce a warning or error?
            pass

        self._a_fun = activation_fn
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        #self._use_spectral_norm = use_spectral_norm
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_gain_softplus = context_mod_gain_softplus
        self._out_fn = out_fn

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True if out_fn is None else False

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        # FIXME make sure that this implementation is correct in all situations
        # (e.g., what to do if weights are passed to the forward method?).
        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x: x  # identity

        self._param_shapes = []
        self._param_shapes_meta = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_sizes = []
            if context_mod_inputs:
                cm_sizes.append(n_in)
            cm_sizes.extend(hidden_layers)
            if not no_last_layer_context_mod:
                cm_sizes.append(n_out)

            for i, n in enumerate(cm_sizes):
                cmod_layer = ContextModLayer(
                    n,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset,
                    apply_gain_softplus=context_mod_gain_softplus)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                assert len(cmod_layer.param_shapes) == 2
                self._param_shapes_meta.extend([
                    {'name': 'cm_scale',
                     'index': -1 if context_mod_no_weights else \
                         len(self._weights),
                     'layer': -1}, # 'layer' is set later.
                    {'name': 'cm_shift',
                     'index': -1 if context_mod_no_weights else \
                         len(self._weights)+1,
                     'layer': -1}, # 'layer' is set later.
                ])

                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[cm_ind].shape),
                                list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]

        if context_mod_no_weights:
            self._hyper_shapes_learned_ref = \
                list(range(len(self._param_shapes)))

        ### Define and initialize batch norm weights.
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, n in enumerate(hidden_layers):
                bn_layer = BatchNormLayer(n,
                                          affine=not no_weights,
                                          track_running_stats=bn_track_stats)
                self._batchnorm_layers.append(bn_layer)

                self._param_shapes.extend(bn_layer.param_shapes)
                assert len(bn_layer.param_shapes) == 2
                self._param_shapes_meta.extend([
                    {
                        'name': 'bn_scale',
                        'index': -1 if no_weights else len(self._weights),
                        'layer': -1
                    },  # 'layer' is set later.
                    {
                        'name': 'bn_shift',
                        'index': -1 if no_weights else len(self._weights) + 1,
                        'layer': -1
                    },  # 'layer' is set later.
                ])

                if no_weights:
                    self._hyper_shapes_learned.extend(bn_layer.param_shapes)
                else:
                    self._weights.extend(bn_layer.weights)

                if distill_bn_stats:
                    self._hyper_shapes_distilled.extend( \
                        [list(p.shape) for p in bn_layer.get_stats(0)])

                # FIXME ugly code. Move initialization somewhere else.
                if not no_weights and init_weights is not None:
                    assert (len(bn_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                        bn_layer.weights[ii].data = init_weights[bn_ind]
                        bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        ### Compute shapes of linear layers.
        linear_shapes = MLP.weight_shapes(n_in=n_in,
                                          n_out=n_out,
                                          hidden_layers=hidden_layers,
                                          use_bias=use_bias)
        self._param_shapes.extend(linear_shapes)

        for i, s in enumerate(linear_shapes):
            self._param_shapes_meta.append({
                'name':
                'weight' if len(s) != 1 else 'bias',
                'index':
                -1 if no_weights else len(self._weights) + i,
                'layer':
                -1  # 'layer' is set later.
            })

        num_weights = MainNetInterface.shapes_to_num_weights(
            self._param_shapes)

        ### Set missing meta information of param_shapes.
        offset = 1 if use_context_mod and context_mod_inputs else 0
        shift = 1
        if use_batch_norm:
            shift += 1
        if use_context_mod:
            shift += 1

        cm_offset = 2 if context_mod_post_activation else 1
        bn_offset = 1 if context_mod_post_activation else 2

        cm_ind = 0
        bn_ind = 0
        layer_ind = 0

        for i, dd in enumerate(self._param_shapes_meta):
            if dd['name'].startswith('cm'):
                if offset == 1 and i in [0, 1]:
                    dd['layer'] = 0
                else:
                    if cm_ind < len(hidden_layers):
                        dd['layer'] = offset + cm_ind * shift + cm_offset
                    else:
                        assert cm_ind == len(hidden_layers) and \
                            not no_last_layer_context_mod
                        # No batchnorm in output layer.
                        dd['layer'] = offset + cm_ind * shift + 1

                    if dd['name'] == 'cm_shift':
                        cm_ind += 1

            elif dd['name'].startswith('bn'):
                dd['layer'] = offset + bn_ind * shift + bn_offset
                if dd['name'] == 'bn_shift':
                    bn_ind += 1

            else:
                dd['layer'] = offset + layer_ind * shift
                if not use_bias or dd['name'] == 'bias':
                    layer_ind += 1

        ### Uer information
        if verbose:
            if use_context_mod:
                cm_num_weights = 0
                for cm_layer in self._context_mod_layers:
                    cm_num_weights += MainNetInterface.shapes_to_num_weights( \
                        cm_layer.param_shapes)

            print(
                'Creating an MLP with %d weights' % num_weights +
                (' (including %d weights associated with-' % cm_num_weights +
                 'context modulation)' if use_context_mod else '') + '.' +
                (' The network uses dropout.' if dropout_rate != -1 else '') +
                (' The network uses batchnorm.' if use_batch_norm else ''))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._hyper_shapes_learned.extend(linear_shapes)

            if use_context_mod:
                if context_mod_no_weights:
                    self._hyper_shapes_learned_ref = \
                        list(range(len(self._param_shapes)))
                else:
                    ncm = len(self._context_mod_shapes)
                    self._hyper_shapes_learned_ref = \
                        list(range(ncm, len(self._param_shapes)))

            self._is_properly_setup()
            return

        ### Define and initialize linear weights.
        for i, dims in enumerate(linear_shapes):
            self._weights.append(
                nn.Parameter(torch.Tensor(*dims), requires_grad=True))
            if len(dims) == 1:
                self._layer_bias_vectors.append(self._weights[-1])
            else:
                self._layer_weight_tensors.append(self._weights[-1])

        if init_weights is not None:
            assert (len(init_weights) == len(linear_shapes))
            for i in range(len(init_weights)):
                assert (np.all(
                    np.equal(list(init_weights[i].shape), linear_shapes[i])))
                if use_bias:
                    if i % 2 == 0:
                        self._layer_weight_tensors[i //
                                                   2].data = init_weights[i]
                    else:
                        self._layer_bias_vectors[i // 2].data = init_weights[i]
                else:
                    self._layer_weight_tensors[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                if use_bias:
                    init_params(self._layer_weight_tensors[i],
                                self._layer_bias_vectors[i])
                else:
                    init_params(self._layer_weight_tensors[i])

        if self._num_context_mod_shapes() == 0:
            # Note, that might be the case if no hidden layers exist and no
            # input or output modulation is used.
            self._use_context_mod = False

        self._is_properly_setup()
    def __init__(self, in_shape=(32, 32, 3), num_classes=10, no_weights=False,
                 init_weights=None, use_context_mod=False,
                 context_mod_inputs=False, no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_gain_softplus=False,
                 context_mod_apply_pixel_wise=False):
        super(BioConvNet, self).__init__(num_classes, True)

        assert(len(in_shape) == 3)
        # FIXME This assertion is not mandatory but a sanity check that the user
        # uses the Tensorflow layout.
        assert(in_shape[2] in [1, 3])
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._in_shape = in_shape
        self._no_weights = no_weights
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_gain_softplus = context_mod_gain_softplus
        self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        # Shapes of output activities for context-modulation, if used.
        cm_shapes = [] # Output shape of all layers.
        if context_mod_inputs:
            cm_shapes.append([in_shape[2], *in_shape[:2]])

        ### Define and initialize all conv and linear layers
        ### Bio-conv layers.
        H = in_shape[0]
        W = in_shape[1]
        C_in = in_shape[2]

        C = [64, 128, 256]
        K = [5, 5, 3]
        S = [2, 2, 1]
        P = [0, 0, 1]

        self._conv_layer = []

        for i, C_out in enumerate(C):
            self._conv_layer.append(LocalConv2dLayer(C_in, C_out, H, W, K[i],
                stride=S[i], padding=P[i], no_weights=no_weights))
            H = self._conv_layer[-1].out_height
            W = self._conv_layer[-1].out_width

            cm_shapes.append([C_out, H, W])

            C_in = C_out

            self._param_shapes.extend(self._conv_layer[-1].param_shapes)
            if no_weights:
                self._hyper_shapes_learned.extend( \
                    self._conv_layer[-1].param_shapes)
            else:
                self._weights.extend(self._conv_layer[-1].weights)

                assert(len(self._conv_layer[-1].weights) == 2)
                self._layer_weight_tensors.append( \
                    self._conv_layer[-1].filters)
                self._layer_bias_vectors.append( \
                    self._conv_layer[-1].bias)

        ### Linear layers
        n_in = H * W * C_out
        assert(n_in == 6400)
        n = [1024, num_classes]
        for i, n_out in enumerate(n):
            W_shape = [n_out, n_in]
            b_shape = [n_out]

            # Note, that the last layer shape might not be used for context-
            # modulation.
            if i < (len(n)-1) or not no_last_layer_context_mod:
                cm_shapes.append([n_out])

            n_in = n_out

            self._param_shapes.extend([W_shape, b_shape])
            if no_weights:
                self._hyper_shapes_learned.extend([W_shape, b_shape])
            else:
                W = nn.Parameter(torch.Tensor(*W_shape), requires_grad=True)
                b = nn.Parameter(torch.Tensor(*b_shape), requires_grad=True)

                init_params(W, b)

                self._weights.extend([W, b])
                self._layer_weight_tensors.append(W)
                self._layer_bias_vectors.append(b)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None
        self._context_mod_weights = nn.ParameterList() if use_context_mod \
            else None

        if use_context_mod:
            if not context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            for i, s in enumerate(cm_shapes):
                cmod_layer = ContextModLayer(s,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset,
                    apply_gain_softplus=context_mod_gain_softplus)
                self._context_mod_layers.append(cmod_layer)

                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if not context_mod_no_weights:
                    self._context_mod_weights.extend(cmod_layer.weights)

            # We always had the context mod weights/shapes at the beginning of
            # our list attributes.
            self._param_shapes = self._context_mod_shapes + self._param_shapes
            if context_mod_no_weights:
                self._hyper_shapes_learned = self._context_mod_shapes + \
                    self._hyper_shapes_learned
            else:
                tmp = self._weights
                self._weights = nn.ParameterList(self._context_mod_weights)
                for w in tmp:
                    self._weights.append(w)

        ### Apply custom init if given.
        if init_weights is not None:
            assert(len(self.weights) == len(init_weights))
            for i in range(len(init_weights)):
                assert(np.all(np.equal(list(init_weights[i].shape),
                                       list(self._weights[i].shape))))
                self._weights[i].data = init_weights[i]

        ### Print user info.
        num_weights = MainNetInterface.shapes_to_num_weights( \
            self._param_shapes)
        if use_context_mod:
            cm_num_weights = MainNetInterface.shapes_to_num_weights( \
                self._context_mod_shapes)

        print('Creating bio-plausible convnet with %d weights' % num_weights
              + (' (including %d weights associated with-' % cm_num_weights
                 + 'context modulation)' if use_context_mod else '') + '.')

        self._is_properly_setup()
示例#10
0
 def num_outputs(self):
     """Getter for the attribute :attr:`num_outputs`."""
     return MainNetInterface.shapes_to_num_weights(self.target_shapes)
示例#11
0
    def __init__(self,
                 in_shape=(224, 224, 3),
                 num_classes=1000,
                 use_bias=True,
                 use_fc_bias=None,
                 num_feature_maps=(64, 64, 128, 256, 512),
                 blocks_per_group=(2, 2, 2, 2),
                 projection_shortcut=False,
                 bottleneck_blocks=False,
                 cutout_mod=False,
                 no_weights=False,
                 use_batch_norm=True,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 chw_input_format=False,
                 verbose=True,
                 **kwargs):
        super(ResNetIN, self).__init__(num_classes, verbose)

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if 'context_mod_apply_pixel_wise' in rem_kwargs:
            rem_kwargs.remove('context_mod_apply_pixel_wise')
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Check or parse remaining arguments ###
        self._in_shape = in_shape
        self._projection_shortcut = projection_shortcut
        self._bottleneck_blocks = bottleneck_blocks
        self._cutout_mod = cutout_mod
        if use_fc_bias is None:
            use_fc_bias = use_bias
        # Also, checkout attribute `_has_bias` below.
        self._use_bias = use_bias
        self._use_fc_bias = use_fc_bias
        self._no_weights = no_weights
        assert not use_batch_norm or (not distill_bn_stats or bn_track_stats)
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._chw_input_format = chw_input_format

        if len(blocks_per_group) != 4:
            raise ValueError('Option "blocks_per_group" must be a list of 4 ' +
                             'integers.')
        self._num_blocks = blocks_per_group
        if len(num_feature_maps) != 5:
            raise ValueError('Option "num_feature_maps" must be a list of 5 ' +
                             'integers.')
        self._filter_sizes = list(num_feature_maps)

        # The first layer of group 3, 4 and 5 uses a strided convolution, so
        # the shorcut connections need to perform a downsampling operation. In
        # addition, whenever traversing from one group to the next, the number
        # of feature maps might change. In all these cases, the network might
        # benefit from smart shortcut connections, which means using projection
        # shortcuts, where a 1x1 conv is used for the mentioned skip connection.
        self._num_non_ident_skips = 3  # Strided convs: 2->3, 3->4 and 4->5
        fs1 = self._filter_sizes[1]
        if self._bottleneck_blocks:
            fs1 *= 4
        if self._filter_sizes[0] != fs1:
            self._num_non_ident_skips += 1  # Also handle 1->2.
        self._group_has_1x1 = [False] * 4
        if self._projection_shortcut:
            for i in range(3, 3 - self._num_non_ident_skips, -1):
                self._group_has_1x1[i] = True
        # Number of conv layers (excluding skip connections)
        self._num_main_conv_layers = 1 + int(np.sum([self._num_blocks[i] * \
            (3 if self._bottleneck_blocks else 2) for i in range(4)]))

        # The original architecture uses a 7x7 kernel in the first conv layer
        # and 3x3 or 1x1 kernels in all remaining layers.
        self._init_kernel_size = (7, 7)
        # All 3x3 layers have padding 1 and 1x1 layers have padding 0.
        self._init_padding = 3
        self._init_stride = 2

        if self._cutout_mod:
            self._init_kernel_size = (3, 3)
            self._init_padding = 1
            self._init_stride = 1

        ### Set required class attributes ###
        # Note, we did overwrite the getter for attribute `has_bias`, as it is
        # not applicable if the values of `use_bias` and `use_fc_bias` differ.
        self._has_bias = use_bias if use_bias == use_fc_bias else False
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        #################################
        ### Create context mod layers ###
        #################################
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = []  # Output shape of all layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            cm_shapes.extend(layer_out_shapes)
            # All layer indices `l` with `l mod 3 == 0` are context-mod layers.
            cm_layer_inds.extend(range(3, 3 * len(layer_out_shapes) + 1, 3))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]
            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ###############################
        ### Create batchnorm layers ###
        ###############################
        # We just use even numbers starting from 2 as layer indices for
        # batchnorm layers.
        if use_batch_norm:
            bn_sizes = []
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    bn_sizes.append(s)
                else:
                    for _ in range(self._num_blocks[i - 1]):
                        if self._bottleneck_blocks:
                            bn_sizes.extend([s, s, 4 * s])
                        else:
                            bn_sizes.extend([s, s])

            # All layer indices `l` with `l mod 3 == 2` are batchnorm layers.
            bn_layers = list(range(2, 3 * len(bn_sizes) + 1, 3))

            # We also need a batchnorm layer per skip connection that uses 1x1
            # projections.
            if self._projection_shortcut:
                bn_layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 2

                factor = 4 if self._bottleneck_blocks else 1
                for i in range(4):  # For each transition between conv groups.
                    if self._group_has_1x1[i]:
                        bn_sizes.append(self._filter_sizes[i + 1] * factor)
                        bn_layers.append(bn_layer_ind_skip)

                        bn_layer_ind_skip += 3

            self._add_batchnorm_layers(bn_sizes,
                                       no_weights,
                                       bn_layers=bn_layers,
                                       distill_bn_stats=distill_bn_stats,
                                       bn_track_stats=bn_track_stats)

        ######################################
        ### Create skip connection weights ###
        ######################################
        if self._projection_shortcut:
            layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 1

            factor = 4 if self._bottleneck_blocks else 1

            n_in = self._filter_sizes[0]
            for i in range(4):  # For each transition between conv groups.
                if not self._group_has_1x1[i]:
                    continue

                n_out = self._filter_sizes[i + 1] * factor

                skip_1x1_shape = [n_out, n_in, 1, 1]

                if not no_weights:
                    self._internal_params.append(nn.Parameter( \
                        torch.Tensor(*skip_1x1_shape), requires_grad=True))
                    self._layer_weight_tensors.append(
                        self._internal_params[-1])
                    self._layer_bias_vectors.append(None)
                    init_params(self._layer_weight_tensors[-1])
                else:
                    self._hyper_shapes_learned.append(skip_1x1_shape)
                    self._hyper_shapes_learned_ref.append( \
                        len(self.param_shapes))

                self._param_shapes.append(skip_1x1_shape)
                self._param_shapes_meta.append({
                    'name': 'weight',
                    'index': -1 if no_weights else \
                        len(self._internal_params)-1,
                    'layer': layer_ind_skip
                })

                layer_ind_skip += 3
                n_in = n_out

        ############################################################
        ### Create convolutional layers and final linear weights ###
        ############################################################
        # Convolutional layers will get IDs `l` such that `l mod 3 == 1`.
        layer_id = 1
        n_per_block = 3 if self._bottleneck_blocks else 2
        for i in range(6):
            if i == 0:  ### Fist layer.
                num = 1
                prev_fs = self._in_shape[2]
                curr_fs = self._filter_sizes[0]

                kernel_size = self._init_kernel_size
                #stride = self._init_stride
            elif i == 5:  ### Final fully-connected layer.
                num = 1
                curr_fs = num_classes

                kernel_size = None
            else:  # Group of residual blocks.
                num = self._num_blocks[i - 1] * n_per_block
                curr_fs = self._filter_sizes[i]

                kernel_size = (3, 3)  # depends on block structure!

            for n in range(num):
                if i == 5:
                    layer_shapes = [[curr_fs, prev_fs]]
                    if use_fc_bias:
                        layer_shapes.append([curr_fs])

                    prev_fs = curr_fs
                else:
                    if i > 0 and self._bottleneck_blocks:
                        if n % 3 == 0:
                            fs = curr_fs
                            ks = (1, 1)
                        elif n % 3 == 1:
                            fs = curr_fs
                            ks = kernel_size
                        else:
                            fs = 4 * curr_fs
                            ks = (1, 1)
                    elif i > 0 and not self._bottleneck_blocks:
                        fs = curr_fs
                        ks = kernel_size
                    else:
                        fs = curr_fs
                        ks = kernel_size

                    layer_shapes = [[fs, prev_fs, *ks]]
                    if use_bias:
                        layer_shapes.append([fs])

                    prev_fs = fs

                for s in layer_shapes:
                    if not no_weights:
                        self._internal_params.append(nn.Parameter( \
                            torch.Tensor(*s), requires_grad=True))
                        if len(s) == 1:
                            self._layer_bias_vectors.append( \
                                self._internal_params[-1])
                        else:
                            self._layer_weight_tensors.append( \
                                self._internal_params[-1])
                    else:
                        self._hyper_shapes_learned.append(s)
                        self._hyper_shapes_learned_ref.append( \
                            len(self.param_shapes))

                    self._param_shapes.append(s)
                    self._param_shapes_meta.append({
                        'name': 'weight' if len(s) != 1 else 'bias',
                        'index': -1 if no_weights else \
                            len(self._internal_params)-1,
                        'layer': layer_id
                    })

                layer_id += 3

                # Initialize_weights
                if not no_weights:
                    init_params(self._layer_weight_tensors[-1],
                        self._layer_bias_vectors[-1] \
                        if len(layer_shapes) == 2 else None)

        ###########################
        ### Print infos to user ###
        ###########################
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_params = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a "%s" with %d weights' \
                  % (str(self), self.num_params)
                  + (' (including %d weights associated with-' % cm_num_params
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses batchnorm.' if use_batch_norm  else ''))

        self._is_properly_setup(check_has_bias=False)
示例#12
0
def generate_replay_networks(config,
                             data_handlers,
                             device,
                             create_rp_hnet=True,
                             only_train_replay=False):
    """Create a replay model that consists of either a encoder/decoder or
    a discriminator/generator pair. Additionally, this method manages the 
    creation of a hypernetwork for the generator/decoder. 
    Following important configurations will be determined in order to create
    the replay model: 
    * in- and output and hidden layer dimensions of the encoder/decoder. 
    * architecture, chunk- and task-embedding details of decoder's hypernetwork. 

    .. note::
        This module also handles the initialisation of the weights of either 
        the classifier or its hypernetwork. This will change in the near future.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device..
        create_rp_hnet: Whether a hypernetwork for the replay should be 
            constructed. If not, the decoder/generator will have 
            trainable weights on its own.
        only_train_replay: We normally do not train on the last task since we do 
            not need to replay this last tasks data. But if we want a replay 
            method to be able to generate data from all tasks then we set this 
            option to true.

    Returns:
        (tuple): Tuple containing:

        - **enc**: Encoder/discriminator network instance.
        - **dec**: Decoder/generator networkinstance.
        - **dec_hnet**: Hypernetwork instance for the decoder/generator. This 
            return value is None if no hypernetwork should be constructed.
    """

    if config.replay_method == 'gan':
        n_out = 1
    else:
        n_out = config.latent_dim * 2

    n_in = data_handlers[0].in_shape[0]
    pd = config.padding * 2
    if config.experiment == "splitMNIST":
        n_in = n_in * n_in
    else:  # permutedMNIST
        n_in = (n_in + pd) * (n_in + pd)

    config.input_dim = n_in
    if config.experiment == "splitMNIST":
        if config.single_class_replay:
            config.out_dim = 1
        else:
            config.out_dim = 2
    else:  # permutedMNIST
        config.out_dim = 10

    if config.infer_task_id:
        # task inference network
        config.out_dim = 1

    # builld encoder
    print('For the replay encoder/discriminator: ')
    enc_arch = misc.str_to_ints(config.enc_fc_arch)
    enc = MLP(n_in=n_in,
              n_out=n_out,
              hidden_layers=enc_arch,
              activation_fn=misc.str_to_act(config.enc_net_act),
              dropout_rate=config.enc_dropout_rate,
              no_weights=False).to(device)
    print('Constructed MLP with shapes: ', enc.param_shapes)
    init_params = list(enc.parameters())
    # builld decoder
    print('For the replay decoder/generator: ')
    dec_arch = misc.str_to_ints(config.dec_fc_arch)
    # add dimensions for conditional input
    n_out = config.latent_dim

    if config.conditional_replay:
        n_out += config.conditional_dim

    dec = MLP(n_in=n_out,
              n_out=n_in,
              hidden_layers=dec_arch,
              activation_fn=misc.str_to_act(config.dec_net_act),
              use_bias=True,
              no_weights=config.rp_beta > 0,
              dropout_rate=config.dec_dropout_rate).to(device)

    print('Constructed MLP with shapes: ', dec.param_shapes)
    config.num_weights_enc = \
                        MainNetInterface.shapes_to_num_weights(enc.param_shapes)

    config.num_weights_dec = \
                        MainNetInterface.shapes_to_num_weights(dec.param_shapes)
    config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec
    # we do not need a replay model for the last task

    # train on last task or not
    if only_train_replay:
        subtr = 0
    else:
        subtr = 1

    num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1

    if config.single_class_replay:
        # we do not need a replay model for the last task
        if config.num_tasks > 1:
            num_embeddings = config.out_dim * (config.num_tasks - subtr)
        else:
            num_embeddings = config.out_dim * (config.num_tasks)

    config.num_embeddings = num_embeddings
    # build decoder hnet
    if create_rp_hnet:
        print('For the decoder/generator hypernetwork: ')
        d_hnet = sim_utils.get_hnet_model(config,
                                          num_embeddings,
                                          device,
                                          dec.hyper_shapes_learned,
                                          cprefix='rp_')

        init_params += list(d_hnet.parameters())

        config.num_weights_rp_hyper_net = sum(p.numel()
                                              for p in d_hnet.parameters()
                                              if p.requires_grad)
        config.compression_ratio_rp = config.num_weights_rp_hyper_net / \
                                                        config.num_weights_dec
        print('Created replay hypernetwork with ratio: ',
              config.compression_ratio_rp)
        if config.compression_ratio_rp > 1:
            print('Note that the compression ratio is computed compared to ' +
                  'current target network,\nthis might not be directly ' +
                  'comparable with the number of parameters of methods we ' +
                  'compare against.')
    else:
        num_embeddings = config.num_tasks - subtr
        d_hnet = None
        init_params += list(dec.parameters())
        config.num_weights_rp_hyper_net = 0
        config.compression_ratio_rp = 0

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_rp_hnet:
        for temb in d_hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if hasattr(d_hnet, 'chunk_embeddings'):
        for emb in d_hnet.chunk_embeddings:
            torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb)

    return enc, dec, d_hnet
示例#13
0
    def __init__(
            self,
            in_shape=(28, 28, 1),
            num_classes=10,
            verbose=True,
            arch='mnist_large',
            no_weights=False,
            init_weights=None,
            dropout_rate=-1,  #0.5
            **kwargs):
        super(LeNet, self).__init__(num_classes, verbose)

        self._in_shape = in_shape
        assert arch in LeNet._ARCHITECTURES.keys()
        self._chosen_arch = LeNet._ARCHITECTURES[arch]
        if num_classes != 10:
            self._chosen_arch[-2][0] = num_classes
            self._chosen_arch[-1][0] = num_classes

        # Sanity check, given current implementation.
        if arch.startswith('mnist'):
            if not in_shape[0] == in_shape[1] == 28:
                raise ValueError('MNIST LeNet architectures expect input ' +
                                 'images of size 28x28.')
        else:
            if not in_shape[0] == in_shape[1] == 32:
                raise ValueError('CIFAR LeNet architectures expect input ' +
                                 'images of size 32x32.')

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Setup class attributes ###
        assert(init_weights is None or \
               (not no_weights or not self._context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            # FIXME `nn.Dropout2d` zeroes out whole feature maps. Is that really
            # desired here?
            self._drop_conv1 = nn.Dropout2d(p=dropout_rate)
            self._drop_conv2 = nn.Dropout2d(p=dropout_rate)
            self._drop_fc1 = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod layers/weights ###
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = []  # Output shape of all context-mod layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            # Context-modulation is applied after the pooling layers.
            # So we delete the shapes of the conv-layer outputs and keep the
            # ones of the pooling layer outputs.
            del layer_out_shapes[2]
            del layer_out_shapes[0]
            cm_shapes.extend(layer_out_shapes)
            cm_layer_inds.extend(range(2, 2 * len(layer_out_shapes) + 1, 2))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]

            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ### Define and add conv- and fc-layer weights.
        for i, s in enumerate(self._chosen_arch):
            if not no_weights:
                self._internal_params.append(
                    nn.Parameter(torch.Tensor(*s), requires_grad=True))
                if len(s) == 1:
                    self._layer_bias_vectors.append(self._internal_params[-1])
                else:
                    self._layer_weight_tensors.append(
                        self._internal_params[-1])
            else:
                self._hyper_shapes_learned.append(s)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            self._param_shapes.append(s)
            self._param_shapes_meta.append({
                'name':
                'weight' if len(s) != 1 else 'bias',
                'index':
                -1 if no_weights else len(self._internal_params) - 1,
                'layer':
                2 * (i // 2) + 1
            })

        ### Initialize weights.
        if init_weights is not None:
            assert len(init_weights) == len(self.weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self.weights[i].shape))
                self.weights[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                init_params(self._layer_weight_tensors[i],
                            self._layer_bias_vectors[i])

        ### Print user info.
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_weights = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a LeNet with %d weights' % self.num_params
                  + (' (including %d weights associated with-' % cm_num_weights
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses dropout.' if dropout_rate != -1 \
                     else ''))

        self._is_properly_setup()
示例#14
0
文件: resnet.py 项目: limberc/hypercl
    def __init__(self,
                 in_shape=[32, 32, 3],
                 num_classes=10,
                 verbose=True,
                 n=5,
                 no_weights=False,
                 init_weights=None,
                 use_batch_norm=True,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_apply_pixel_wise=False):
        super(ResNet, self).__init__(num_classes, verbose)

        self._in_shape = in_shape
        self._n = n

        assert (init_weights is None or \
                (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights

        assert (not use_batch_norm or (not distill_bn_stats or bn_track_stats))

        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm

        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise

        self._kernel_size = [3, 3]
        self._filter_sizes = [16, 16, 32, 64]

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        # We don't use any output non-linearity.
        self._has_linear_out = True

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []

        #################################################
        ### Define and initialize context mod weights ###
        #################################################
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_shapes = []  # Output shape of all layers.
            if context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
            layer_out_shapes = self._compute_layer_out_sizes()
            if no_last_layer_context_mod:
                cm_shapes.extend(layer_out_shapes[:-1])
            else:
                cm_shapes.extend(layer_out_shapes)

            if not context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            for i, s in enumerate(cm_shapes):
                cmod_layer = ContextModLayer(
                    s,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert (np.all(np.equal( \
                            list(init_weights[cm_ind].shape),
                            list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]

        ###########################
        ### Print infos to user ###
        ###########################
        # Compute the total number of weights in this network and display
        # them to the user.
        # Note, this complicated calculation is not necessary as we can simply
        # count the number of weights afterwards. But it's an additional sanity
        # check for us.
        fs = self._filter_sizes
        num_weights = np.prod(self._kernel_size) * \
                      (in_shape[2] * fs[0] + np.sum([fs[i] * fs[i + 1] + \
                                                     (2 * n - 1) * fs[i + 1] ** 2 for i in range(3)])) + \
                      (fs[0] + 2 * n * np.sum([fs[i] for i in range(1, 4)])) + \
                      (fs[-1] * num_classes + num_classes)

        cm_num_weights = MainNetInterface.shapes_to_num_weights( \
            self._context_mod_shapes) if use_context_mod else 0
        num_weights += cm_num_weights

        if use_batch_norm:
            # The gamma and beta parameters of a batch norm layer are
            # learned as well.
            num_weights += 2 * (fs[0] + \
                                2 * n * np.sum([fs[i] for i in range(1, 4)]))

        if verbose:
            print('A ResNet with %d layers and %d weights is created' \
                  % (6 * n + 2, num_weights)
                  + (' (including %d context-mod weights).' % cm_num_weights \
                         if cm_num_weights > 0 else '.'))

        ################################################
        ### Define and initialize batch norm weights ###
        ################################################
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    num = 1
                else:
                    num = 2 * n

                for j in range(num):
                    bn_layer = BatchNormLayer(
                        s,
                        affine=not no_weights,
                        track_running_stats=bn_track_stats)
                    self._batchnorm_layers.append(bn_layer)

                    if distill_bn_stats:
                        self._hyper_shapes_distilled.extend( \
                            [list(p.shape) for p in bn_layer.get_stats(0)])

                    if not no_weights and init_weights is not None:
                        assert (len(bn_layer.weights) == 2)
                        for ii in range(2):
                            assert (np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                            bn_layer.weights[ii].data = init_weights[bn_ind]
                            bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        # Note, method `_compute_hyper_shapes` doesn't take context-mod into
        # consideration.
        self._param_shapes.extend(self._compute_hyper_shapes(no_weights=True))
        assert (num_weights == \
                MainNetInterface.shapes_to_num_weights(self._param_shapes))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            if self._hyper_shapes_learned is None:
                self._hyper_shapes_learned = self._compute_hyper_shapes()
            else:
                # Context-mod weights are already included.
                self._hyper_shapes_learned.extend(self._compute_hyper_shapes())

            self._is_properly_setup()
            return

        if use_batch_norm:
            for bn_layer in self._batchnorm_layers:
                self._weights.extend(bn_layer.weights)

        ############################################
        ### Define and initialize layer weights ###
        ###########################################
        ### Does not include context-mod or batchnorm weights.
        # First layer.
        self._layer_weight_tensors.append(
            nn.Parameter(torch.Tensor(self._filter_sizes[0], self._in_shape[2],
                                      *self._kernel_size),
                         requires_grad=True))
        self._layer_bias_vectors.append(
            nn.Parameter(torch.Tensor(self._filter_sizes[0]),
                         requires_grad=True))

        # Each block consists of 2n layers.
        for i in range(1, len(self._filter_sizes)):
            in_filters = self._filter_sizes[i - 1]
            out_filters = self._filter_sizes[i]

            for _ in range(2 * n):
                self._layer_weight_tensors.append(
                    nn.Parameter(torch.Tensor(out_filters, in_filters,
                                              *self._kernel_size),
                                 requires_grad=True))
                self._layer_bias_vectors.append(
                    nn.Parameter(torch.Tensor(out_filters),
                                 requires_grad=True))
                # Note, that the first layer in this block has potentially a
                # different number of input filters.
                in_filters = out_filters

        # After the average pooling, there is one more dense layer.
        self._layer_weight_tensors.append(
            nn.Parameter(torch.Tensor(num_classes, self._filter_sizes[-1]),
                         requires_grad=True))
        self._layer_bias_vectors.append(
            nn.Parameter(torch.Tensor(num_classes), requires_grad=True))

        # We add the weights interleaved, such that there are always consecutive
        # weight tensor and bias vector per layer. This fulfils the requirements
        # of attribute `mask_fc_out`.
        for i in range(len(self._layer_weight_tensors)):
            self._weights.append(self._layer_weight_tensors[i])
            self._weights.append(self._layer_bias_vectors[i])

        ### Initialize weights.
        if init_weights is not None:
            num_layers = 6 * n + 2
            assert (len(init_weights) == 2 * num_layers)
            offset = 0
            if use_batch_norm:
                offset = 2 * (6 * n + 1)
            assert (len(self._weights) == offset + 2 * num_layers)
            for i in range(len(init_weights)):
                j = offset + i
                assert (np.all(
                    np.equal(list(init_weights[i].shape),
                             list(self._weights[j].shape))))
                self._weights[j].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                init_params(self._layer_weight_tensors[i],
                            self._layer_bias_vectors[i])

        self._is_properly_setup()
    def __init__(self, in_shape=(32, 32, 3), num_classes=10, n=4, k=10,
                 num_feature_maps=(16, 16, 32, 64), use_bias=True,
                 use_fc_bias=None, no_weights=False, use_batch_norm=True,
                 bn_track_stats=True, distill_bn_stats=False, dropout_rate=-1,
                 chw_input_format=False, verbose=True, **kwargs):
        super(WRN, self).__init__(num_classes, verbose)

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Check or parse remaining arguments ###
        self._in_shape = in_shape
        self._n = n
        self._k = k
        if use_fc_bias is None:
            use_fc_bias = use_bias
        # Also, checkout attribute `_has_bias` below.
        self._use_bias = use_bias
        self._use_fc_bias = use_fc_bias
        self._no_weights = no_weights
        assert not use_batch_norm or (not distill_bn_stats or bn_track_stats)
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._dropout_rate = dropout_rate
        self._chw_input_format = chw_input_format

        # The original authors found that the best configuration uses this
        # kernel in all convolutional layers.
        self._kernel_size = (3, 3)
        if len(num_feature_maps) != 4:
            raise ValueError('Option "num_feature_maps" must be a list of 4 ' +
                             'integers.')
        self._filter_sizes = list(num_feature_maps)
        if k != 1:
            for i in range(1, 4):
                self._filter_sizes[i] = k * num_feature_maps[i]
        # Strides used in the first layer of each convolutional group.
        self._strides = (1, 1, 2, 2)

        ### Set required class attributes ###
        # Note, we did overwrite the getter for attribute `has_bias`, as it is
        # not applicable if the values of `use_bias` and `use_fc_bias` differ.
        self._has_bias = use_bias if use_bias == use_fc_bias else False
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if dropout_rate != -1:
            assert dropout_rate >= 0. and dropout_rate <= 1.
            self._dropout = nn.Dropout(p=dropout_rate)

        #################################
        ### Create context mod layers ###
        #################################
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = [] # Output shape of all layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            cm_shapes.extend(layer_out_shapes)
            # All layer indices `l` with `l mod 3 == 0` are context-mod layers.
            cm_layer_inds.extend(range(3, 3*len(layer_out_shapes)+1, 3))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]
            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ###############################
        ### Create batchnorm layers ###
        ###############################
        # We just use even numbers starting from 2 as layer indices for
        # batchnorm layers.
        if use_batch_norm:
            bn_sizes = []
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    bn_sizes.append(s)
                else:
                    bn_sizes.extend([s] * (2*n))

            # All layer indices `l` with `l mod 3 == 2` are batchnorm layers.
            self._add_batchnorm_layers(bn_sizes, no_weights,
                bn_layers=list(range(2, 3*len(bn_sizes)+1, 3)),
                distill_bn_stats=distill_bn_stats,
                bn_track_stats=bn_track_stats)

        ######################################
        ### Create skip connection weights ###
        ######################################
        # We use 1x1 convolutional layers for residual blocks in case the
        # number of input and output feature maps disagrees. We also use 1x1
        # convolutions whenever a stride greater than 1 is applied. This is not
        # necessary in my opinion (as it adds extra weights that do not affect
        # the downsampling itself), but commonly done; or instance, in the
        # original PyTorch implementation.
        # Note, there may be maximally 3 1x1 layers added to the network.
        # Note, we use 1x1 conv layers without biases.
        skip_1x1_shapes = []
        self._group_has_1x1 = [False] * 3
        for i in range(1, 4):
            if self._filter_sizes[i-1] != self._filter_sizes[i] or \
                    self._strides[i] != 1:
                skip_1x1_shapes.append([self._filter_sizes[i],
                                        self._filter_sizes[i-1], 1, 1])
                self._group_has_1x1[i-1] = True

        for s in skip_1x1_shapes:
            if not no_weights:
                self._internal_params.append(nn.Parameter( \
                    torch.Tensor(*s), requires_grad=True))
                self._layer_weight_tensors.append(self._internal_params[-1])
                init_params(self._layer_weight_tensors[-1])
            else:
                self._hyper_shapes_learned.append(s)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            self._param_shapes.append(s)
            self._param_shapes_meta.append({
                'name': 'weight',
                'index': -1 if no_weights else \
                    len(self._internal_params)-1,
                'layer': -1
            })

        ############################################################
        ### Create convolutional layers and final linear weights ###
        ############################################################
        # Convolutional layers will get IDs `l` such that `l mod 3 == 1`.
        layer_id = 1
        for i in range(5):
            if i == 0: ### Fist layer.
                num = 1
                prev_fs = self._in_shape[2]
                curr_fs = self._filter_sizes[0]
            elif i == 4: ### Final fully-connected layer.
                num = 1
                curr_fs = num_classes
            else: # Group of residual blocks.
                num = 2 * n
                curr_fs = self._filter_sizes[i]

            for _ in range(num):
                if i == 4:
                    layer_shapes = [[curr_fs, prev_fs]]
                    if use_fc_bias:
                        layer_shapes.append([curr_fs])
                else:
                    layer_shapes = [[curr_fs, prev_fs, *self._kernel_size]]
                    if use_bias:
                        layer_shapes.append([curr_fs])

                for s in layer_shapes:
                    if not no_weights:
                        self._internal_params.append(nn.Parameter( \
                            torch.Tensor(*s), requires_grad=True))
                        if len(s) == 1:
                            self._layer_bias_vectors.append( \
                                self._internal_params[-1])
                        else:
                            self._layer_weight_tensors.append( \
                                self._internal_params[-1])
                    else:
                        self._hyper_shapes_learned.append(s)
                        self._hyper_shapes_learned_ref.append( \
                            len(self.param_shapes))

                    self._param_shapes.append(s)
                    self._param_shapes_meta.append({
                        'name': 'weight' if len(s) != 1 else 'bias',
                        'index': -1 if no_weights else \
                            len(self._internal_params)-1,
                        'layer': layer_id
                    })

                prev_fs = curr_fs
                layer_id += 3

                # Initialize_weights
                if not no_weights:
                    init_params(self._layer_weight_tensors[-1],
                        self._layer_bias_vectors[-1] \
                        if len(layer_shapes) == 2 else None)

        ###########################
        ### Print infos to user ###
        ###########################
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_params = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a WideResnet "%s" with %d weights' \
                  % (str(self), self.num_params)
                  + (' (including %d weights associated with-' % cm_num_params
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses batchnorm.' if use_batch_norm  else '')
                  + (' The network uses dropout.' if dropout_rate != -1 \
                     else ''))

        self._is_properly_setup(check_has_bias=False)
示例#16
0
    def __init__(self, weight_shapes, activation_fn=torch.nn.ReLU(),
                 use_bias=True, no_weights=False, init_weights=None,
                 dropout_rate=-1, out_fn=None, verbose=True, 
                 use_spectral_norm=False, use_batch_norm=False):
        """Initialize the network.

        Args:
            weight_shapes: A list of list of integers, denoting the shape of
                each parameter tensor in this network. Note, this parameter only
                has an effect on the construction of this network, if
                "no_weights" is False. Otherwise, it is just used to check the
                shapes of the input to the network in the forward method.
            activation_fn: The nonlinearity used in hidden layers. If None, no
                nonlinearity will be applied.
            use_bias: Whether layers may have bias terms.
            no_weights: If set to True, no trainable parameters will be
                constructed, i.e., weights are assumed to be produced ad-hoc
                by a hypernetwork and passed to the forward function.
            init_weights (optional): This option is for convinience reasons.
                The option expects a list of parameter values that are used to
                initialize the network weights. As such, it provides a
                convinient way of initializing a network with a weight draw
                produced by the hypernetwork.
            dropout_rate: If -1, no dropout will be applied. Otherwise a number
                between 0 and 1 is expected, denoting the dropout rate of hidden
                layers.
            out_fn (optional): If provided, this function will be applied to the
                output neurons of the network. Note, this changes the output
                of the forward method.
            verbose: Whether to print the number of weights in the network.
            use_spectral_norm: Use spectral normalization for training.
            use_batch_norm: Whether batch normalization should be used.
        """
        # FIXME find a way using super to handle multiple inheritence.
        #super(MainNetwork, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        warn('Please use class "mnets.mlp.MLP" instead.', DeprecationWarning)

        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')
        if use_batch_norm:
            raise NotImplementedError('Batch normalization not yet ' +
                                      'implemented for this network.')

        assert(len(weight_shapes) > 0)
        self._all_shapes = weight_shapes
        self._has_bias = use_bias
        self._a_fun = activation_fn
        assert(init_weights is None or no_weights is False)
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        self._out_fn = out_fn

        self._has_fc_out = True

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x : x # identity


        if verbose:
            print('Creating an MLP with %d weights' \
                      % (MnetAPIV2.shapes_to_num_weights(self._all_shapes))
                      + (', that uses dropout.' if dropout_rate != -1 else '.'))

        if dropout_rate != -1:
            assert(dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        self._weights = None
        if no_weights:
            self._hyper_shapes = self._all_shapes
            self._is_properly_setup()
            return

        ### Define and initialize network weights.
        # Each odd entry of this list will contain a weight Tensor and each
        # even entry a bias vector.
        self._weights = nn.ParameterList()

        for i, dims in enumerate(self._all_shapes):
            self._weights.append(nn.Parameter(torch.Tensor(*dims),
                                              requires_grad=True))

        if init_weights is not None:
            assert(len(init_weights) == len(self._all_shapes))
            for i in range(len(init_weights)):
                assert(np.all(np.equal(list(init_weights[i].shape),
                                       list(self._weights[i].shape))))
                self._weights[i].data = init_weights[i]
        else:
            for i in range(0, len(self._weights), 2 if use_bias else 1):
                if use_bias:
                    init_params(self._weights[i], self._weights[i + 1])
                else:
                    init_params(self._weights[i])

        self._is_properly_setup()
    def __init__(self,
                 n_in,
                 n_out=1,
                 inp_chunk_dim=100,
                 out_chunk_dim=10,
                 cemb_size=8,
                 cemb_init_std=1.,
                 red_layers=(10, 10),
                 net_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 dynamic_biases=None,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        self._n_in = n_in
        self._n_out = n_out
        self._inp_chunk_dim = inp_chunk_dim
        self._out_chunk_dim = out_chunk_dim
        self._cemb_size = cemb_size
        self._a_fun = activation_fn
        self._no_weights = no_weights

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True  # Ensure that `out_fn` is `None`!

        self._param_shapes = []
        #self._param_shapes_meta = [] # TODO implement!
        self._weights = None if no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None if not no_weights else []
        #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
        #    is None else [] # TODO implement.

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        self._context_mod_layers = None
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        #################################
        ### Generate Chunk Embeddings ###
        #################################
        self._num_cembs = int(np.ceil(n_in / inp_chunk_dim))
        last_chunk_size = n_in % inp_chunk_dim
        if last_chunk_size != 0:
            self._pad = inp_chunk_dim - last_chunk_size
        else:
            self._pad = -1

        cemb_shape = [self._num_cembs, cemb_size]
        self._param_shapes.append(cemb_shape)
        if no_weights:
            self._cembs = None
            self._hyper_shapes_learned.append(cemb_shape)
        else:
            self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape),
                                       requires_grad=True)
            nn.init.normal_(self._cembs, mean=0., std=cemb_init_std)

            self._weights.append(self._cembs)

        ############################
        ### Setup Dynamic Biases ###
        ############################
        self._has_dyn_bias = None
        if dynamic_biases is not None:
            assert np.all(np.array(dynamic_biases) >= 0) and \
                   np.all(np.array(dynamic_biases) < len(red_layers) + 1)
            dynamic_biases = np.sort(np.unique(dynamic_biases))

            # For each layer in the `reducer`, where we want to apply a dynamic
            # bias, we have to create a weight matrix for a corresponding
            # linear layer (we just ignore)
            self._dyn_bias_weights = nn.ModuleList()
            self._has_dyn_bias = []

            for i in range(len(red_layers) + 1):
                if i in dynamic_biases:
                    self._has_dyn_bias.append(True)

                    trgt_dim = out_chunk_dim
                    if i < len(red_layers):
                        trgt_dim = red_layers[i]
                    trgt_shape = [trgt_dim, cemb_size]

                    self._param_shapes.append(trgt_shape)
                    if not no_weights:
                        self._dyn_bias_weights.append(None)
                        self._hyper_shapes_learned.append(trgt_shape)
                    else:
                        self._dyn_bias_weights.append(nn.Parameter( \
                            torch.Tensor(*trgt_shape), requires_grad=True))
                        self._weights.append(self._dyn_bias_weights[-1])

                        init_params(self._dyn_bias_weights[-1])

                        self._layer_weight_tensors.append( \
                            self._dyn_bias_weights[-1])
                        self._layer_bias_vectors.append(None)
                else:
                    self._has_dyn_bias.append(False)
                    self._dyn_bias_weights.append(None)

        ################################
        ### Create `Reducer` Network ###
        ################################
        red_inp_dim = inp_chunk_dim + \
            (cemb_size if dynamic_biases is None else 0)
        self._reducer = MLP(
            n_in=red_inp_dim,
            n_out=out_chunk_dim,
            hidden_layers=red_layers,
            activation_fn=activation_fn,
            use_bias=use_bias,
            no_weights=no_weights,
            init_weights=None,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm,
            bn_track_stats=bn_track_stats,
            distill_bn_stats=distill_bn_stats,
            # We use context modulation to realize dynamic biases, since they
            # allow a different modulation per sample in the input mini-batch.
            # Hence, we can process several chunks in parallel with the reducer
            # network.
            use_context_mod=not dynamic_biases is None,
            context_mod_inputs=False,
            no_last_layer_context_mod=False,
            context_mod_no_weights=True,
            context_mod_post_activation=False,
            context_mod_gain_offset=False,
            context_mod_gain_softplus=False,
            out_fn=None,
            verbose=True)

        if dynamic_biases is not None:
            # FIXME We have to extract the param shapes from
            # `self._reducer.param_shapes`, as well as from
            # `self._reducer._hyper_shapes_learned` that belong to context-mod
            # layers. We may not add them to our own `param_shapes` attribute,
            # as these are not parameters (due to our misuse of the context-mod
            # layers).
            # Note, in the `forward` method, we need to supply context-mod
            # weights for all reducer networks, independent on whether they have
            # a dynamic bias or not. We can do so, by providing constant ones
            # for all gains and constance zero-shift for all layers without
            # dynamic biases (note, we need to ensure the correct batch dim!).
            raise NotImplementedError(
                'Dynamic biases are not yet implemented!')

        assert self._reducer._context_mod_layers is None

        ### Overtake all attributes from the underlying MLP.
        for s in self._reducer.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._reducer._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._reducer._weights:
                self._weights.append(p)

        for p in self._reducer._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._reducer._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._reducer._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._reducer._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = []
            for s in self._reducer._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        ###############################
        ### Create Actual `Network` ###
        ###############################
        net_inp_dim = out_chunk_dim * self._num_cembs
        self._network = MLP(n_in=net_inp_dim,
                            n_out=n_out,
                            hidden_layers=net_layers,
                            activation_fn=activation_fn,
                            use_bias=use_bias,
                            no_weights=no_weights,
                            init_weights=None,
                            dropout_rate=dropout_rate,
                            use_spectral_norm=use_spectral_norm,
                            use_batch_norm=use_batch_norm,
                            bn_track_stats=bn_track_stats,
                            distill_bn_stats=distill_bn_stats,
                            use_context_mod=False,
                            out_fn=None,
                            verbose=True)

        ### Overtake all attributes from the underlying MLP.
        for s in self._network.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._network._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._network._weights:
                self._weights.append(p)

        for p in self._network._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._network._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._network._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._hyper_shapes_distilled is not None:
            assert self._network._hyper_shapes_distilled is not None
            for s in self._network._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        #####################################
        ### Takeover given Initialization ###
        #####################################
        if init_weights is not None:
            assert len(init_weights) == len(self._weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self._param_shapes[i]))
                self._weights[i].data = init_weights[i]

        ######################
        ### Finalize Setup ###
        ######################
        num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes)
        print('Constructed MLP that processes dimensionality reduced inputs ' +
              'through chunking. The network has a total of %d weights.' %
              num_weights)

        self._is_properly_setup()