def __init__(self, num_classes, verbose): """Initialize the network. Args: num_classes: The number of output neurons. verbose: Allow printing of general information about the generated network (such as number of weights). """ # FIXME find a way using super to handle multiple inheritence. #super(Classifier, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) assert(num_classes > 0) self._num_classes = num_classes self._verbose = verbose
def __init__(self, mnet, no_mean_reinit=False, logvar_encoding=False, apply_rho_offset=False, is_radial=False): # FIXME find a way using super to handle multiple inheritance. #super(GaussianBNNWrapper, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) assert isinstance(mnet, MainNetInterface) assert not isinstance(mnet, GaussianBNNWrapper) if is_radial: print('Converting network into BNN with radial weight ' + 'distribution ...') else: print( 'Converting network into BNN with diagonal Gaussian weight ' + 'distribution ...') self._mnet = mnet self._logvar_encoding = logvar_encoding self._apply_rho_offset = apply_rho_offset self._rho_offset = -2.5 self._is_radial = is_radial # Take over attributes of `mnet` and modify them if necessary. self._mean_params = None self._rho_params = None if mnet.internal_params is not None: self._mean_params = mnet.internal_params self._rho_params = nn.ParameterList() for p in self._mean_params: self._rho_params.append( nn.Parameter(torch.Tensor(p.size()), requires_grad=True)) # Initialize weights. if not no_mean_reinit: for p in self._mean_params: p.data.uniform_(-0.1, 0.1) for p in self._rho_params: if apply_rho_offset: # We will subtract 2.5 from `rho` in the forward method. #p.data.uniform_(-0.5, 0.5) p.data.uniform_(-3 - self._rho_offset, -2 - self._rho_offset) else: p.data.uniform_(-3, -2) self._internal_params = nn.ParameterList() for p in self._mean_params: self._internal_params.append(p) for p in self._rho_params: self._internal_params.append(p) # Simply duplicate `param_shapes` and `hyper_shapes_learned`. self._param_shapes = mnet.param_shapes + mnet.param_shapes if mnet._param_shapes_meta is not None: self._param_shapes_meta = [] old_wlen = 0 if self.internal_params is None \ else len(mnet.internal_params) for dd in mnet._param_shapes_meta: dd['dist_param'] = 'mean' self._param_shapes_meta.append(dd) for dd_old in mnet._param_shapes_meta: dd = dict(dd_old) dd['index'] += old_wlen dd['dist_param'] = 'rho' self._param_shapes_meta.append(dd) if mnet._hyper_shapes_learned is not None: self._hyper_shapes_learned = mnet._hyper_shapes_learned + \ mnet._hyper_shapes_learned if mnet._hyper_shapes_learned_ref is not None: self._hyper_shapes_learned_ref = \ list(mnet._hyper_shapes_learned_ref) old_plen = len(mnet.param_shapes) for ii in mnet._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ii + old_plen) self._hyper_shapes_distilled = mnet._hyper_shapes_distilled if self._hyper_shapes_distilled is not None: # In general, that shouldn't be an issue, as those distilled values # are just things like batchnorm stats. But it might be good to # inform the user about the fact that we are not considering this # attribute as special. warn('Class "GaussianBNNWrapper" doesn\'t modify the existing ' + 'attribute "hyper_shapes_distilled".') self._has_bias = mnet._has_bias self._has_fc_out = mnet._has_fc_out # Note, it's still true that the last two entries of # `hyper_shapes_learned` are belonging to the output layer. But those # are only the variance weights. So, we would forget about the mean # weights when setting this quantitiy to true. self._mask_fc_out = False #mnet._mask_fc_out self._has_linear_out = mnet._has_linear_out # We don't modify the following attributed, but generate warnings # when using them. self._layer_weight_tensors = mnet._layer_weight_tensors self._layer_bias_vectors = mnet._layer_bias_vectors self._batchnorm_layers = mnet._batchnorm_layers self._context_mod_layers = mnet._context_mod_layers self._is_properly_setup(check_has_bias=False)
def __init__(self, n_in=1, n_out=1, hidden_layers=[2000, 2000], activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritence. #super(MainNetwork, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x: x # identity self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer( n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] print('hidden_layers', hidden_layers) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] # Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) num_weights = MainNetInterface.shapes_to_num_weights( self._param_shapes) if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print( 'Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append( nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert (len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert (np.all( np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i // 2].data = init_weights[i] else: self._layer_bias_vectors[i // 2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) self._is_properly_setup()
def __init__(self, rnn_args={}, mlp_args=None, preprocess_fct=None, no_weights=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) assert isinstance(rnn_args, (dict, list, tuple)) assert mlp_args is None or isinstance(mlp_args, dict) if isinstance(rnn_args, dict): rnn_args = [rnn_args] self._forward_rnns = [] self._backward_rnns = [] self._out_mlp = None self._preprocess_fct = preprocess_fct self._forward_called = False # FIXME At the moment we do not control input and output size of # individual networks and need to assume that the user sets them # correctly. ### Create all forward and backward nets for each bidirectional layer. for rargs in rnn_args: assert isinstance(rargs, dict) if 'verbose' not in rargs.keys(): rargs['verbose'] = False if 'no_weights' in rargs.keys() and \ rargs['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'bidirectional layer is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in rargs.keys(): rargs['no_weights'] = no_weights self._forward_rnns.append(SimpleRNN(**rargs)) self._backward_rnns.append(SimpleRNN(**rargs)) ### Create output network. if mlp_args is not None: if 'verbose' not in mlp_args.keys(): mlp_args['verbose'] = False if 'no_weights' in mlp_args.keys() and \ mlp_args['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'output MLP is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in mlp_args.keys(): mlp_args['no_weights'] = no_weights self._out_mlp = MLP(**mlp_args) ### Set all interface attributes correctly. if self._out_mlp is None: self._has_fc_out = self._forward_rnns[-1].has_fc_out # We can't set the following attribute to true, as the output is # a concatenation of the outputs from two networks. Therefore, the # weights used two compute the outputs are at different locations # in the `param_shapes` list. self._mask_fc_out = False self._has_linear_out = self._forward_rnns[-1].has_linear_out else: self._has_fc_out = self._out_mlp.has_fc_out self._mask_fc_out = self._out_mlp.mask_fc_out self._has_linear_out = self._out_mlp.has_linear_out # Collect all internal net objects from which we need to collect # attributes. nets = [] for i, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[i] nets.append((fnet, 'forward_rnn', i)) nets.append((bnet, 'backward_rnn', i)) if self._out_mlp is not None: nets.append((self._out_mlp, 'out_mlp', -1)) # Iterate over all nets to collect their attribute values. self._param_shapes = [] self._param_shapes_meta = [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() for i, net_tup in enumerate(nets): net, net_type, net_id = net_tup # Note, it is important to convert lists into new object and not # just copy references! # Note, we have to adapt all references if `i > 0`. # Sanity check: if i == 0: cm_nw = net._context_mod_no_weights elif cm_nw != net._context_mod_no_weights: raise ValueError('Network expect that either all internal ' + 'networks maintain their context-mod ' + 'weights or non of them does!') ps_len_old = len(self._param_shapes) if net._internal_params is not None: if self._internal_params is None: self._internal_params = nn.ParameterList() ip_len_old = len(self._internal_params) self._internal_params.extend( \ nn.ParameterList(net._internal_params)) self._param_shapes.extend(list(net._param_shapes)) for meta in net.param_shapes_meta: assert 'birnn_layer_type' not in meta.keys() assert 'birnn_layer_id' not in meta.keys() new_meta = dict(meta) new_meta['birnn_layer_type'] = net_type new_meta['birnn_layer_id'] = net_id if i > 0: # FIXME We should properly adjust colliding `layer` IDs. new_meta['layer'] = -1 new_meta['index'] = meta['index'] + ip_len_old self._param_shapes_meta.append(new_meta) if net._hyper_shapes_learned is not None: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = [] self._hyper_shapes_learned_ref = [] self._hyper_shapes_learned.extend( \ list(net._hyper_shapes_learned)) for ref in net._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ref + ps_len_old) if net._hyper_shapes_distilled is not None: if self._hyper_shapes_distilled is None: self._hyper_shapes_distilled = [] self._hyper_shapes_distilled.extend( \ list(net._hyper_shapes_distilled)) if self._has_bias is None: self._has_bias = net._has_bias elif self._has_bias != net._has_bias: self._has_bias = False # FIXME We should overwrite the getter and throw an error! warn('Some internally maintained networks use biases, ' + 'while others don\'t. Setting attribute "has_bias" to ' + 'False.') self._layer_weight_tensors.extend( \ nn.ParameterList(net._layer_weight_tensors)) self._layer_bias_vectors.extend( \ nn.ParameterList(net._layer_bias_vectors)) if net._batchnorm_layers is not None: if self._batchnorm_layers is None: self._batchnorm_layers = nn.ModuleList() self._batchnorm_layers.extend( \ nn.ModuleList(net._batchnorm_layers)) if net._context_mod_layers is not None: if self._context_mod_layers is None: self._context_mod_layers = nn.ModuleList() self._context_mod_layers.extend( \ nn.ModuleList(net._context_mod_layers)) self._is_properly_setup() ### Print user information. if verbose: print('Constructed Bidirectional RNN with %d weights.' \ % self.num_params)
def __init__(self, n_in=1, n_out=1, hidden_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) # FIXME Spectral norm is incorrectly implemented. Function # `nn.utils.spectral_norm` needs to be called in the constructor, such # that sepc norm is wrapped around a module. if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x: x # identity self._param_shapes = [] self._param_shapes_meta = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer( n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) assert len(cmod_layer.param_shapes) == 2 self._param_shapes_meta.extend([ {'name': 'cm_scale', 'index': -1 if context_mod_no_weights else \ len(self._weights), 'layer': -1}, # 'layer' is set later. {'name': 'cm_shift', 'index': -1 if context_mod_no_weights else \ len(self._weights)+1, 'layer': -1}, # 'layer' is set later. ]) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) assert len(bn_layer.param_shapes) == 2 self._param_shapes_meta.extend([ { 'name': 'bn_scale', 'index': -1 if no_weights else len(self._weights), 'layer': -1 }, # 'layer' is set later. { 'name': 'bn_shift', 'index': -1 if no_weights else len(self._weights) + 1, 'layer': -1 }, # 'layer' is set later. ]) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] ### Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) for i, s in enumerate(linear_shapes): self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else len(self._weights) + i, 'layer': -1 # 'layer' is set later. }) num_weights = MainNetInterface.shapes_to_num_weights( self._param_shapes) ### Set missing meta information of param_shapes. offset = 1 if use_context_mod and context_mod_inputs else 0 shift = 1 if use_batch_norm: shift += 1 if use_context_mod: shift += 1 cm_offset = 2 if context_mod_post_activation else 1 bn_offset = 1 if context_mod_post_activation else 2 cm_ind = 0 bn_ind = 0 layer_ind = 0 for i, dd in enumerate(self._param_shapes_meta): if dd['name'].startswith('cm'): if offset == 1 and i in [0, 1]: dd['layer'] = 0 else: if cm_ind < len(hidden_layers): dd['layer'] = offset + cm_ind * shift + cm_offset else: assert cm_ind == len(hidden_layers) and \ not no_last_layer_context_mod # No batchnorm in output layer. dd['layer'] = offset + cm_ind * shift + 1 if dd['name'] == 'cm_shift': cm_ind += 1 elif dd['name'].startswith('bn'): dd['layer'] = offset + bn_ind * shift + bn_offset if dd['name'] == 'bn_shift': bn_ind += 1 else: dd['layer'] = offset + layer_ind * shift if not use_bias or dd['name'] == 'bias': layer_ind += 1 ### Uer information if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print( 'Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) if use_context_mod: if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) else: ncm = len(self._context_mod_shapes) self._hyper_shapes_learned_ref = \ list(range(ncm, len(self._param_shapes))) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append( nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert (len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert (np.all( np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i // 2].data = init_weights[i] else: self._layer_bias_vectors[i // 2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) if self._num_context_mod_shapes() == 0: # Note, that might be the case if no hidden layers exist and no # input or output modulation is used. self._use_context_mod = False self._is_properly_setup()
def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup()