def __str__(self): """Print network information.""" num_uncond = MainNetInterface.shapes_to_num_weights( \ self.unconditional_param_shapes) num_cond = MainNetInterface.shapes_to_num_weights( \ self.conditional_param_shapes) num_uncond_internal = 0 num_cond_internal = 0 if self.unconditional_params is not None: num_uncond_internal = MainNetInterface.shapes_to_num_weights( \ [p.shape for p in self.unconditional_params]) if self.unconditional_params is not None: num_cond_internal = MainNetInterface.shapes_to_num_weights( \ [p.shape for p in self.conditional_params]) msg = 'Hypernetwork with %d weights and %d outputs (compression ' + \ 'ratio: %.2f).\nThe network consists of %d unconditional ' + \ 'weights (%d internally maintained) and %d conditional '+ \ 'weights (%d internally maintained).' return msg % (self.num_params, self.num_outputs, self.num_params / self.num_outputs, num_uncond, num_uncond_internal, num_cond, num_cond_internal)
def generate_classifier(config, data_handlers, device): """Create a classifier network. Depending on the experiment and method, the method manages to build either a classifier for task inference or a classifier that solves our task is build. This also implies if the network will receive weights from a hypernetwork or will have weights on its own. Following important configurations will be determined in order to create the classifier: \n * in- and output and hidden layer dimensions of the classifier. \n * architecture, chunk- and task-embedding details of the hypernetwork. See :class:`mnets.mlp.MLP` for details on the network that will be created to be a classifier. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. Returns: (tuple): Tuple containing: - **net**: The classifier network. - **class_hnet**: (optional) The classifier's hypernetwork. """ n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.class_incremental: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.training_task_infer or config.class_incremental: # task inference network config.out_dim = 1 # have all output neurons already build up for cl 2 if config.cl_scenario != 2: n_out = config.out_dim * config.num_tasks else: n_out = config.out_dim if config.training_task_infer or config.class_incremental: n_out = config.num_tasks # build classifier print('For the Classifier: ') class_arch = misc.str_to_ints(config.class_fc_arch) if config.training_with_hnet: no_weights = True else: no_weights = False net = MLP(n_in=n_in, n_out=n_out, hidden_layers=class_arch, activation_fn=misc.str_to_act(config.class_net_act), dropout_rate=config.class_dropout_rate, no_weights=no_weights).to(device) print('Constructed MLP with shapes: ', net.param_shapes) config.num_weights_class_net = \ MainNetInterface.shapes_to_num_weights(net.param_shapes) # build classifier hnet # this is set in the run method in train.py if config.training_with_hnet: class_hnet = sim_utils.get_hnet_model(config, config.num_tasks, device, net.param_shapes, cprefix='class_') init_params = list(class_hnet.parameters()) config.num_weights_class_hyper_net = sum( p.numel() for p in class_hnet.parameters() if p.requires_grad) config.compression_ratio_class = config.num_weights_class_hyper_net / \ config.num_weights_class_net print('Created classifier Hypernetwork with ratio: ', config.compression_ratio_class) if config.compression_ratio_class > 1: print('Note that the compression ratio is computed compared to ' + 'current target network, not might not be directly ' + 'comparable with the number of parameters of work we ' + 'compare against.') else: class_hnet = None init_params = list(net.parameters()) config.num_weights_class_hyper_net = None config.compression_ratio_class = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if config.training_with_hnet: for temb in class_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(class_hnet, 'chunk_embeddings'): for emb in class_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) if not config.training_with_hnet: return net else: return net, class_hnet
def __init__(self, in_shape=(32, 32, 3), num_classes=10, verbose=True, arch='cifar', no_weights=False, init_weights=None, dropout_rate=0.25): super(ZenkeNet, self).__init__(num_classes, verbose) assert(in_shape[0] == 32 and in_shape[1] == 32) self._in_shape = in_shape assert(arch in ZenkeNet._architectures.keys()) self._param_shapes = ZenkeNet._architectures[arch] self._param_shapes[-2][0] = num_classes self._param_shapes[-1][0] = num_classes assert(init_weights is None or no_weights is False) self._no_weights = no_weights self._use_dropout = dropout_rate != -1 self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True # We don't use any output non-linearity. self._has_linear_out = True self._num_weights = MainNetInterface.shapes_to_num_weights( \ self._param_shapes) if verbose: print('Creating a ZenkeNet with %d weights' \ % (self._num_weights) + (', that uses dropout.' if self._use_dropout else '.')) if self._use_dropout: if dropout_rate > 0.5: # FIXME not a pretty solution, but we aim to follow the original # paper. raise ValueError('Dropout rate must be smaller equal 0.5.') self._drop_conv = nn.Dropout2d(p=dropout_rate) self._drop_fc1 = nn.Dropout(p=dropout_rate * 2.) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._weights = None self._hyper_shapes_learned = self._param_shapes self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) self._is_properly_setup() return ### Define and initialize network weights. # Each odd entry of this list will contain a weight Tensor and each # even entry a bias vector. self._weights = nn.ParameterList() for i, dims in enumerate(self._param_shapes): self._weights.append(nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if i % 2 == 0: self._layer_weight_tensors.append(self._weights[i]) else: assert(len(dims) == 1) self._layer_bias_vectors.append(self._weights[i]) if init_weights is not None: assert(len(init_weights) == len(self._param_shapes)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), list(self._weights[i].shape)))) self._weights[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) self._is_properly_setup()
def __init__(self, target_shapes, num_tasks, chunk_dim=2586, layers=[50, 100], te_dim=8, activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, ce_dim=None, init_weights=None, dropout_rate=-1, noise_dim=-1, temb_std=-1): # FIXME find a way using super to handle multiple inheritence. #super(ChunkedHyperNetworkHandler, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) assert (len(target_shapes) > 0) assert (init_weights is None or no_weights is False) assert (ce_dim is not None) self._target_shapes = target_shapes self._num_tasks = num_tasks self._ce_dim = ce_dim self._chunk_dim = chunk_dim self._layers = layers self._use_bias = use_bias self._act_fn = activation_fn self._init_weights = init_weights self._no_weights = no_weights self._te_dim = te_dim self._noise_dim = noise_dim self._temb_std = temb_std self._shifts = None # FIXME temporary test. # FIXME: weights should incorporate chunk embeddings as they are part of # theta. if init_weights is not None: warn('Argument "init_weights" does not yet allow initialization ' + 'of chunk embeddings.') ### Generate Hypernet with chunk_dim output. # Note, we can safely pass "temb_std" to the full hypernetwork, as we # process all chunks in one big batch and the hypernet will use the same # perturbed task embeddings for that reason (i.e., noise is shared). self._hypernet = HyperNetwork([[chunk_dim]], num_tasks, verbose=False, layers=layers, te_dim=te_dim, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=init_weights, ce_dim=ce_dim + (noise_dim if noise_dim != -1 else 0), dropout_rate=dropout_rate, noise_dim=-1, temb_std=temb_std) self._num_outputs = MainNetInterface.shapes_to_num_weights( \ self._target_shapes) ### Generate embeddings for all weight chunks. self._num_chunks = int(np.ceil(self._num_outputs / chunk_dim)) if no_weights: self._embs = None else: self._embs = nn.Parameter(data=torch.Tensor( self._num_chunks, ce_dim), requires_grad=True) nn.init.normal_(self._embs, mean=0., std=1.) # Note, the chunk embeddings are part of theta. hdims = self._hypernet.theta_shapes ntheta = MainNetInterface.shapes_to_num_weights(hdims) + \ (self._embs.numel() if not no_weights else 0) ntembs = int(np.sum([t.numel() for t in self.get_task_embs()])) self._num_weights = ntheta + ntembs print('Constructed hypernetwork with %d parameters ' % (ntheta \ + ntembs) + '(%d network weights + %d task embedding weights).' % (ntheta, ntembs)) print('The hypernetwork has a total of %d outputs.' % self._num_outputs) self._theta_shapes = [[self._num_chunks, ce_dim]] + \ self._hypernet.theta_shapes self._is_properly_setup()
def __init__(self, n_in=1, n_out=1, hidden_layers=[2000, 2000], activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritence. #super(MainNetwork, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x: x # identity self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer( n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] print('hidden_layers', hidden_layers) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] # Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) num_weights = MainNetInterface.shapes_to_num_weights( self._param_shapes) if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print( 'Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append( nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert (len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert (np.all( np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i // 2].data = init_weights[i] else: self._layer_bias_vectors[i // 2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) self._is_properly_setup()
def __init__(self, target_shapes, num_tasks, layers=[50, 100], verbose=True, te_dim=8, no_te_embs=False, activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, ce_dim=None, dropout_rate=-1, use_spectral_norm=False, create_feedback_matrix=False, target_net_out_dim=10, random_scale_feedback_matrix=1., use_batch_norm=False, noise_dim=-1, temb_std=-1): """Build the network. The network will consist of several hidden layers and a dedicated output layer for each weight matrix/bias vector. The input to the network will be a learned task embedding. Args: target_shapes: A list of list of integers, denoting the shape of each parameter tensor in the main network (hence, determining the output of this network). num_tasks: The number of task embeddings needed. layers: A list of integers, each indicating the size of a hidden layer in this network. te_dim: The dimensionality of the task embeddings. no_te_embs: If this option is True, no class internal task embeddings are constructed and are instead expected to be provided to the forward method. activation_fn: The nonlinearity used in hidden layers. If None, no nonlinearity will be applied. use_bias: Whether layers may have bias terms. no_weights: If set to True, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the forward function. Does not affect task embeddings. init_weights (optional): This option is for convenience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convenient way of initializing a network with, for instance, a weight draw produced by the hypernetwork. Does not affect task embeddings. ce_dim (optional): The dimensionality of any additional embeddings, (in addition to the task embedding) that will be used as input to the hypernetwork. If this option is ``None``, no additional input is expected. Otherwise, an additional embedding has to be passed to the :meth:`forward` method (see argument ``ext_inputs``). A typical usecase would be a chunk embedding. dropout_rate (optional): If -1, no dropout will be applied. Otherwise a number between 0 and 1 is expected, denoting the dropout of hidden layers. use_spectral_norm: Enable spectral normalization for all layers. create_feedback_matrix: A feedback matrix for credit assignment in the main network will be created. See attribute :attr:`feedback_matrix`. target_net_out_dim: Target network output dimension. We need this information to create feedback matrices that can be used in conjunction with Direct Feedback Alignment. Only needs to be specified when enabling ``create_feedback_matrix``. random_scale_feedback_matrix: Scale of uniform distribution used to create the feedback matrix. Only needs to be specified when enabling ``create_feedback_matrix``. use_batch_norm: If True, batchnorm will be applied to all hidden layers. noise_dim (optional): If -1, no noise will be applied. Otherwise the hypernetwork will receive as additional input zero-mean Gaussian noise with unit variance during training (zeroes will be inputted during eval-mode). Note, if a batch of inputs is given, then a different noise vector is generated for every sample in the batch. temb_std (optional): If not -1, the task embeddings will be perturbed by zero-mean Gaussian noise with the given std (additive noise). The perturbation is only applied if the network is in training mode. Note, per batch of external inputs, the perturbation of the task embedding will be shared. """ # FIXME find a way using super to handle multiple inheritence. #super(HyperNetwork, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) if use_spectral_norm: raise NotImplementedError( 'Spectral normalization not yet ' + 'implemented for this hypernetwork type.') if use_batch_norm: # Note, batch normalization only makes sense when batch processing # is applied during training (i.e., batch size > 1). # As long as we only support processing of 1 task embedding, that # means that external inputs are required. if ce_dim is None: raise ValueError('Can\'t use batchnorm as long as ' + 'hypernetwork process more than 1 sample ' + '("ce_dim" must be specified).') raise NotImplementedError( 'Batch normalization not yet ' + 'implemented for this hypernetwork type.') assert (len(target_shapes) > 0) assert (no_te_embs or num_tasks > 0) self._num_tasks = num_tasks assert (init_weights is None or no_weights is False) self._no_weights = no_weights self._no_te_embs = no_te_embs self._te_dim = te_dim self._size_ext_input = ce_dim self._layers = layers self._target_shapes = target_shapes self._use_bias = use_bias self._act_fn = activation_fn self._init_weights = init_weights self._dropout_rate = dropout_rate self._noise_dim = noise_dim self._temb_std = temb_std self._shifts = None # FIXME temporary test. ### Hidden layers self._gen_layers(layers, te_dim, use_bias, no_weights, init_weights, ce_dim, noise_dim) if create_feedback_matrix: self._create_feedback_matrix(target_shapes, target_net_out_dim, random_scale_feedback_matrix) self._dropout = None if dropout_rate != -1: assert (dropout_rate >= 0 and dropout_rate <= 1) self._dropout = nn.Dropout(dropout_rate) # Task embeddings. if no_te_embs: self._task_embs = None else: self._task_embs = nn.ParameterList() for _ in range(num_tasks): self._task_embs.append( nn.Parameter(data=torch.Tensor(te_dim), requires_grad=True)) torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.) self._theta_shapes = self._hidden_dims + self._out_dims ntheta = MainNetInterface.shapes_to_num_weights(self._theta_shapes) ntembs = int(np.sum([t.numel() for t in self._task_embs])) \ if not no_te_embs else 0 self._num_weights = ntheta + ntembs self._num_outputs = MainNetInterface.shapes_to_num_weights( \ self.target_shapes) if verbose: print('Constructed hypernetwork with %d parameters (' % (ntheta \ + ntembs) + '%d network weights + %d task embedding weights).' % (ntheta, ntembs)) print('The hypernetwork has a total of %d outputs.' % self._num_outputs) self._is_properly_setup()
def __init__(self, main_dims, num_tasks, out_size=[64, 64], num_layers=5, num_filters=None, kernel_size=5, sa_units=[1, 3], rem_layers=[50, 50, 50], te_dim=8, ce_dim=8, no_theta=False, init_theta=None, use_batch_norm=False, use_spectral_norm=False, dropout_rate=-1, discard_remainder=False, noise_dim=-1, temb_std=-1): # FIXME find a way using super to handle multiple inheritence. #super(SAHyperNetwork, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) if init_theta is not None: # FIXME I would need to know the number of parameter tensors in each # hypernet before creating them to split the list init_theta. raise NotImplementedError( 'Argument "init_theta" not implemented ' + 'yet!') assert (init_theta is None or no_theta is False) self._no_theta = no_theta self._te_dim = te_dim self._discard_remainder = discard_remainder self._target_shapes = main_dims self._num_outputs = MainNetInterface.shapes_to_num_weights(main_dims) print('Building a self-attention hypernet for a network with %d '% \ self._num_outputs + 'weights.') assert (len(out_size) in [2, 3]) self._out_size = out_size num_outs = np.prod(out_size) assert (num_outs <= self._num_outputs) self._noise_dim = noise_dim self._temb_std = temb_std num_embs = self._num_outputs // num_outs rem_weights = self._num_outputs % num_outs if rem_weights > 0 and not discard_remainder: print('%d remaining weights (%.2f%%) are generated by a fully-' \ % (rem_weights, 100.0 * rem_weights / self._num_outputs) + \ 'connected hypernetwork.') elif rem_weights > 0: num_embs += 1 print('%d weights generated by the last chunk of the self-' % (num_outs - rem_weights) + 'attention hypernet will be ' + 'discarded.') self._num_embs = num_embs ### Generate Hypernet. self._hypernet = SAHnetPart(out_size=out_size, num_layers=num_layers, num_filters=num_filters, kernel_size=kernel_size, sa_units=sa_units, input_dim=te_dim + ce_dim + (noise_dim if noise_dim != -1 else 0), use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm, no_theta=no_theta, init_theta=None) self._rem_hypernet = None self._remainder = rem_weights if rem_weights > 0 and not discard_remainder: print('A second hypernet for the remainder of the weights has ' + 'to be created, as %d is not dividable by %d ' % (self._num_outputs, num_outs) + '(remaidner %d)' % rem_weights) self._rem_hypernet = HyperNetwork( [[rem_weights]], None, layers=rem_layers, te_dim=te_dim, no_te_embs=True, no_weights=no_theta, ce_dim=(noise_dim if noise_dim != -1 else None), dropout_rate=dropout_rate, use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm, noise_dim=-1, temb_std=None) ### Generate embeddings for all weight chunks. if no_theta: self._embs = None else: self._embs = nn.Parameter(data=torch.Tensor(num_embs, ce_dim), requires_grad=True) torch.nn.init.normal_(self._embs, mean=0., std=1.) # There is no need for a chunk embedding, as this network always # produces the same chunk. #if self._remainder > 0 and not discard_remainder: # self._rem_emb = nn.Parameter(data=torch.Tensor(1, ce_dim), # requires_grad=True) # torch.nn.init.normal_(self._rem_emb, mean=0., std=1.) ### Generate task embeddings. self._task_embs = nn.ParameterList() # We store individual task embeddings as it makes it easier to pass # only subsets of task embeddings to an optimizer. for _ in range(num_tasks): self._task_embs.append( nn.Parameter(data=torch.Tensor(te_dim), requires_grad=True)) torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.) self._num_weights = 0 for p in list(self.parameters()): self._num_weights += np.prod(p.shape) print('Total number of parameters in the hypernetwork: %d' % self._num_weights) self._theta_shapes = [[num_embs, ce_dim]] + \ self._hypernet.theta_shapes if self._rem_hypernet is not None: self._theta_shapes += self._rem_hypernet.theta_shapes
def __init__(self, n_in=1, n_out=1, hidden_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) # FIXME Spectral norm is incorrectly implemented. Function # `nn.utils.spectral_norm` needs to be called in the constructor, such # that sepc norm is wrapped around a module. if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x: x # identity self._param_shapes = [] self._param_shapes_meta = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer( n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) assert len(cmod_layer.param_shapes) == 2 self._param_shapes_meta.extend([ {'name': 'cm_scale', 'index': -1 if context_mod_no_weights else \ len(self._weights), 'layer': -1}, # 'layer' is set later. {'name': 'cm_shift', 'index': -1 if context_mod_no_weights else \ len(self._weights)+1, 'layer': -1}, # 'layer' is set later. ]) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) assert len(bn_layer.param_shapes) == 2 self._param_shapes_meta.extend([ { 'name': 'bn_scale', 'index': -1 if no_weights else len(self._weights), 'layer': -1 }, # 'layer' is set later. { 'name': 'bn_shift', 'index': -1 if no_weights else len(self._weights) + 1, 'layer': -1 }, # 'layer' is set later. ]) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] ### Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) for i, s in enumerate(linear_shapes): self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else len(self._weights) + i, 'layer': -1 # 'layer' is set later. }) num_weights = MainNetInterface.shapes_to_num_weights( self._param_shapes) ### Set missing meta information of param_shapes. offset = 1 if use_context_mod and context_mod_inputs else 0 shift = 1 if use_batch_norm: shift += 1 if use_context_mod: shift += 1 cm_offset = 2 if context_mod_post_activation else 1 bn_offset = 1 if context_mod_post_activation else 2 cm_ind = 0 bn_ind = 0 layer_ind = 0 for i, dd in enumerate(self._param_shapes_meta): if dd['name'].startswith('cm'): if offset == 1 and i in [0, 1]: dd['layer'] = 0 else: if cm_ind < len(hidden_layers): dd['layer'] = offset + cm_ind * shift + cm_offset else: assert cm_ind == len(hidden_layers) and \ not no_last_layer_context_mod # No batchnorm in output layer. dd['layer'] = offset + cm_ind * shift + 1 if dd['name'] == 'cm_shift': cm_ind += 1 elif dd['name'].startswith('bn'): dd['layer'] = offset + bn_ind * shift + bn_offset if dd['name'] == 'bn_shift': bn_ind += 1 else: dd['layer'] = offset + layer_ind * shift if not use_bias or dd['name'] == 'bias': layer_ind += 1 ### Uer information if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print( 'Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) if use_context_mod: if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) else: ncm = len(self._context_mod_shapes) self._hyper_shapes_learned_ref = \ list(range(ncm, len(self._param_shapes))) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append( nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert (len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert (np.all( np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i // 2].data = init_weights[i] else: self._layer_bias_vectors[i // 2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) if self._num_context_mod_shapes() == 0: # Note, that might be the case if no hidden layers exist and no # input or output modulation is used. self._use_context_mod = False self._is_properly_setup()
def __init__(self, in_shape=(32, 32, 3), num_classes=10, no_weights=False, init_weights=None, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, context_mod_apply_pixel_wise=False): super(BioConvNet, self).__init__(num_classes, True) assert(len(in_shape) == 3) # FIXME This assertion is not mandatory but a sanity check that the user # uses the Tensorflow layout. assert(in_shape[2] in [1, 3]) assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._in_shape = in_shape self._no_weights = no_weights self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() # Shapes of output activities for context-modulation, if used. cm_shapes = [] # Output shape of all layers. if context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) ### Define and initialize all conv and linear layers ### Bio-conv layers. H = in_shape[0] W = in_shape[1] C_in = in_shape[2] C = [64, 128, 256] K = [5, 5, 3] S = [2, 2, 1] P = [0, 0, 1] self._conv_layer = [] for i, C_out in enumerate(C): self._conv_layer.append(LocalConv2dLayer(C_in, C_out, H, W, K[i], stride=S[i], padding=P[i], no_weights=no_weights)) H = self._conv_layer[-1].out_height W = self._conv_layer[-1].out_width cm_shapes.append([C_out, H, W]) C_in = C_out self._param_shapes.extend(self._conv_layer[-1].param_shapes) if no_weights: self._hyper_shapes_learned.extend( \ self._conv_layer[-1].param_shapes) else: self._weights.extend(self._conv_layer[-1].weights) assert(len(self._conv_layer[-1].weights) == 2) self._layer_weight_tensors.append( \ self._conv_layer[-1].filters) self._layer_bias_vectors.append( \ self._conv_layer[-1].bias) ### Linear layers n_in = H * W * C_out assert(n_in == 6400) n = [1024, num_classes] for i, n_out in enumerate(n): W_shape = [n_out, n_in] b_shape = [n_out] # Note, that the last layer shape might not be used for context- # modulation. if i < (len(n)-1) or not no_last_layer_context_mod: cm_shapes.append([n_out]) n_in = n_out self._param_shapes.extend([W_shape, b_shape]) if no_weights: self._hyper_shapes_learned.extend([W_shape, b_shape]) else: W = nn.Parameter(torch.Tensor(*W_shape), requires_grad=True) b = nn.Parameter(torch.Tensor(*b_shape), requires_grad=True) init_params(W, b) self._weights.extend([W, b]) self._layer_weight_tensors.append(W) self._layer_bias_vectors.append(b) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None self._context_mod_weights = nn.ParameterList() if use_context_mod \ else None if use_context_mod: if not context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] for i, s in enumerate(cm_shapes): cmod_layer = ContextModLayer(s, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self._context_mod_shapes.extend(cmod_layer.param_shapes) if not context_mod_no_weights: self._context_mod_weights.extend(cmod_layer.weights) # We always had the context mod weights/shapes at the beginning of # our list attributes. self._param_shapes = self._context_mod_shapes + self._param_shapes if context_mod_no_weights: self._hyper_shapes_learned = self._context_mod_shapes + \ self._hyper_shapes_learned else: tmp = self._weights self._weights = nn.ParameterList(self._context_mod_weights) for w in tmp: self._weights.append(w) ### Apply custom init if given. if init_weights is not None: assert(len(self.weights) == len(init_weights)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), list(self._weights[i].shape)))) self._weights[i].data = init_weights[i] ### Print user info. num_weights = MainNetInterface.shapes_to_num_weights( \ self._param_shapes) if use_context_mod: cm_num_weights = MainNetInterface.shapes_to_num_weights( \ self._context_mod_shapes) print('Creating bio-plausible convnet with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.') self._is_properly_setup()
def num_outputs(self): """Getter for the attribute :attr:`num_outputs`.""" return MainNetInterface.shapes_to_num_weights(self.target_shapes)
def __init__(self, in_shape=(224, 224, 3), num_classes=1000, use_bias=True, use_fc_bias=None, num_feature_maps=(64, 64, 128, 256, 512), blocks_per_group=(2, 2, 2, 2), projection_shortcut=False, bottleneck_blocks=False, cutout_mod=False, no_weights=False, use_batch_norm=True, bn_track_stats=True, distill_bn_stats=False, chw_input_format=False, verbose=True, **kwargs): super(ResNetIN, self).__init__(num_classes, verbose) ### Parse or set context-mod arguments ### rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs) if 'context_mod_apply_pixel_wise' in rem_kwargs: rem_kwargs.remove('context_mod_apply_pixel_wise') if len(rem_kwargs) > 0: raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs)) # Since this is a conv-net, we may also want to add the following. if 'context_mod_apply_pixel_wise' not in kwargs.keys(): kwargs['context_mod_apply_pixel_wise'] = False self._use_context_mod = kwargs['use_context_mod'] self._context_mod_inputs = kwargs['context_mod_inputs'] self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod'] self._context_mod_no_weights = kwargs['context_mod_no_weights'] self._context_mod_post_activation = \ kwargs['context_mod_post_activation'] self._context_mod_gain_offset = kwargs['context_mod_gain_offset'] self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus'] self._context_mod_apply_pixel_wise = \ kwargs['context_mod_apply_pixel_wise'] ### Check or parse remaining arguments ### self._in_shape = in_shape self._projection_shortcut = projection_shortcut self._bottleneck_blocks = bottleneck_blocks self._cutout_mod = cutout_mod if use_fc_bias is None: use_fc_bias = use_bias # Also, checkout attribute `_has_bias` below. self._use_bias = use_bias self._use_fc_bias = use_fc_bias self._no_weights = no_weights assert not use_batch_norm or (not distill_bn_stats or bn_track_stats) self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._chw_input_format = chw_input_format if len(blocks_per_group) != 4: raise ValueError('Option "blocks_per_group" must be a list of 4 ' + 'integers.') self._num_blocks = blocks_per_group if len(num_feature_maps) != 5: raise ValueError('Option "num_feature_maps" must be a list of 5 ' + 'integers.') self._filter_sizes = list(num_feature_maps) # The first layer of group 3, 4 and 5 uses a strided convolution, so # the shorcut connections need to perform a downsampling operation. In # addition, whenever traversing from one group to the next, the number # of feature maps might change. In all these cases, the network might # benefit from smart shortcut connections, which means using projection # shortcuts, where a 1x1 conv is used for the mentioned skip connection. self._num_non_ident_skips = 3 # Strided convs: 2->3, 3->4 and 4->5 fs1 = self._filter_sizes[1] if self._bottleneck_blocks: fs1 *= 4 if self._filter_sizes[0] != fs1: self._num_non_ident_skips += 1 # Also handle 1->2. self._group_has_1x1 = [False] * 4 if self._projection_shortcut: for i in range(3, 3 - self._num_non_ident_skips, -1): self._group_has_1x1[i] = True # Number of conv layers (excluding skip connections) self._num_main_conv_layers = 1 + int(np.sum([self._num_blocks[i] * \ (3 if self._bottleneck_blocks else 2) for i in range(4)])) # The original architecture uses a 7x7 kernel in the first conv layer # and 3x3 or 1x1 kernels in all remaining layers. self._init_kernel_size = (7, 7) # All 3x3 layers have padding 1 and 1x1 layers have padding 0. self._init_padding = 3 self._init_stride = 2 if self._cutout_mod: self._init_kernel_size = (3, 3) self._init_padding = 1 self._init_stride = 1 ### Set required class attributes ### # Note, we did overwrite the getter for attribute `has_bias`, as it is # not applicable if the values of `use_bias` and `use_fc_bias` differ. self._has_bias = use_bias if use_bias == use_fc_bias else False self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer! self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_weights and \ self._context_mod_no_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not self._context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() ################################# ### Create context mod layers ### ################################# self._context_mod_layers = nn.ModuleList() if self._use_context_mod \ else None if self._use_context_mod: cm_layer_inds = [] cm_shapes = [] # Output shape of all layers. if self._context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) # We reserve layer zero for input context-mod. Otherwise, there # is no layer zero. cm_layer_inds.append(0) layer_out_shapes = self._compute_layer_out_sizes() cm_shapes.extend(layer_out_shapes) # All layer indices `l` with `l mod 3 == 0` are context-mod layers. cm_layer_inds.extend(range(3, 3 * len(layer_out_shapes) + 1, 3)) if self._no_last_layer_context_mod: cm_shapes = cm_shapes[:-1] cm_layer_inds = cm_layer_inds[:-1] if not self._context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds) ############################### ### Create batchnorm layers ### ############################### # We just use even numbers starting from 2 as layer indices for # batchnorm layers. if use_batch_norm: bn_sizes = [] for i, s in enumerate(self._filter_sizes): if i == 0: bn_sizes.append(s) else: for _ in range(self._num_blocks[i - 1]): if self._bottleneck_blocks: bn_sizes.extend([s, s, 4 * s]) else: bn_sizes.extend([s, s]) # All layer indices `l` with `l mod 3 == 2` are batchnorm layers. bn_layers = list(range(2, 3 * len(bn_sizes) + 1, 3)) # We also need a batchnorm layer per skip connection that uses 1x1 # projections. if self._projection_shortcut: bn_layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 2 factor = 4 if self._bottleneck_blocks else 1 for i in range(4): # For each transition between conv groups. if self._group_has_1x1[i]: bn_sizes.append(self._filter_sizes[i + 1] * factor) bn_layers.append(bn_layer_ind_skip) bn_layer_ind_skip += 3 self._add_batchnorm_layers(bn_sizes, no_weights, bn_layers=bn_layers, distill_bn_stats=distill_bn_stats, bn_track_stats=bn_track_stats) ###################################### ### Create skip connection weights ### ###################################### if self._projection_shortcut: layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 1 factor = 4 if self._bottleneck_blocks else 1 n_in = self._filter_sizes[0] for i in range(4): # For each transition between conv groups. if not self._group_has_1x1[i]: continue n_out = self._filter_sizes[i + 1] * factor skip_1x1_shape = [n_out, n_in, 1, 1] if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*skip_1x1_shape), requires_grad=True)) self._layer_weight_tensors.append( self._internal_params[-1]) self._layer_bias_vectors.append(None) init_params(self._layer_weight_tensors[-1]) else: self._hyper_shapes_learned.append(skip_1x1_shape) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(skip_1x1_shape) self._param_shapes_meta.append({ 'name': 'weight', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': layer_ind_skip }) layer_ind_skip += 3 n_in = n_out ############################################################ ### Create convolutional layers and final linear weights ### ############################################################ # Convolutional layers will get IDs `l` such that `l mod 3 == 1`. layer_id = 1 n_per_block = 3 if self._bottleneck_blocks else 2 for i in range(6): if i == 0: ### Fist layer. num = 1 prev_fs = self._in_shape[2] curr_fs = self._filter_sizes[0] kernel_size = self._init_kernel_size #stride = self._init_stride elif i == 5: ### Final fully-connected layer. num = 1 curr_fs = num_classes kernel_size = None else: # Group of residual blocks. num = self._num_blocks[i - 1] * n_per_block curr_fs = self._filter_sizes[i] kernel_size = (3, 3) # depends on block structure! for n in range(num): if i == 5: layer_shapes = [[curr_fs, prev_fs]] if use_fc_bias: layer_shapes.append([curr_fs]) prev_fs = curr_fs else: if i > 0 and self._bottleneck_blocks: if n % 3 == 0: fs = curr_fs ks = (1, 1) elif n % 3 == 1: fs = curr_fs ks = kernel_size else: fs = 4 * curr_fs ks = (1, 1) elif i > 0 and not self._bottleneck_blocks: fs = curr_fs ks = kernel_size else: fs = curr_fs ks = kernel_size layer_shapes = [[fs, prev_fs, *ks]] if use_bias: layer_shapes.append([fs]) prev_fs = fs for s in layer_shapes: if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) if len(s) == 1: self._layer_bias_vectors.append( \ self._internal_params[-1]) else: self._layer_weight_tensors.append( \ self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': layer_id }) layer_id += 3 # Initialize_weights if not no_weights: init_params(self._layer_weight_tensors[-1], self._layer_bias_vectors[-1] \ if len(layer_shapes) == 2 else None) ########################### ### Print infos to user ### ########################### if verbose: if self._use_context_mod: cm_param_shapes = [] for cm_layer in self.context_mod_layers: cm_param_shapes.extend(cm_layer.param_shapes) cm_num_params = \ MainNetInterface.shapes_to_num_weights(cm_param_shapes) print('Creating a "%s" with %d weights' \ % (str(self), self.num_params) + (' (including %d weights associated with-' % cm_num_params + 'context modulation)' if self._use_context_mod else '') + '.' + (' The network uses batchnorm.' if use_batch_norm else '')) self._is_properly_setup(check_has_bias=False)
def generate_replay_networks(config, data_handlers, device, create_rp_hnet=True, only_train_replay=False): """Create a replay model that consists of either a encoder/decoder or a discriminator/generator pair. Additionally, this method manages the creation of a hypernetwork for the generator/decoder. Following important configurations will be determined in order to create the replay model: * in- and output and hidden layer dimensions of the encoder/decoder. * architecture, chunk- and task-embedding details of decoder's hypernetwork. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device.. create_rp_hnet: Whether a hypernetwork for the replay should be constructed. If not, the decoder/generator will have trainable weights on its own. only_train_replay: We normally do not train on the last task since we do not need to replay this last tasks data. But if we want a replay method to be able to generate data from all tasks then we set this option to true. Returns: (tuple): Tuple containing: - **enc**: Encoder/discriminator network instance. - **dec**: Decoder/generator networkinstance. - **dec_hnet**: Hypernetwork instance for the decoder/generator. This return value is None if no hypernetwork should be constructed. """ if config.replay_method == 'gan': n_out = 1 else: n_out = config.latent_dim * 2 n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.single_class_replay: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.infer_task_id: # task inference network config.out_dim = 1 # builld encoder print('For the replay encoder/discriminator: ') enc_arch = misc.str_to_ints(config.enc_fc_arch) enc = MLP(n_in=n_in, n_out=n_out, hidden_layers=enc_arch, activation_fn=misc.str_to_act(config.enc_net_act), dropout_rate=config.enc_dropout_rate, no_weights=False).to(device) print('Constructed MLP with shapes: ', enc.param_shapes) init_params = list(enc.parameters()) # builld decoder print('For the replay decoder/generator: ') dec_arch = misc.str_to_ints(config.dec_fc_arch) # add dimensions for conditional input n_out = config.latent_dim if config.conditional_replay: n_out += config.conditional_dim dec = MLP(n_in=n_out, n_out=n_in, hidden_layers=dec_arch, activation_fn=misc.str_to_act(config.dec_net_act), use_bias=True, no_weights=config.rp_beta > 0, dropout_rate=config.dec_dropout_rate).to(device) print('Constructed MLP with shapes: ', dec.param_shapes) config.num_weights_enc = \ MainNetInterface.shapes_to_num_weights(enc.param_shapes) config.num_weights_dec = \ MainNetInterface.shapes_to_num_weights(dec.param_shapes) config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec # we do not need a replay model for the last task # train on last task or not if only_train_replay: subtr = 0 else: subtr = 1 num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1 if config.single_class_replay: # we do not need a replay model for the last task if config.num_tasks > 1: num_embeddings = config.out_dim * (config.num_tasks - subtr) else: num_embeddings = config.out_dim * (config.num_tasks) config.num_embeddings = num_embeddings # build decoder hnet if create_rp_hnet: print('For the decoder/generator hypernetwork: ') d_hnet = sim_utils.get_hnet_model(config, num_embeddings, device, dec.hyper_shapes_learned, cprefix='rp_') init_params += list(d_hnet.parameters()) config.num_weights_rp_hyper_net = sum(p.numel() for p in d_hnet.parameters() if p.requires_grad) config.compression_ratio_rp = config.num_weights_rp_hyper_net / \ config.num_weights_dec print('Created replay hypernetwork with ratio: ', config.compression_ratio_rp) if config.compression_ratio_rp > 1: print('Note that the compression ratio is computed compared to ' + 'current target network,\nthis might not be directly ' + 'comparable with the number of parameters of methods we ' + 'compare against.') else: num_embeddings = config.num_tasks - subtr d_hnet = None init_params += list(dec.parameters()) config.num_weights_rp_hyper_net = 0 config.compression_ratio_rp = 0 ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_rp_hnet: for temb in d_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(d_hnet, 'chunk_embeddings'): for emb in d_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) return enc, dec, d_hnet
def __init__( self, in_shape=(28, 28, 1), num_classes=10, verbose=True, arch='mnist_large', no_weights=False, init_weights=None, dropout_rate=-1, #0.5 **kwargs): super(LeNet, self).__init__(num_classes, verbose) self._in_shape = in_shape assert arch in LeNet._ARCHITECTURES.keys() self._chosen_arch = LeNet._ARCHITECTURES[arch] if num_classes != 10: self._chosen_arch[-2][0] = num_classes self._chosen_arch[-1][0] = num_classes # Sanity check, given current implementation. if arch.startswith('mnist'): if not in_shape[0] == in_shape[1] == 28: raise ValueError('MNIST LeNet architectures expect input ' + 'images of size 28x28.') else: if not in_shape[0] == in_shape[1] == 32: raise ValueError('CIFAR LeNet architectures expect input ' + 'images of size 32x32.') ### Parse or set context-mod arguments ### rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs) if len(rem_kwargs) > 0: raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs)) # Since this is a conv-net, we may also want to add the following. if 'context_mod_apply_pixel_wise' not in kwargs.keys(): kwargs['context_mod_apply_pixel_wise'] = False self._use_context_mod = kwargs['use_context_mod'] self._context_mod_inputs = kwargs['context_mod_inputs'] self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod'] self._context_mod_no_weights = kwargs['context_mod_no_weights'] self._context_mod_post_activation = \ kwargs['context_mod_post_activation'] self._context_mod_gain_offset = kwargs['context_mod_gain_offset'] self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus'] self._context_mod_apply_pixel_wise = \ kwargs['context_mod_apply_pixel_wise'] ### Setup class attributes ### assert(init_weights is None or \ (not no_weights or not self._context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer! self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_weights and \ self._context_mod_no_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not self._context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) # FIXME `nn.Dropout2d` zeroes out whole feature maps. Is that really # desired here? self._drop_conv1 = nn.Dropout2d(p=dropout_rate) self._drop_conv2 = nn.Dropout2d(p=dropout_rate) self._drop_fc1 = nn.Dropout(p=dropout_rate) ### Define and initialize context mod layers/weights ### self._context_mod_layers = nn.ModuleList() if self._use_context_mod \ else None if self._use_context_mod: cm_layer_inds = [] cm_shapes = [] # Output shape of all context-mod layers. if self._context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) # We reserve layer zero for input context-mod. Otherwise, there # is no layer zero. cm_layer_inds.append(0) layer_out_shapes = self._compute_layer_out_sizes() # Context-modulation is applied after the pooling layers. # So we delete the shapes of the conv-layer outputs and keep the # ones of the pooling layer outputs. del layer_out_shapes[2] del layer_out_shapes[0] cm_shapes.extend(layer_out_shapes) cm_layer_inds.extend(range(2, 2 * len(layer_out_shapes) + 1, 2)) if self._no_last_layer_context_mod: cm_shapes = cm_shapes[:-1] cm_layer_inds = cm_layer_inds[:-1] if not self._context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds) ### Define and add conv- and fc-layer weights. for i, s in enumerate(self._chosen_arch): if not no_weights: self._internal_params.append( nn.Parameter(torch.Tensor(*s), requires_grad=True)) if len(s) == 1: self._layer_bias_vectors.append(self._internal_params[-1]) else: self._layer_weight_tensors.append( self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else len(self._internal_params) - 1, 'layer': 2 * (i // 2) + 1 }) ### Initialize weights. if init_weights is not None: assert len(init_weights) == len(self.weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self.weights[i].shape)) self.weights[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) ### Print user info. if verbose: if self._use_context_mod: cm_param_shapes = [] for cm_layer in self.context_mod_layers: cm_param_shapes.extend(cm_layer.param_shapes) cm_num_weights = \ MainNetInterface.shapes_to_num_weights(cm_param_shapes) print('Creating a LeNet with %d weights' % self.num_params + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if self._use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 \ else '')) self._is_properly_setup()
def __init__(self, in_shape=[32, 32, 3], num_classes=10, verbose=True, n=5, no_weights=False, init_weights=None, use_batch_norm=True, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_apply_pixel_wise=False): super(ResNet, self).__init__(num_classes, verbose) self._in_shape = in_shape self._n = n assert (init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights assert (not use_batch_norm or (not distill_bn_stats or bn_track_stats)) self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise self._kernel_size = [3, 3] self._filter_sizes = [16, 16, 32, 64] self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True # We don't use any output non-linearity. self._has_linear_out = True self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] ################################################# ### Define and initialize context mod weights ### ################################################# self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_shapes = [] # Output shape of all layers. if context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) layer_out_shapes = self._compute_layer_out_sizes() if no_last_layer_context_mod: cm_shapes.extend(layer_out_shapes[:-1]) else: cm_shapes.extend(layer_out_shapes) if not context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] for i, s in enumerate(cm_shapes): cmod_layer = ContextModLayer( s, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert (np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] ########################### ### Print infos to user ### ########################### # Compute the total number of weights in this network and display # them to the user. # Note, this complicated calculation is not necessary as we can simply # count the number of weights afterwards. But it's an additional sanity # check for us. fs = self._filter_sizes num_weights = np.prod(self._kernel_size) * \ (in_shape[2] * fs[0] + np.sum([fs[i] * fs[i + 1] + \ (2 * n - 1) * fs[i + 1] ** 2 for i in range(3)])) + \ (fs[0] + 2 * n * np.sum([fs[i] for i in range(1, 4)])) + \ (fs[-1] * num_classes + num_classes) cm_num_weights = MainNetInterface.shapes_to_num_weights( \ self._context_mod_shapes) if use_context_mod else 0 num_weights += cm_num_weights if use_batch_norm: # The gamma and beta parameters of a batch norm layer are # learned as well. num_weights += 2 * (fs[0] + \ 2 * n * np.sum([fs[i] for i in range(1, 4)])) if verbose: print('A ResNet with %d layers and %d weights is created' \ % (6 * n + 2, num_weights) + (' (including %d context-mod weights).' % cm_num_weights \ if cm_num_weights > 0 else '.')) ################################################ ### Define and initialize batch norm weights ### ################################################ self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, s in enumerate(self._filter_sizes): if i == 0: num = 1 else: num = 2 * n for j in range(num): bn_layer = BatchNormLayer( s, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert (np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] # Note, method `_compute_hyper_shapes` doesn't take context-mod into # consideration. self._param_shapes.extend(self._compute_hyper_shapes(no_weights=True)) assert (num_weights == \ MainNetInterface.shapes_to_num_weights(self._param_shapes)) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = self._compute_hyper_shapes() else: # Context-mod weights are already included. self._hyper_shapes_learned.extend(self._compute_hyper_shapes()) self._is_properly_setup() return if use_batch_norm: for bn_layer in self._batchnorm_layers: self._weights.extend(bn_layer.weights) ############################################ ### Define and initialize layer weights ### ########################################### ### Does not include context-mod or batchnorm weights. # First layer. self._layer_weight_tensors.append( nn.Parameter(torch.Tensor(self._filter_sizes[0], self._in_shape[2], *self._kernel_size), requires_grad=True)) self._layer_bias_vectors.append( nn.Parameter(torch.Tensor(self._filter_sizes[0]), requires_grad=True)) # Each block consists of 2n layers. for i in range(1, len(self._filter_sizes)): in_filters = self._filter_sizes[i - 1] out_filters = self._filter_sizes[i] for _ in range(2 * n): self._layer_weight_tensors.append( nn.Parameter(torch.Tensor(out_filters, in_filters, *self._kernel_size), requires_grad=True)) self._layer_bias_vectors.append( nn.Parameter(torch.Tensor(out_filters), requires_grad=True)) # Note, that the first layer in this block has potentially a # different number of input filters. in_filters = out_filters # After the average pooling, there is one more dense layer. self._layer_weight_tensors.append( nn.Parameter(torch.Tensor(num_classes, self._filter_sizes[-1]), requires_grad=True)) self._layer_bias_vectors.append( nn.Parameter(torch.Tensor(num_classes), requires_grad=True)) # We add the weights interleaved, such that there are always consecutive # weight tensor and bias vector per layer. This fulfils the requirements # of attribute `mask_fc_out`. for i in range(len(self._layer_weight_tensors)): self._weights.append(self._layer_weight_tensors[i]) self._weights.append(self._layer_bias_vectors[i]) ### Initialize weights. if init_weights is not None: num_layers = 6 * n + 2 assert (len(init_weights) == 2 * num_layers) offset = 0 if use_batch_norm: offset = 2 * (6 * n + 1) assert (len(self._weights) == offset + 2 * num_layers) for i in range(len(init_weights)): j = offset + i assert (np.all( np.equal(list(init_weights[i].shape), list(self._weights[j].shape)))) self._weights[j].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) self._is_properly_setup()
def __init__(self, in_shape=(32, 32, 3), num_classes=10, n=4, k=10, num_feature_maps=(16, 16, 32, 64), use_bias=True, use_fc_bias=None, no_weights=False, use_batch_norm=True, bn_track_stats=True, distill_bn_stats=False, dropout_rate=-1, chw_input_format=False, verbose=True, **kwargs): super(WRN, self).__init__(num_classes, verbose) ### Parse or set context-mod arguments ### rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs) if len(rem_kwargs) > 0: raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs)) # Since this is a conv-net, we may also want to add the following. if 'context_mod_apply_pixel_wise' not in kwargs.keys(): kwargs['context_mod_apply_pixel_wise'] = False self._use_context_mod = kwargs['use_context_mod'] self._context_mod_inputs = kwargs['context_mod_inputs'] self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod'] self._context_mod_no_weights = kwargs['context_mod_no_weights'] self._context_mod_post_activation = \ kwargs['context_mod_post_activation'] self._context_mod_gain_offset = kwargs['context_mod_gain_offset'] self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus'] self._context_mod_apply_pixel_wise = \ kwargs['context_mod_apply_pixel_wise'] ### Check or parse remaining arguments ### self._in_shape = in_shape self._n = n self._k = k if use_fc_bias is None: use_fc_bias = use_bias # Also, checkout attribute `_has_bias` below. self._use_bias = use_bias self._use_fc_bias = use_fc_bias self._no_weights = no_weights assert not use_batch_norm or (not distill_bn_stats or bn_track_stats) self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._dropout_rate = dropout_rate self._chw_input_format = chw_input_format # The original authors found that the best configuration uses this # kernel in all convolutional layers. self._kernel_size = (3, 3) if len(num_feature_maps) != 4: raise ValueError('Option "num_feature_maps" must be a list of 4 ' + 'integers.') self._filter_sizes = list(num_feature_maps) if k != 1: for i in range(1, 4): self._filter_sizes[i] = k * num_feature_maps[i] # Strides used in the first layer of each convolutional group. self._strides = (1, 1, 2, 2) ### Set required class attributes ### # Note, we did overwrite the getter for attribute `has_bias`, as it is # not applicable if the values of `use_bias` and `use_fc_bias` differ. self._has_bias = use_bias if use_bias == use_fc_bias else False self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer! self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_weights and \ self._context_mod_no_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not self._context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if dropout_rate != -1: assert dropout_rate >= 0. and dropout_rate <= 1. self._dropout = nn.Dropout(p=dropout_rate) ################################# ### Create context mod layers ### ################################# self._context_mod_layers = nn.ModuleList() if self._use_context_mod \ else None if self._use_context_mod: cm_layer_inds = [] cm_shapes = [] # Output shape of all layers. if self._context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) # We reserve layer zero for input context-mod. Otherwise, there # is no layer zero. cm_layer_inds.append(0) layer_out_shapes = self._compute_layer_out_sizes() cm_shapes.extend(layer_out_shapes) # All layer indices `l` with `l mod 3 == 0` are context-mod layers. cm_layer_inds.extend(range(3, 3*len(layer_out_shapes)+1, 3)) if self._no_last_layer_context_mod: cm_shapes = cm_shapes[:-1] cm_layer_inds = cm_layer_inds[:-1] if not self._context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds) ############################### ### Create batchnorm layers ### ############################### # We just use even numbers starting from 2 as layer indices for # batchnorm layers. if use_batch_norm: bn_sizes = [] for i, s in enumerate(self._filter_sizes): if i == 0: bn_sizes.append(s) else: bn_sizes.extend([s] * (2*n)) # All layer indices `l` with `l mod 3 == 2` are batchnorm layers. self._add_batchnorm_layers(bn_sizes, no_weights, bn_layers=list(range(2, 3*len(bn_sizes)+1, 3)), distill_bn_stats=distill_bn_stats, bn_track_stats=bn_track_stats) ###################################### ### Create skip connection weights ### ###################################### # We use 1x1 convolutional layers for residual blocks in case the # number of input and output feature maps disagrees. We also use 1x1 # convolutions whenever a stride greater than 1 is applied. This is not # necessary in my opinion (as it adds extra weights that do not affect # the downsampling itself), but commonly done; or instance, in the # original PyTorch implementation. # Note, there may be maximally 3 1x1 layers added to the network. # Note, we use 1x1 conv layers without biases. skip_1x1_shapes = [] self._group_has_1x1 = [False] * 3 for i in range(1, 4): if self._filter_sizes[i-1] != self._filter_sizes[i] or \ self._strides[i] != 1: skip_1x1_shapes.append([self._filter_sizes[i], self._filter_sizes[i-1], 1, 1]) self._group_has_1x1[i-1] = True for s in skip_1x1_shapes: if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) self._layer_weight_tensors.append(self._internal_params[-1]) init_params(self._layer_weight_tensors[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': -1 }) ############################################################ ### Create convolutional layers and final linear weights ### ############################################################ # Convolutional layers will get IDs `l` such that `l mod 3 == 1`. layer_id = 1 for i in range(5): if i == 0: ### Fist layer. num = 1 prev_fs = self._in_shape[2] curr_fs = self._filter_sizes[0] elif i == 4: ### Final fully-connected layer. num = 1 curr_fs = num_classes else: # Group of residual blocks. num = 2 * n curr_fs = self._filter_sizes[i] for _ in range(num): if i == 4: layer_shapes = [[curr_fs, prev_fs]] if use_fc_bias: layer_shapes.append([curr_fs]) else: layer_shapes = [[curr_fs, prev_fs, *self._kernel_size]] if use_bias: layer_shapes.append([curr_fs]) for s in layer_shapes: if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) if len(s) == 1: self._layer_bias_vectors.append( \ self._internal_params[-1]) else: self._layer_weight_tensors.append( \ self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': layer_id }) prev_fs = curr_fs layer_id += 3 # Initialize_weights if not no_weights: init_params(self._layer_weight_tensors[-1], self._layer_bias_vectors[-1] \ if len(layer_shapes) == 2 else None) ########################### ### Print infos to user ### ########################### if verbose: if self._use_context_mod: cm_param_shapes = [] for cm_layer in self.context_mod_layers: cm_param_shapes.extend(cm_layer.param_shapes) cm_num_params = \ MainNetInterface.shapes_to_num_weights(cm_param_shapes) print('Creating a WideResnet "%s" with %d weights' \ % (str(self), self.num_params) + (' (including %d weights associated with-' % cm_num_params + 'context modulation)' if self._use_context_mod else '') + '.' + (' The network uses batchnorm.' if use_batch_norm else '') + (' The network uses dropout.' if dropout_rate != -1 \ else '')) self._is_properly_setup(check_has_bias=False)
def __init__(self, weight_shapes, activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, out_fn=None, verbose=True, use_spectral_norm=False, use_batch_norm=False): """Initialize the network. Args: weight_shapes: A list of list of integers, denoting the shape of each parameter tensor in this network. Note, this parameter only has an effect on the construction of this network, if "no_weights" is False. Otherwise, it is just used to check the shapes of the input to the network in the forward method. activation_fn: The nonlinearity used in hidden layers. If None, no nonlinearity will be applied. use_bias: Whether layers may have bias terms. no_weights: If set to True, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the forward function. init_weights (optional): This option is for convinience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convinient way of initializing a network with a weight draw produced by the hypernetwork. dropout_rate: If -1, no dropout will be applied. Otherwise a number between 0 and 1 is expected, denoting the dropout rate of hidden layers. out_fn (optional): If provided, this function will be applied to the output neurons of the network. Note, this changes the output of the forward method. verbose: Whether to print the number of weights in the network. use_spectral_norm: Use spectral normalization for training. use_batch_norm: Whether batch normalization should be used. """ # FIXME find a way using super to handle multiple inheritence. #super(MainNetwork, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) warn('Please use class "mnets.mlp.MLP" instead.', DeprecationWarning) if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm: raise NotImplementedError('Batch normalization not yet ' + 'implemented for this network.') assert(len(weight_shapes) > 0) self._all_shapes = weight_shapes self._has_bias = use_bias self._a_fun = activation_fn assert(init_weights is None or no_weights is False) self._no_weights = no_weights self._dropout_rate = dropout_rate self._out_fn = out_fn self._has_fc_out = True if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x : x # identity if verbose: print('Creating an MLP with %d weights' \ % (MnetAPIV2.shapes_to_num_weights(self._all_shapes)) + (', that uses dropout.' if dropout_rate != -1 else '.')) if dropout_rate != -1: assert(dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) self._weights = None if no_weights: self._hyper_shapes = self._all_shapes self._is_properly_setup() return ### Define and initialize network weights. # Each odd entry of this list will contain a weight Tensor and each # even entry a bias vector. self._weights = nn.ParameterList() for i, dims in enumerate(self._all_shapes): self._weights.append(nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if init_weights is not None: assert(len(init_weights) == len(self._all_shapes)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), list(self._weights[i].shape)))) self._weights[i].data = init_weights[i] else: for i in range(0, len(self._weights), 2 if use_bias else 1): if use_bias: init_params(self._weights[i], self._weights[i + 1]) else: init_params(self._weights[i]) self._is_properly_setup()
def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup()