예제 #1
0
    def __init__(self, num_classes, verbose):
        """Initialize the network.

        Args:
            num_classes: The number of output neurons.
            verbose: Allow printing of general information about the generated
                network (such as number of weights).
        """
        # FIXME find a way using super to handle multiple inheritence.
        #super(Classifier, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        assert(num_classes > 0)
        self._num_classes = num_classes

        self._verbose = verbose
예제 #2
0
    def __init__(self,
                 mnet,
                 no_mean_reinit=False,
                 logvar_encoding=False,
                 apply_rho_offset=False,
                 is_radial=False):
        # FIXME find a way using super to handle multiple inheritance.
        #super(GaussianBNNWrapper, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        assert isinstance(mnet, MainNetInterface)
        assert not isinstance(mnet, GaussianBNNWrapper)

        if is_radial:
            print('Converting network into BNN with radial weight ' +
                  'distribution ...')
        else:
            print(
                'Converting network into BNN with diagonal Gaussian weight ' +
                'distribution ...')

        self._mnet = mnet
        self._logvar_encoding = logvar_encoding
        self._apply_rho_offset = apply_rho_offset
        self._rho_offset = -2.5
        self._is_radial = is_radial

        # Take over attributes of `mnet` and modify them if necessary.
        self._mean_params = None
        self._rho_params = None
        if mnet.internal_params is not None:
            self._mean_params = mnet.internal_params
            self._rho_params = nn.ParameterList()

            for p in self._mean_params:
                self._rho_params.append(
                    nn.Parameter(torch.Tensor(p.size()), requires_grad=True))

            # Initialize weights.
            if not no_mean_reinit:
                for p in self._mean_params:
                    p.data.uniform_(-0.1, 0.1)

            for p in self._rho_params:
                if apply_rho_offset:
                    # We will subtract 2.5 from `rho` in the forward method.
                    #p.data.uniform_(-0.5, 0.5)
                    p.data.uniform_(-3 - self._rho_offset,
                                    -2 - self._rho_offset)
                else:
                    p.data.uniform_(-3, -2)

            self._internal_params = nn.ParameterList()
            for p in self._mean_params:
                self._internal_params.append(p)
            for p in self._rho_params:
                self._internal_params.append(p)

        # Simply duplicate `param_shapes` and `hyper_shapes_learned`.
        self._param_shapes = mnet.param_shapes + mnet.param_shapes
        if mnet._param_shapes_meta is not None:
            self._param_shapes_meta = []
            old_wlen = 0  if self.internal_params is None \
                else len(mnet.internal_params)
            for dd in mnet._param_shapes_meta:
                dd['dist_param'] = 'mean'
                self._param_shapes_meta.append(dd)

            for dd_old in mnet._param_shapes_meta:
                dd = dict(dd_old)
                dd['index'] += old_wlen
                dd['dist_param'] = 'rho'
                self._param_shapes_meta.append(dd)

        if mnet._hyper_shapes_learned is not None:
            self._hyper_shapes_learned = mnet._hyper_shapes_learned + \
                mnet._hyper_shapes_learned
        if mnet._hyper_shapes_learned_ref is not None:
            self._hyper_shapes_learned_ref = \
                list(mnet._hyper_shapes_learned_ref)
            old_plen = len(mnet.param_shapes)
            for ii in mnet._hyper_shapes_learned_ref:
                self._hyper_shapes_learned_ref.append(ii + old_plen)

        self._hyper_shapes_distilled = mnet._hyper_shapes_distilled
        if self._hyper_shapes_distilled is not None:
            # In general, that shouldn't be an issue, as those distilled values
            # are just things like batchnorm stats. But it might be good to
            # inform the user about the fact that we are not considering this
            # attribute as special.
            warn('Class "GaussianBNNWrapper" doesn\'t modify the existing ' +
                 'attribute "hyper_shapes_distilled".')

        self._has_bias = mnet._has_bias
        self._has_fc_out = mnet._has_fc_out
        # Note, it's still true that the last two entries of
        # `hyper_shapes_learned` are belonging to the output layer. But those
        # are only the variance weights. So, we would forget about the mean
        # weights when setting this quantitiy to true.
        self._mask_fc_out = False  #mnet._mask_fc_out
        self._has_linear_out = mnet._has_linear_out

        # We don't modify the following attributed, but generate warnings
        # when using them.
        self._layer_weight_tensors = mnet._layer_weight_tensors
        self._layer_bias_vectors = mnet._layer_bias_vectors
        self._batchnorm_layers = mnet._batchnorm_layers
        self._context_mod_layers = mnet._context_mod_layers

        self._is_properly_setup(check_has_bias=False)
예제 #3
0
파일: mlp.py 프로젝트: ZixuanKe/CAT
    def __init__(self,
                 n_in=1,
                 n_out=1,
                 hidden_layers=[2000, 2000],
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 out_fn=None,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritence.
        #super(MainNetwork, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')

        if use_batch_norm and use_context_mod:
            # FIXME Does it make sense to have both enabled?
            # I.e., should we produce a warning or error?
            pass

        self._a_fun = activation_fn
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        #self._use_spectral_norm = use_spectral_norm
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._out_fn = out_fn

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True if out_fn is None else False

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        # FIXME make sure that this implementation is correct in all situations
        # (e.g., what to do if weights are passed to the forward method?).
        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x: x  # identity

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_sizes = []
            if context_mod_inputs:
                cm_sizes.append(n_in)
            cm_sizes.extend(hidden_layers)
            if not no_last_layer_context_mod:
                cm_sizes.append(n_out)

            for i, n in enumerate(cm_sizes):
                cmod_layer = ContextModLayer(
                    n,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[cm_ind].shape),
                                list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]
        print('hidden_layers', hidden_layers)

        ### Define and initialize batch norm weights.
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, n in enumerate(hidden_layers):
                bn_layer = BatchNormLayer(n,
                                          affine=not no_weights,
                                          track_running_stats=bn_track_stats)
                self._batchnorm_layers.append(bn_layer)
                self._param_shapes.extend(bn_layer.param_shapes)

                if no_weights:
                    self._hyper_shapes_learned.extend(bn_layer.param_shapes)
                else:
                    self._weights.extend(bn_layer.weights)

                if distill_bn_stats:
                    self._hyper_shapes_distilled.extend( \
                        [list(p.shape) for p in bn_layer.get_stats(0)])

                # FIXME ugly code. Move initialization somewhere else.
                if not no_weights and init_weights is not None:
                    assert (len(bn_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                        bn_layer.weights[ii].data = init_weights[bn_ind]
                        bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        # Compute shapes of linear layers.
        linear_shapes = MLP.weight_shapes(n_in=n_in,
                                          n_out=n_out,
                                          hidden_layers=hidden_layers,
                                          use_bias=use_bias)
        self._param_shapes.extend(linear_shapes)

        num_weights = MainNetInterface.shapes_to_num_weights(
            self._param_shapes)

        if verbose:
            if use_context_mod:
                cm_num_weights = 0
                for cm_layer in self._context_mod_layers:
                    cm_num_weights += MainNetInterface.shapes_to_num_weights( \
                        cm_layer.param_shapes)

            print(
                'Creating an MLP with %d weights' % num_weights +
                (' (including %d weights associated with-' % cm_num_weights +
                 'context modulation)' if use_context_mod else '') + '.' +
                (' The network uses dropout.' if dropout_rate != -1 else '') +
                (' The network uses batchnorm.' if use_batch_norm else ''))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._hyper_shapes_learned.extend(linear_shapes)
            self._is_properly_setup()
            return

        ### Define and initialize linear weights.
        for i, dims in enumerate(linear_shapes):
            self._weights.append(
                nn.Parameter(torch.Tensor(*dims), requires_grad=True))
            if len(dims) == 1:
                self._layer_bias_vectors.append(self._weights[-1])
            else:
                self._layer_weight_tensors.append(self._weights[-1])

        if init_weights is not None:
            assert (len(init_weights) == len(linear_shapes))
            for i in range(len(init_weights)):
                assert (np.all(
                    np.equal(list(init_weights[i].shape), linear_shapes[i])))
                if use_bias:
                    if i % 2 == 0:
                        self._layer_weight_tensors[i //
                                                   2].data = init_weights[i]
                    else:
                        self._layer_bias_vectors[i // 2].data = init_weights[i]
                else:
                    self._layer_weight_tensors[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                if use_bias:
                    init_params(self._layer_weight_tensors[i],
                                self._layer_bias_vectors[i])
                else:
                    init_params(self._layer_weight_tensors[i])

        self._is_properly_setup()
예제 #4
0
    def __init__(self,
                 rnn_args={},
                 mlp_args=None,
                 preprocess_fct=None,
                 no_weights=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        assert isinstance(rnn_args, (dict, list, tuple))
        assert mlp_args is None or isinstance(mlp_args, dict)

        if isinstance(rnn_args, dict):
            rnn_args = [rnn_args]

        self._forward_rnns = []
        self._backward_rnns = []
        self._out_mlp = None
        self._preprocess_fct = preprocess_fct
        self._forward_called = False

        # FIXME At the moment we do not control input and output size of
        # individual networks and need to assume that the user sets them
        # correctly.

        ### Create all forward and backward nets for each bidirectional layer.
        for rargs in rnn_args:
            assert isinstance(rargs, dict)
            if 'verbose' not in rargs.keys():
                rargs['verbose'] = False
            if 'no_weights' in rargs.keys() and \
                    rargs['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'bidirectional layer is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in rargs.keys():
                rargs['no_weights'] = no_weights

            self._forward_rnns.append(SimpleRNN(**rargs))
            self._backward_rnns.append(SimpleRNN(**rargs))

        ### Create output network.
        if mlp_args is not None:
            if 'verbose' not in mlp_args.keys():
                mlp_args['verbose'] = False
            if 'no_weights' in mlp_args.keys() and \
                    mlp_args['no_weights'] != no_weights:
                raise ValueError('Keyword argument "no_weights" of ' +
                                 'output MLP is in conflict with ' +
                                 'constructor argument "no_weights".')
            elif 'no_weights' not in mlp_args.keys():
                mlp_args['no_weights'] = no_weights

            self._out_mlp = MLP(**mlp_args)

        ### Set all interface attributes correctly.
        if self._out_mlp is None:
            self._has_fc_out = self._forward_rnns[-1].has_fc_out
            # We can't set the following attribute to true, as the output is
            # a concatenation of the outputs from two networks. Therefore, the
            # weights used two compute the outputs are at different locations
            # in the `param_shapes` list.
            self._mask_fc_out = False
            self._has_linear_out = self._forward_rnns[-1].has_linear_out
        else:
            self._has_fc_out = self._out_mlp.has_fc_out
            self._mask_fc_out = self._out_mlp.mask_fc_out
            self._has_linear_out = self._out_mlp.has_linear_out

        # Collect all internal net objects from which we need to collect
        # attributes.
        nets = []
        for i, fnet in enumerate(self._forward_rnns):
            bnet = self._backward_rnns[i]

            nets.append((fnet, 'forward_rnn', i))
            nets.append((bnet, 'backward_rnn', i))
        if self._out_mlp is not None:
            nets.append((self._out_mlp, 'out_mlp', -1))

        # Iterate over all nets to collect their attribute values.
        self._param_shapes = []
        self._param_shapes_meta = []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        for i, net_tup in enumerate(nets):
            net, net_type, net_id = net_tup
            # Note, it is important to convert lists into new object and not
            # just copy references!
            # Note, we have to adapt all references if `i > 0`.

            # Sanity check:
            if i == 0:
                cm_nw = net._context_mod_no_weights
            elif cm_nw != net._context_mod_no_weights:
                raise ValueError('Network expect that either all internal ' +
                                 'networks maintain their context-mod ' +
                                 'weights or non of them does!')

            ps_len_old = len(self._param_shapes)

            if net._internal_params is not None:
                if self._internal_params is None:
                    self._internal_params = nn.ParameterList()
                ip_len_old = len(self._internal_params)
                self._internal_params.extend( \
                    nn.ParameterList(net._internal_params))
            self._param_shapes.extend(list(net._param_shapes))
            for meta in net.param_shapes_meta:
                assert 'birnn_layer_type' not in meta.keys()
                assert 'birnn_layer_id' not in meta.keys()

                new_meta = dict(meta)
                new_meta['birnn_layer_type'] = net_type
                new_meta['birnn_layer_id'] = net_id
                if i > 0:
                    # FIXME We should properly adjust colliding `layer` IDs.
                    new_meta['layer'] = -1
                new_meta['index'] = meta['index'] + ip_len_old
                self._param_shapes_meta.append(new_meta)

            if net._hyper_shapes_learned is not None:
                if self._hyper_shapes_learned is None:
                    self._hyper_shapes_learned = []
                    self._hyper_shapes_learned_ref = []
                self._hyper_shapes_learned.extend( \
                    list(net._hyper_shapes_learned))
                for ref in net._hyper_shapes_learned_ref:
                    self._hyper_shapes_learned_ref.append(ref + ps_len_old)
            if net._hyper_shapes_distilled is not None:
                if self._hyper_shapes_distilled is None:
                    self._hyper_shapes_distilled = []
                self._hyper_shapes_distilled.extend( \
                    list(net._hyper_shapes_distilled))

            if self._has_bias is None:
                self._has_bias = net._has_bias
            elif self._has_bias != net._has_bias:
                self._has_bias = False
                # FIXME We should overwrite the getter and throw an error!
                warn('Some internally maintained networks use biases, ' +
                     'while others don\'t. Setting attribute "has_bias" to ' +
                     'False.')

            self._layer_weight_tensors.extend( \
                nn.ParameterList(net._layer_weight_tensors))
            self._layer_bias_vectors.extend( \
                nn.ParameterList(net._layer_bias_vectors))
            if net._batchnorm_layers is not None:
                if self._batchnorm_layers is None:
                    self._batchnorm_layers = nn.ModuleList()
                self._batchnorm_layers.extend( \
                    nn.ModuleList(net._batchnorm_layers))
            if net._context_mod_layers is not None:
                if self._context_mod_layers is None:
                    self._context_mod_layers = nn.ModuleList()
                self._context_mod_layers.extend( \
                    nn.ModuleList(net._context_mod_layers))

        self._is_properly_setup()

        ### Print user information.
        if verbose:
            print('Constructed Bidirectional RNN with %d weights.' \
                  % self.num_params)
예제 #5
0
    def __init__(self,
                 n_in=1,
                 n_out=1,
                 hidden_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_gain_softplus=False,
                 out_fn=None,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        # FIXME Spectral norm is incorrectly implemented. Function
        # `nn.utils.spectral_norm` needs to be called in the constructor, such
        # that sepc norm is wrapped around a module.
        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')

        if use_batch_norm and use_context_mod:
            # FIXME Does it make sense to have both enabled?
            # I.e., should we produce a warning or error?
            pass

        self._a_fun = activation_fn
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        #self._use_spectral_norm = use_spectral_norm
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_gain_softplus = context_mod_gain_softplus
        self._out_fn = out_fn

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True if out_fn is None else False

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        # FIXME make sure that this implementation is correct in all situations
        # (e.g., what to do if weights are passed to the forward method?).
        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x: x  # identity

        self._param_shapes = []
        self._param_shapes_meta = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_sizes = []
            if context_mod_inputs:
                cm_sizes.append(n_in)
            cm_sizes.extend(hidden_layers)
            if not no_last_layer_context_mod:
                cm_sizes.append(n_out)

            for i, n in enumerate(cm_sizes):
                cmod_layer = ContextModLayer(
                    n,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset,
                    apply_gain_softplus=context_mod_gain_softplus)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                assert len(cmod_layer.param_shapes) == 2
                self._param_shapes_meta.extend([
                    {'name': 'cm_scale',
                     'index': -1 if context_mod_no_weights else \
                         len(self._weights),
                     'layer': -1}, # 'layer' is set later.
                    {'name': 'cm_shift',
                     'index': -1 if context_mod_no_weights else \
                         len(self._weights)+1,
                     'layer': -1}, # 'layer' is set later.
                ])

                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[cm_ind].shape),
                                list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]

        if context_mod_no_weights:
            self._hyper_shapes_learned_ref = \
                list(range(len(self._param_shapes)))

        ### Define and initialize batch norm weights.
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, n in enumerate(hidden_layers):
                bn_layer = BatchNormLayer(n,
                                          affine=not no_weights,
                                          track_running_stats=bn_track_stats)
                self._batchnorm_layers.append(bn_layer)

                self._param_shapes.extend(bn_layer.param_shapes)
                assert len(bn_layer.param_shapes) == 2
                self._param_shapes_meta.extend([
                    {
                        'name': 'bn_scale',
                        'index': -1 if no_weights else len(self._weights),
                        'layer': -1
                    },  # 'layer' is set later.
                    {
                        'name': 'bn_shift',
                        'index': -1 if no_weights else len(self._weights) + 1,
                        'layer': -1
                    },  # 'layer' is set later.
                ])

                if no_weights:
                    self._hyper_shapes_learned.extend(bn_layer.param_shapes)
                else:
                    self._weights.extend(bn_layer.weights)

                if distill_bn_stats:
                    self._hyper_shapes_distilled.extend( \
                        [list(p.shape) for p in bn_layer.get_stats(0)])

                # FIXME ugly code. Move initialization somewhere else.
                if not no_weights and init_weights is not None:
                    assert (len(bn_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                        bn_layer.weights[ii].data = init_weights[bn_ind]
                        bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        ### Compute shapes of linear layers.
        linear_shapes = MLP.weight_shapes(n_in=n_in,
                                          n_out=n_out,
                                          hidden_layers=hidden_layers,
                                          use_bias=use_bias)
        self._param_shapes.extend(linear_shapes)

        for i, s in enumerate(linear_shapes):
            self._param_shapes_meta.append({
                'name':
                'weight' if len(s) != 1 else 'bias',
                'index':
                -1 if no_weights else len(self._weights) + i,
                'layer':
                -1  # 'layer' is set later.
            })

        num_weights = MainNetInterface.shapes_to_num_weights(
            self._param_shapes)

        ### Set missing meta information of param_shapes.
        offset = 1 if use_context_mod and context_mod_inputs else 0
        shift = 1
        if use_batch_norm:
            shift += 1
        if use_context_mod:
            shift += 1

        cm_offset = 2 if context_mod_post_activation else 1
        bn_offset = 1 if context_mod_post_activation else 2

        cm_ind = 0
        bn_ind = 0
        layer_ind = 0

        for i, dd in enumerate(self._param_shapes_meta):
            if dd['name'].startswith('cm'):
                if offset == 1 and i in [0, 1]:
                    dd['layer'] = 0
                else:
                    if cm_ind < len(hidden_layers):
                        dd['layer'] = offset + cm_ind * shift + cm_offset
                    else:
                        assert cm_ind == len(hidden_layers) and \
                            not no_last_layer_context_mod
                        # No batchnorm in output layer.
                        dd['layer'] = offset + cm_ind * shift + 1

                    if dd['name'] == 'cm_shift':
                        cm_ind += 1

            elif dd['name'].startswith('bn'):
                dd['layer'] = offset + bn_ind * shift + bn_offset
                if dd['name'] == 'bn_shift':
                    bn_ind += 1

            else:
                dd['layer'] = offset + layer_ind * shift
                if not use_bias or dd['name'] == 'bias':
                    layer_ind += 1

        ### Uer information
        if verbose:
            if use_context_mod:
                cm_num_weights = 0
                for cm_layer in self._context_mod_layers:
                    cm_num_weights += MainNetInterface.shapes_to_num_weights( \
                        cm_layer.param_shapes)

            print(
                'Creating an MLP with %d weights' % num_weights +
                (' (including %d weights associated with-' % cm_num_weights +
                 'context modulation)' if use_context_mod else '') + '.' +
                (' The network uses dropout.' if dropout_rate != -1 else '') +
                (' The network uses batchnorm.' if use_batch_norm else ''))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._hyper_shapes_learned.extend(linear_shapes)

            if use_context_mod:
                if context_mod_no_weights:
                    self._hyper_shapes_learned_ref = \
                        list(range(len(self._param_shapes)))
                else:
                    ncm = len(self._context_mod_shapes)
                    self._hyper_shapes_learned_ref = \
                        list(range(ncm, len(self._param_shapes)))

            self._is_properly_setup()
            return

        ### Define and initialize linear weights.
        for i, dims in enumerate(linear_shapes):
            self._weights.append(
                nn.Parameter(torch.Tensor(*dims), requires_grad=True))
            if len(dims) == 1:
                self._layer_bias_vectors.append(self._weights[-1])
            else:
                self._layer_weight_tensors.append(self._weights[-1])

        if init_weights is not None:
            assert (len(init_weights) == len(linear_shapes))
            for i in range(len(init_weights)):
                assert (np.all(
                    np.equal(list(init_weights[i].shape), linear_shapes[i])))
                if use_bias:
                    if i % 2 == 0:
                        self._layer_weight_tensors[i //
                                                   2].data = init_weights[i]
                    else:
                        self._layer_bias_vectors[i // 2].data = init_weights[i]
                else:
                    self._layer_weight_tensors[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                if use_bias:
                    init_params(self._layer_weight_tensors[i],
                                self._layer_bias_vectors[i])
                else:
                    init_params(self._layer_weight_tensors[i])

        if self._num_context_mod_shapes() == 0:
            # Note, that might be the case if no hidden layers exist and no
            # input or output modulation is used.
            self._use_context_mod = False

        self._is_properly_setup()
    def __init__(self,
                 n_in,
                 n_out=1,
                 inp_chunk_dim=100,
                 out_chunk_dim=10,
                 cemb_size=8,
                 cemb_init_std=1.,
                 red_layers=(10, 10),
                 net_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 dynamic_biases=None,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        self._n_in = n_in
        self._n_out = n_out
        self._inp_chunk_dim = inp_chunk_dim
        self._out_chunk_dim = out_chunk_dim
        self._cemb_size = cemb_size
        self._a_fun = activation_fn
        self._no_weights = no_weights

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True  # Ensure that `out_fn` is `None`!

        self._param_shapes = []
        #self._param_shapes_meta = [] # TODO implement!
        self._weights = None if no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None if not no_weights else []
        #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
        #    is None else [] # TODO implement.

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        self._context_mod_layers = None
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        #################################
        ### Generate Chunk Embeddings ###
        #################################
        self._num_cembs = int(np.ceil(n_in / inp_chunk_dim))
        last_chunk_size = n_in % inp_chunk_dim
        if last_chunk_size != 0:
            self._pad = inp_chunk_dim - last_chunk_size
        else:
            self._pad = -1

        cemb_shape = [self._num_cembs, cemb_size]
        self._param_shapes.append(cemb_shape)
        if no_weights:
            self._cembs = None
            self._hyper_shapes_learned.append(cemb_shape)
        else:
            self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape),
                                       requires_grad=True)
            nn.init.normal_(self._cembs, mean=0., std=cemb_init_std)

            self._weights.append(self._cembs)

        ############################
        ### Setup Dynamic Biases ###
        ############################
        self._has_dyn_bias = None
        if dynamic_biases is not None:
            assert np.all(np.array(dynamic_biases) >= 0) and \
                   np.all(np.array(dynamic_biases) < len(red_layers) + 1)
            dynamic_biases = np.sort(np.unique(dynamic_biases))

            # For each layer in the `reducer`, where we want to apply a dynamic
            # bias, we have to create a weight matrix for a corresponding
            # linear layer (we just ignore)
            self._dyn_bias_weights = nn.ModuleList()
            self._has_dyn_bias = []

            for i in range(len(red_layers) + 1):
                if i in dynamic_biases:
                    self._has_dyn_bias.append(True)

                    trgt_dim = out_chunk_dim
                    if i < len(red_layers):
                        trgt_dim = red_layers[i]
                    trgt_shape = [trgt_dim, cemb_size]

                    self._param_shapes.append(trgt_shape)
                    if not no_weights:
                        self._dyn_bias_weights.append(None)
                        self._hyper_shapes_learned.append(trgt_shape)
                    else:
                        self._dyn_bias_weights.append(nn.Parameter( \
                            torch.Tensor(*trgt_shape), requires_grad=True))
                        self._weights.append(self._dyn_bias_weights[-1])

                        init_params(self._dyn_bias_weights[-1])

                        self._layer_weight_tensors.append( \
                            self._dyn_bias_weights[-1])
                        self._layer_bias_vectors.append(None)
                else:
                    self._has_dyn_bias.append(False)
                    self._dyn_bias_weights.append(None)

        ################################
        ### Create `Reducer` Network ###
        ################################
        red_inp_dim = inp_chunk_dim + \
            (cemb_size if dynamic_biases is None else 0)
        self._reducer = MLP(
            n_in=red_inp_dim,
            n_out=out_chunk_dim,
            hidden_layers=red_layers,
            activation_fn=activation_fn,
            use_bias=use_bias,
            no_weights=no_weights,
            init_weights=None,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm,
            bn_track_stats=bn_track_stats,
            distill_bn_stats=distill_bn_stats,
            # We use context modulation to realize dynamic biases, since they
            # allow a different modulation per sample in the input mini-batch.
            # Hence, we can process several chunks in parallel with the reducer
            # network.
            use_context_mod=not dynamic_biases is None,
            context_mod_inputs=False,
            no_last_layer_context_mod=False,
            context_mod_no_weights=True,
            context_mod_post_activation=False,
            context_mod_gain_offset=False,
            context_mod_gain_softplus=False,
            out_fn=None,
            verbose=True)

        if dynamic_biases is not None:
            # FIXME We have to extract the param shapes from
            # `self._reducer.param_shapes`, as well as from
            # `self._reducer._hyper_shapes_learned` that belong to context-mod
            # layers. We may not add them to our own `param_shapes` attribute,
            # as these are not parameters (due to our misuse of the context-mod
            # layers).
            # Note, in the `forward` method, we need to supply context-mod
            # weights for all reducer networks, independent on whether they have
            # a dynamic bias or not. We can do so, by providing constant ones
            # for all gains and constance zero-shift for all layers without
            # dynamic biases (note, we need to ensure the correct batch dim!).
            raise NotImplementedError(
                'Dynamic biases are not yet implemented!')

        assert self._reducer._context_mod_layers is None

        ### Overtake all attributes from the underlying MLP.
        for s in self._reducer.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._reducer._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._reducer._weights:
                self._weights.append(p)

        for p in self._reducer._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._reducer._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._reducer._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._reducer._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = []
            for s in self._reducer._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        ###############################
        ### Create Actual `Network` ###
        ###############################
        net_inp_dim = out_chunk_dim * self._num_cembs
        self._network = MLP(n_in=net_inp_dim,
                            n_out=n_out,
                            hidden_layers=net_layers,
                            activation_fn=activation_fn,
                            use_bias=use_bias,
                            no_weights=no_weights,
                            init_weights=None,
                            dropout_rate=dropout_rate,
                            use_spectral_norm=use_spectral_norm,
                            use_batch_norm=use_batch_norm,
                            bn_track_stats=bn_track_stats,
                            distill_bn_stats=distill_bn_stats,
                            use_context_mod=False,
                            out_fn=None,
                            verbose=True)

        ### Overtake all attributes from the underlying MLP.
        for s in self._network.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._network._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._network._weights:
                self._weights.append(p)

        for p in self._network._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._network._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._network._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._hyper_shapes_distilled is not None:
            assert self._network._hyper_shapes_distilled is not None
            for s in self._network._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        #####################################
        ### Takeover given Initialization ###
        #####################################
        if init_weights is not None:
            assert len(init_weights) == len(self._weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self._param_shapes[i]))
                self._weights[i].data = init_weights[i]

        ######################
        ### Finalize Setup ###
        ######################
        num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes)
        print('Constructed MLP that processes dimensionality reduced inputs ' +
              'through chunking. The network has a total of %d weights.' %
              num_weights)

        self._is_properly_setup()