class BiRNN(nn.Module, MainNetInterface): r"""Implementation of a bidirectional RNN. Note: The output is non-linear if the last layer is recurrent! Otherwise, logits are returned (cmp. attribute :attr:`mnets.mnet_interface.MainNetInterface.has_fc_out`). Example: Here is an example instantiation of a BiLSTM with a single bidirectional layer of dimensionality 256, assuming 100 dimensional inputs and 10 dimensional outputs. .. code-block:: python net = BiRNN(rnn_args={'n_in': 100, 'rnn_layers': [256], 'use_lstm': True, 'fc_layers_pre': [], 'fc_layers': []}, mlp_args={'n_in': 512, 'n_out': 10, 'hidden_layers': []}, no_weights=False) Attributes: preprocess_fct (func, optional): See constructor argument ``preprocess_fct``. num_rec_layers (int): See attribute :attr:`mnets.simple_rnn.SimpleRNN.num_rec_layers`. Total number of recurrent layer, where each bidirectional layer consists of at least two recurrent layers (forward and backward layer). use_lstm (bool): See attribute :attr:`mnets.simple_rnn.SimpleRNN.use_lstm`. Args: rnn_args (dict or list): A dictionary of arguments for an instance of class :class:`mnets.simple_rnn.SimpleRNN`. These arguments will be used to create two instances of this class, one representing the forward RNN and one the backward RNN. Note, each of these instances may contain multiple layers, even non-recurrent layers. The outputs of such an instance are considered the hidden activations :math:`\hat{h}_{1:T}^{(f)}` or :math:`\hat{h}_{1:T}^{(b)}`, respectively. To realize multiple bidirectional layers (which in itself can be multi-layer RNNs), one may provide a list of dictionaries. Each entry in such list will be used to generate a single bidirectional layer (i.e., consisting of two instances of class :class:`mnets.simple_rnn.SimpleRNN`). Note, the input size of each new layer has to be twice the size of :math:`\hat{h}_t^{(f)}` from the previous layer. mlp_args (dict, optional): A dictionary of arguments for class :class:`mnets.mlp.MLP`. The input size of such an MLP should be twice the size of :math:`\hat{h}_t^{(f)}`. If ``None``, then the output of the last bidirectional layer is considered the output of the network. preprocess_fct (func, optional): A function handle can be provided, that will process inputs ``x`` passed to the method :meth:`forward`. An example usecase could be the translation or selection of word embeddings. The function handle must have the signature: ``preprocess_fct(x, seq_lengths=None)``. See the corresponding argument descriptions of method :meth:`forward`.The function is expected to return the preprocessed ``x``. no_weights (bool): See parameter ``no_weights`` of class :class:`mnets.mlp.MLP`. verbose (bool): See parameter ``verbose`` of class :class:`mnets.mlp.MLP`. """ def __init__(self, rnn_args={}, mlp_args=None, preprocess_fct=None, no_weights=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) assert isinstance(rnn_args, (dict, list, tuple)) assert mlp_args is None or isinstance(mlp_args, dict) if isinstance(rnn_args, dict): rnn_args = [rnn_args] self._forward_rnns = [] self._backward_rnns = [] self._out_mlp = None self._preprocess_fct = preprocess_fct self._forward_called = False # FIXME At the moment we do not control input and output size of # individual networks and need to assume that the user sets them # correctly. ### Create all forward and backward nets for each bidirectional layer. for rargs in rnn_args: assert isinstance(rargs, dict) if 'verbose' not in rargs.keys(): rargs['verbose'] = False if 'no_weights' in rargs.keys() and \ rargs['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'bidirectional layer is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in rargs.keys(): rargs['no_weights'] = no_weights self._forward_rnns.append(SimpleRNN(**rargs)) self._backward_rnns.append(SimpleRNN(**rargs)) ### Create output network. if mlp_args is not None: if 'verbose' not in mlp_args.keys(): mlp_args['verbose'] = False if 'no_weights' in mlp_args.keys() and \ mlp_args['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'output MLP is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in mlp_args.keys(): mlp_args['no_weights'] = no_weights self._out_mlp = MLP(**mlp_args) ### Set all interface attributes correctly. if self._out_mlp is None: self._has_fc_out = self._forward_rnns[-1].has_fc_out # We can't set the following attribute to true, as the output is # a concatenation of the outputs from two networks. Therefore, the # weights used two compute the outputs are at different locations # in the `param_shapes` list. self._mask_fc_out = False self._has_linear_out = self._forward_rnns[-1].has_linear_out else: self._has_fc_out = self._out_mlp.has_fc_out self._mask_fc_out = self._out_mlp.mask_fc_out self._has_linear_out = self._out_mlp.has_linear_out # Collect all internal net objects from which we need to collect # attributes. nets = [] for i, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[i] nets.append((fnet, 'forward_rnn', i)) nets.append((bnet, 'backward_rnn', i)) if self._out_mlp is not None: nets.append((self._out_mlp, 'out_mlp', -1)) # Iterate over all nets to collect their attribute values. self._param_shapes = [] self._param_shapes_meta = [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() for i, net_tup in enumerate(nets): net, net_type, net_id = net_tup # Note, it is important to convert lists into new object and not # just copy references! # Note, we have to adapt all references if `i > 0`. # Sanity check: if i == 0: cm_nw = net._context_mod_no_weights elif cm_nw != net._context_mod_no_weights: raise ValueError('Network expect that either all internal ' + 'networks maintain their context-mod ' + 'weights or non of them does!') ps_len_old = len(self._param_shapes) if net._internal_params is not None: if self._internal_params is None: self._internal_params = nn.ParameterList() ip_len_old = len(self._internal_params) self._internal_params.extend( \ nn.ParameterList(net._internal_params)) self._param_shapes.extend(list(net._param_shapes)) for meta in net.param_shapes_meta: assert 'birnn_layer_type' not in meta.keys() assert 'birnn_layer_id' not in meta.keys() new_meta = dict(meta) new_meta['birnn_layer_type'] = net_type new_meta['birnn_layer_id'] = net_id if i > 0: # FIXME We should properly adjust colliding `layer` IDs. new_meta['layer'] = -1 new_meta['index'] = meta['index'] + ip_len_old self._param_shapes_meta.append(new_meta) if net._hyper_shapes_learned is not None: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = [] self._hyper_shapes_learned_ref = [] self._hyper_shapes_learned.extend( \ list(net._hyper_shapes_learned)) for ref in net._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ref + ps_len_old) if net._hyper_shapes_distilled is not None: if self._hyper_shapes_distilled is None: self._hyper_shapes_distilled = [] self._hyper_shapes_distilled.extend( \ list(net._hyper_shapes_distilled)) if self._has_bias is None: self._has_bias = net._has_bias elif self._has_bias != net._has_bias: self._has_bias = False # FIXME We should overwrite the getter and throw an error! warn('Some internally maintained networks use biases, ' + 'while others don\'t. Setting attribute "has_bias" to ' + 'False.') self._layer_weight_tensors.extend( \ nn.ParameterList(net._layer_weight_tensors)) self._layer_bias_vectors.extend( \ nn.ParameterList(net._layer_bias_vectors)) if net._batchnorm_layers is not None: if self._batchnorm_layers is None: self._batchnorm_layers = nn.ModuleList() self._batchnorm_layers.extend( \ nn.ModuleList(net._batchnorm_layers)) if net._context_mod_layers is not None: if self._context_mod_layers is None: self._context_mod_layers = nn.ModuleList() self._context_mod_layers.extend( \ nn.ModuleList(net._context_mod_layers)) self._is_properly_setup() ### Print user information. if verbose: print('Constructed Bidirectional RNN with %d weights.' \ % self.num_params) @property def preprocess_fct(self): """Getter for attribute :attr:`preprocess_fct`.""" return self._preprocess_fct @preprocess_fct.setter def preprocess_fct(self, value): """Setter for attribute :attr:`preprocess_fct`. Note: This setter may only be called before the first call of the :meth:`forward` method. """ if self._forward_called: raise RuntimeError('Attribute "preprocess_fct" cannot be ' + 'modified after method "forward" has been ' + 'called.') self._preprocess_fct = value @property def num_rec_layers(self): """Getter for read-only attribute :attr:`num_rec_layers`.""" num_rec_layers = 0 for net in self._forward_rnns + self._backward_rnns: num_rec_layers += net.num_rec_layers return num_rec_layers @property def use_lstm(self): """Getter for read-only attribute :attr:`use_lstm`.""" use_lstm = self._forward_rnns[0].use_lstm for i in range(1, len(self._forward_rnns)): if self._forward_rnns[i].use_lstm != use_lstm: raise RuntimeError('Attribute "use_lstm" not applicable to ' + 'this network as layers use mixed types ' + 'of RNNs.') return use_lstm def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. """ # SimpleRNNs should not have any distillation targets. for net in self._forward_rnns + self._backward_rnns: if net.distillation_targets is not None: raise RuntimeError() if self._out_mlp is not None: return self._out_mlp.distillation_targets() return None def forward(self, x, weights=None, distilled_params=None, condition=None, seq_lengths=None): """Compute the output :math:`y` of this network given the input :math:`x`. Note: If constructor argument ``preprocess_fct`` was set, then all inputs ``x`` are first processed by this function. Args: (....): See docstring of method :meth:`mnets.mnet_interface.MainNetInterface.forward`. We provide some more specific information below. weights (list or dict): See argument ``weights`` of method :meth:`mnets.mlp.MLP.forward`. distilled_params: Will only be passed to the underlying instance of class :class:`mnets.mlp.MLP` condition (int or dict, optional): If provided, then this argument will be passed as argument ``ckpt_id`` to the method :meth:`utils.context_mod_layer.ContextModLayer.forward`. When providing as dict, see argument ``condition`` of method :meth:`mnets.mlp.MLP.forward` for more details. seq_lengths (numpy.ndarray, optional): List of sequence lengths. The length of the list has to match the batch size of inputs ``x``. The entries will correspond to the unpadded sequence lengths. If this option is provided, then the bidirectional layers will reverse its input sequences according to the unpadded sequence lengths. Example: ``x = [[a,b,0,0], [a,b,c,0]].T``. If ``seq_lengths = [2, 3]`` if provided, then the reverse sequences ``[[b,a,0,0], [c,b,a,0]].T`` are fed into the first bidirectional layer (and similarly for all subsequent bidirectional layers). Otherwise reverse sequences ``[[0,0,b,a], [0,c,b,a]].T`` are used. Caution: If this option is not provided but padded input sequences are used, the output of a bidirectional layer will depent on the padding. I.e., different padding lengths will lead to different results. Returns: (torch.Tensor or tuple): Where the tuple is containing: - **output** (torch.Tensor): The output of the network. - **hidden** (list): ``None`` - not implemented yet. """ # FIXME Delete warning below. if seq_lengths is None: warn('"seq_lengths" has not been provided to BiRNN.') if self._out_mlp is None: assert distilled_params is None ######################## ### Parse condition ### ####################### rnn_cmod_cond = None mlp_cond = None if condition is not None: if isinstance(condition, dict): if 'cmod_ckpt_id' in condition.keys(): rnn_cmod_cond = condition['cmod_ckpt_id'] mlp_cond = condition else: rnn_cmod_cond = condition mlp_cond = {'cmod_ckpt_id': condition} ######################################## ### Extract-weights for each network ### ######################################## forward_weights = [None] * len(self._forward_rnns) backward_weights = [None] * len(self._backward_rnns) mlp_weights = None n_cm = self._num_context_mod_shapes() int_weights = None cm_weights = None all_weights = None if weights is not None and isinstance(weights, dict): if 'internal_weights' in weights.keys(): int_weights = weights['internal_weights'] if 'mod_weights' in weights.keys(): cm_weights = weights['mod_weights'] elif weights is not None: if len(weights) == n_cm: cm_weights = weights else: assert len(weights) == len(self.param_shapes) all_weights = weights if weights is not None: # Collect all context-mod and internal weights if not explicitly # passed. Note, those will either be taken from `all_weights` or # have to exist internally. if n_cm > 0 and cm_weights is None: cm_weights = [] for ii, meta in enumerate(self.param_shapes_meta): if meta['name'].startswith('cm_'): if all_weights is not None: cm_weights.append(all_weights[ii]) else: assert meta['index'] != -1 cm_weights.append( \ self.internal_params[meta['index']]) if int_weights is None: int_weights = [] for ii, meta in enumerate(self.param_shapes_meta): if not meta['name'].startswith('cm_'): if all_weights is not None: int_weights.append(all_weights[ii]) else: assert meta['index'] != -1 int_weights.append( \ self.internal_params[meta['index']]) # Now that we have all context-mod and internal weights, we need to # distribute them across networks. Therefore, note that the order # in which they appear in `param_shapes` matches the order of # `cm_weights` and `int_weights`. cm_ind = 0 int_ind = 0 for ii, meta in enumerate(self.param_shapes_meta): net_type = meta['birnn_layer_type'] net_id = meta['birnn_layer_id'] if net_type == 'forward_rnn': if forward_weights[net_id] is None: forward_weights[net_id] = dict() curr_weights = forward_weights[net_id] elif net_type == 'backward_rnn': if backward_weights[net_id] is None: backward_weights[net_id] = dict() curr_weights = backward_weights[net_id] else: assert net_type == 'out_mlp' if mlp_weights is None: mlp_weights = dict() curr_weights = mlp_weights if meta['name'].startswith('cm_'): if 'mod_weights' not in curr_weights.keys(): curr_weights['mod_weights'] = [] curr_weights['mod_weights'].append(cm_weights[cm_ind]) cm_ind += 1 else: if 'internal_weights' not in curr_weights.keys(): curr_weights['internal_weights'] = [] curr_weights['internal_weights'].append( \ int_weights[int_ind]) int_ind += 1 ##################################### ### Apply potential preprocessing ### ##################################### self._forward_called = True if self._preprocess_fct is not None: x = self._preprocess_fct(x, seq_lengths=seq_lengths) #################################### ### Process bidirectional layers ### #################################### # Create reverse input sequence for backward network. if seq_lengths is not None: assert seq_lengths.size == x.shape[1] def revert_order(inp): if seq_lengths is None: return torch.flip(inp, [0]) else: inp_back = torch.zeros_like(inp) for ii in range(seq_lengths.size): inp_back[:int(seq_lengths[ii]),ii, :] = \ torch.flip(inp[:int(seq_lengths[ii]),ii, :], [0]) return inp_back h = x for ll, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[ll] # Revert inputs in time before processing them by the backward RNN. h_rev = revert_order(h) h_f = fnet.forward(h, weights=forward_weights[ll], condition=rnn_cmod_cond, return_hidden=False, return_hidden_int=False) h_b = bnet.forward(h_rev, weights=backward_weights[ll], condition=rnn_cmod_cond, return_hidden=False, return_hidden_int=False) # Revert outputs in time from the backward RNN before concatenation. # NOTE If `seq_lengths` are given, then this function will also set # the hidden timesteps corresponding to "padded timesteps" to zero. h_b = revert_order(h_b) # Set hidden states of `h_f` corresponding to padded timesteps to # zero to ensure consistency. Note, will only ever affect those # "padded timesteps". if seq_lengths is not None: for ii in range(seq_lengths.size): h_f[:int(seq_lengths[ii]), ii, :] = 0 h = torch.cat([h_f, h_b], dim=2) ############################## ### Compute network output ### ############################## if self._out_mlp is not None: #n_time, n_batch, n_feat = h.shape #h = h.view(n_time*n_batch, n_feat) h = self._out_mlp.forward(h, weights=mlp_weights, distilled_params=distilled_params, condition=mlp_cond) #h = h.view(n_time, n_batch, -1) return h def init_hh_weights_orthogonal(self): """Initialize hidden-to-hidden weights orthogonally. This method will call method :meth:`mnets.simple_rnn.SimpleRNN.init_hh_weights_orthogonal` of all internally maintained instances of class :class:`mnets.simple_rnn.SimpleRNN`. """ for net in self._forward_rnns + self._backward_rnns: net.init_hh_weights_orthogonal() def get_cm_weights(self): """Get internal maintained weights that are associated with context- modulation. Returns: (list): List of weights from :attr:`mnets.mnet_interface.MainNetInterface.internal_params` that are belonging to context-mod layers. """ ret = [] for i, meta in enumerate(self.param_shapes_meta): if not (meta['name'] == 'cm_shift' or meta['name'] == 'cm_scale'): continue if meta['index'] != -1: ret.append(self.internal_params[meta['index']]) return ret def get_non_cm_weights(self): """Get internal weights that are not associated with context-modulation. Returns: (list): List of weights from :attr:`mnets.mnet_interface.MainNetInterface.internal_params` that are not belonging to context-mod layers. """ n_cm = self._num_context_mod_shapes() if n_cm == 0: return self.internal_params else: ret = [] for i, meta in enumerate(self.param_shapes_meta): if meta['name'] == 'cm_shift' or meta['name'] == 'cm_scale': continue if meta['index'] != -1: ret.append(self.internal_params[meta['index']]) return ret
class ChunkSqueezer(nn.Module, MainNetInterface): """An MLP that first reduces the dimensionality of its inputs. The input dimensionality ``n_in`` is first reduced by a `reducer` network (which is an instance of class :class:`mnets.mlp.MLP`) using a chunking strategy. The reduced input will be then passed to the actual `network` (which is another instance of :class:`mnets.mlp.MLP`) to compute an output. Args: n_in (int): Input dimensionality. n_out (int): Number of output neurons. inp_chunk_dim (int): The input (dimensionality ``n_in``) will be split into chunks of size ``inp_chunk_dim``. Thus, there will be ``np.ceil(n_in/inp_chunk_dim)`` input chunks that are individually squeezed through the `reducer` network. Note: If the last chunk chunk might be zero-padded. out_chunk_dim (int): The output size of the `reducer` network. The input size of the actual `network` is then ``np.ceil(n_in/inp_chunk_dim) * out_chunk_dim``. cemb_size (int): The `reducer` network processes every chunk individually. In order to do so, it needs to know which chunk it is processing. Therefore, it is conditioned on a learned chunk embedding (there will be ``np.ceil(n_in/inp_chunk_dim)`` chunk embeddings). The dimensionality of these chunk embeddings is dertermined by this argument. cemb_init_std (float): Standard deviation used for the normal initialization of the chunk embeddings. red_layers (list or tuple): The architecture of the `reducer` network. See argument ``hidden_layers`` of class :class:`mnets.mlp.MLP`. net_layers (list or tuple): The architecture of the actual `network`. See argument ``hidden_layers`` of class :class:`mnets.mlp.MLP`. activation_fn: The nonlinearity used in hidden layers. If ``None``, no nonlinearity will be applied. use_bias: Will be passed as option ``use_bias`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). dynamic_biases (list, optional): This option determines the hidden layers of the `reducer` networks that receive the chunk embedding as dynamic biases. It is a list of indexes with the first hidden layer having index 0 and the output of the `reducer` would have index ``len(red_layers)``. The chunk embeddings will be transformed through a fully connected layer (no bias) and then added as "dynamic" bias to the output of the corresponding hidden layer. Note: If left unspecified, the chunk embeddings will just be another input to the `reducer` network. no_weights (bool): If set to ``True``, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the :meth:`forward` method. init_weights (optional): This option is for convinience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convinient way of initializing a network with a weight draw produced by the hypernetwork. Note, internal weights (see :attr:`mnets.mnet_interface.MainNetInterface.weights`) will be affected by this argument only. dropout_rate (float): Will be passed as option ``dropout_rate`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). use_spectral_norm (bool): Will be passed as option ``use_spectral_norm`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). use_batch_norm (bool): Will be passed as option ``use_batch_norm`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). bn_track_stats (bool): Will be passed as option ``bn_track_stats`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). distill_bn_stats (bool): Will be passed as option ``distill_bn_stats`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). """ def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup() def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. This method will return the distillation targets from the 2 underlying networks, see method :meth:`mnets.mlp.MLP.distillation_targets`. Returns: The target tensors corresponding to the shapes specified in attribute :attr:`hyper_shapes_distilled`. """ if self.hyper_shapes_distilled is None: return None ret = self._reducer.distillation_targets + \ self._network.distillation_targets return ret def forward(self, x, weights=None, distilled_params=None, condition=None): """Compute the output :math:`y` of this network given the input :math:`x`. Args: (....): See docstring of method :meth:`mnets.mnet_interface.MainNetInterface.forward`. We provide some more specific information below. distilled_params: Will be split and passed as distillation targets to the underying instances of class :class:`mnets.mlp.MLP` if specified. condition (optional, int or dict): Will be passed to the underlying instances of class :class:`mnets.mlp.MLP`. Returns: The output :math:`y` of the network. """ if self._no_weights and weights is None: raise Exception('Network was generated without weights. ' + 'Hence, "weights" option may not be None.') if weights is None: weights = self._weights else: assert len(weights) == len(self.param_shapes) for i, s in enumerate(self.param_shapes): assert np.all(np.equal(s, list(weights[i].shape))) ######################################### ### Extract parameters from `weights` ### ######################################### cembs = weights[0] w_ind = 1 if self._has_dyn_bias is not None: w_ind_new = w_ind + len(self._dyn_bias_weights) dyn_bias_weights = weights[w_ind:w_ind_new] w_ind = w_ind_new # TODO use `dyn_bias_weights` to construct weights for context-mod # layers. raise NotImplementedError w_ind_new = w_ind + len(self._reducer.param_shapes) red_weights = weights[w_ind:w_ind_new] w_ind = w_ind_new w_ind_new = w_ind + len(self._network.param_shapes) net_weights = weights[w_ind:w_ind_new] w_ind = w_ind_new red_distilled_params = None net_distilled_params = None if distilled_params is not None: if self.hyper_shapes_distilled is None: raise ValueError( 'Argument "distilled_params" can only be ' + 'provided if the return value of ' + 'method "distillation_targets()" is not None.') assert len(distilled_params) == len(self.hyper_shapes_distilled) red_distilled_params = \ distilled_params[:len(self._reducer.hyper_shapes_distilled)] net_distilled_params = \ distilled_params[len(self._reducer.hyper_shapes_distilled):] ########################### ### Chunk network input ### ########################### assert x.shape[1] == self._n_in if self._pad != -1: x = F.pad(x, (0, self._pad)) assert x.shape[1] % self._out_chunk_dim == 0 batch_size = x.shape[0] # We now split the input `x` into chunks and convert them into # separate samples, i.e., the `batch_size` will be multiplied by the # number of chunks. # So, we parallel process a huge batch with a small network rather than # processing a huge input with a huge network. chunks = torch.split(x, self._inp_chunk_dim, dim=1) # Concatenate the chunks along the batch dimension. chunks = torch.cat(chunks, dim=0) if self._has_dyn_bias is not None: raise NotImplementedError() else: # Within a chunk the same chunk embedding is used. cembs = torch.split(cembs, 1, dim=0) cembs = [emb.expand(batch_size, -1) for emb in cembs] cembs = torch.cat(cembs, dim=0) chunks = torch.cat([chunks, cembs], dim=1) ################################### ### Reduce input dimensionality ### ################################### if self._has_dyn_bias is not None: # TODO pass context-mod weights to `reducer`. raise NotImplementedError() chunks = self._reducer.forward(chunks, weights=red_weights, distilled_params=red_distilled_params, condition=condition) ### Reformat `reducer` output into the input of the actual `network`. chunks = torch.split(chunks, batch_size, dim=0) net_input = torch.cat(chunks, dim=1) assert net_input.shape[0] == batch_size ############################### ### Compute network output ### ############################## return self._network.forward(net_input, weights=net_weights, distilled_params=net_distilled_params, condition=condition)