def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.hypernet_args`. num_tasks (int): The number of task embeddings the hypernetwork should have. device: PyTorch device. mnet_shapes: Dimensions of the weight tensors of the main network. See main net argument :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hypernet_args`. Returns: The created hypernet model. """ if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) hyper_chunks = misc.str_to_ints(gc('hyper_chunks')) assert (len(hyper_chunks) in [1, 2, 3]) if len(hyper_chunks) == 1: hyper_chunks = hyper_chunks[0] hnet_arch = misc.str_to_ints(gc('hnet_arch')) sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters')) sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels')) sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers')) hnet_act = misc.str_to_act(gc('hnet_act')) if isinstance(hyper_chunks, list): # Chunked self-attention hypernet if len(sa_hnet_kernels) == 1: sa_hnet_kernels = sa_hnet_kernels[0] # Note, that the user can specify the kernel size for each dimension and # layer separately. elif len(sa_hnet_kernels) > 2 and \ len(sa_hnet_kernels) == gc('sa_hnet_num_layers') * 2: tmp = sa_hnet_kernels sa_hnet_kernels = [] for i in range(0, len(tmp), 2): sa_hnet_kernels.append([tmp[i], tmp[i + 1]]) if gc('hnet_dropout_rate') != -1: warn('SA-Hypernet doesn\'t use dropout. Dropout rate will be ' + 'ignored.') if gc('hnet_act') != 'relu': warn('SA-Hypernet doesn\'t support the other non-linearities ' + 'than ReLUs yet. Option "%shnet_act" (%s) will be ignored.' % (cprefix, gc('hnet_act'))) hnet = SAHyperNetwork( mnet_shapes, num_tasks, out_size=hyper_chunks, num_layers=gc('sa_hnet_num_layers'), num_filters=sa_hnet_filters, kernel_size=sa_hnet_kernels, sa_units=sa_hnet_attention_layers, # Note, we don't use an additional hypernet for the remaining # weights! #rem_layers=hnet_arch, te_dim=gc('temb_size'), ce_dim=gc('emb_size'), no_theta=False, # Batchnorm and spectral norma are not yet implemented. #use_batch_norm=gc('hnet_batchnorm'), #use_spectral_norm=gc('hnet_specnorm'), # Droput would only be used for the additional network, which we # don't use. #dropout_rate=gc('hnet_dropout_rate'), discard_remainder=True, noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) elif hyper_chunks != -1: # Chunked fully-connected hypernet hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks, chunk_dim=hyper_chunks, layers=hnet_arch, activation_fn=hnet_act, te_dim=gc('temb_size'), ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) else: # Fully-connected hypernet. hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch, te_dim=gc('temb_size'), activation_fn=hnet_act, dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) return hnet
def _generate_networks(config, data_handlers, device, create_hnet=True, create_rnet=False, no_replay=False): """Create the main-net, hypernetwork and recognition network. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. create_hnet: Whether a hypernetwork should be constructed. If not, the main network will have trainable weights. create_rnet: Whether a task-recognition autoencoder should be created. no_replay: If the recognition network should be an instance of class MainModel rather than of class RecognitionNet (note, for multitask learning, no replay network is required). Returns: mnet: Main network instance. hnet: Hypernetwork instance. This return value is None if no hypernetwork should be constructed. rnet: RecognitionNet instance. This return value is None if no recognition network should be constructed. """ num_tasks = len(data_handlers) n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks main_arch = misc.str_to_ints(config.main_arch) main_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=n_y, hidden_layers=main_arch) mnet = MainNetwork(main_shapes, activation_fn=misc.str_to_act(config.main_act), use_bias=True, no_weights=create_hnet).to(device) if create_hnet: hnet_arch = misc.str_to_ints(config.hnet_arch) hnet = HyperNetwork(main_shapes, num_tasks, layers=hnet_arch, te_dim=config.emb_size, activation_fn=misc.str_to_act( config.hnet_act)).to(device) init_params = list(hnet.parameters()) else: hnet = None init_params = list(mnet.parameters()) if create_rnet: ae_arch = misc.str_to_ints(config.ae_arch) if no_replay: rnet_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=num_tasks, hidden_layers=ae_arch, use_bias=True) rnet = MainNetwork(rnet_shapes, activation_fn=misc.str_to_act(config.ae_act), use_bias=True, no_weights=False, dropout_rate=-1, out_fn=lambda x: F.softmax(x, dim=1)) else: rnet = RecognitionNet(n_x, num_tasks, dim_z=config.ae_dim_z, enc_layers=ae_arch, activation_fn=misc.str_to_act(config.ae_act), use_bias=True).to(device) init_params += list(rnet.parameters()) else: rnet = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) elif config.normal_init: torch.nn.init.normal_(W, mean=0, std=config.std_normal_init) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_hnet: for temb in hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if config.use_hyperfan_init: hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2) return mnet, hnet, rnet
def __init__(self, target_shapes, num_tasks, chunk_dim=2586, layers=[50, 100], te_dim=8, activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, ce_dim=None, init_weights=None, dropout_rate=-1, noise_dim=-1, temb_std=-1): # FIXME find a way using super to handle multiple inheritence. #super(ChunkedHyperNetworkHandler, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) assert (len(target_shapes) > 0) assert (init_weights is None or no_weights is False) assert (ce_dim is not None) self._target_shapes = target_shapes self._num_tasks = num_tasks self._ce_dim = ce_dim self._chunk_dim = chunk_dim self._layers = layers self._use_bias = use_bias self._act_fn = activation_fn self._init_weights = init_weights self._no_weights = no_weights self._te_dim = te_dim self._noise_dim = noise_dim self._temb_std = temb_std self._shifts = None # FIXME temporary test. # FIXME: weights should incorporate chunk embeddings as they are part of # theta. if init_weights is not None: warn('Argument "init_weights" does not yet allow initialization ' + 'of chunk embeddings.') ### Generate Hypernet with chunk_dim output. # Note, we can safely pass "temb_std" to the full hypernetwork, as we # process all chunks in one big batch and the hypernet will use the same # perturbed task embeddings for that reason (i.e., noise is shared). self._hypernet = HyperNetwork([[chunk_dim]], num_tasks, verbose=False, layers=layers, te_dim=te_dim, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=init_weights, ce_dim=ce_dim + (noise_dim if noise_dim != -1 else 0), dropout_rate=dropout_rate, noise_dim=-1, temb_std=temb_std) self._num_outputs = MainNetInterface.shapes_to_num_weights( \ self._target_shapes) ### Generate embeddings for all weight chunks. self._num_chunks = int(np.ceil(self._num_outputs / chunk_dim)) if no_weights: self._embs = None else: self._embs = nn.Parameter(data=torch.Tensor( self._num_chunks, ce_dim), requires_grad=True) nn.init.normal_(self._embs, mean=0., std=1.) # Note, the chunk embeddings are part of theta. hdims = self._hypernet.theta_shapes ntheta = MainNetInterface.shapes_to_num_weights(hdims) + \ (self._embs.numel() if not no_weights else 0) ntembs = int(np.sum([t.numel() for t in self.get_task_embs()])) self._num_weights = ntheta + ntembs print('Constructed hypernetwork with %d parameters ' % (ntheta \ + ntembs) + '(%d network weights + %d task embedding weights).' % (ntheta, ntembs)) print('The hypernetwork has a total of %d outputs.' % self._num_outputs) self._theta_shapes = [[self._num_chunks, ce_dim]] + \ self._hypernet.theta_shapes self._is_properly_setup()
class ChunkedHyperNetworkHandler(nn.Module, CLHyperNetInterface): """This class handles an instance of the class :class:`toy_example.hyper_model.HyperNetwork` to produce the weights of a main network. I.e., it generates one instance of a full hypernetwork (that will produce only one chunk rather than all main net weights) and handles all the embedding vectors. Additionally, it provides an easy interface to generate the weights of the main network. Note: To implement ``noise_dim`` this class does not make use of the underlying full hypernetwork. Instead, it concatenates noise to the chunk embeddings before inputting them to the full hypernet (in this way, we make sure that we use the same noise (for all chunks) while producing one set of main network weights). Note: If ``no_weights`` is set, then there also won't be internal chunk embeddings. Attributes: chunk_embeddings: List of embedding vectors that encode main network location of the weights to be generated. Args: (....): See constructor arguments of class :class:`toy_example.hyper_model.HyperNetwork`. chunk_dim (int): The chunk size, i.e, the number of weights produced by single the internally maintained instance of a full hypernet. ce_dim (int): The size of the chunk embeddings. """ def __init__(self, target_shapes, num_tasks, chunk_dim=2586, layers=[50, 100], te_dim=8, activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, ce_dim=None, init_weights=None, dropout_rate=-1, noise_dim=-1, temb_std=-1): # FIXME find a way using super to handle multiple inheritence. #super(ChunkedHyperNetworkHandler, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) assert (len(target_shapes) > 0) assert (init_weights is None or no_weights is False) assert (ce_dim is not None) self._target_shapes = target_shapes self._num_tasks = num_tasks self._ce_dim = ce_dim self._chunk_dim = chunk_dim self._layers = layers self._use_bias = use_bias self._act_fn = activation_fn self._init_weights = init_weights self._no_weights = no_weights self._te_dim = te_dim self._noise_dim = noise_dim self._temb_std = temb_std self._shifts = None # FIXME temporary test. # FIXME: weights should incorporate chunk embeddings as they are part of # theta. if init_weights is not None: warn('Argument "init_weights" does not yet allow initialization ' + 'of chunk embeddings.') ### Generate Hypernet with chunk_dim output. # Note, we can safely pass "temb_std" to the full hypernetwork, as we # process all chunks in one big batch and the hypernet will use the same # perturbed task embeddings for that reason (i.e., noise is shared). self._hypernet = HyperNetwork([[chunk_dim]], num_tasks, verbose=False, layers=layers, te_dim=te_dim, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=init_weights, ce_dim=ce_dim + (noise_dim if noise_dim != -1 else 0), dropout_rate=dropout_rate, noise_dim=-1, temb_std=temb_std) self._num_outputs = MainNetInterface.shapes_to_num_weights( \ self._target_shapes) ### Generate embeddings for all weight chunks. self._num_chunks = int(np.ceil(self._num_outputs / chunk_dim)) if no_weights: self._embs = None else: self._embs = nn.Parameter(data=torch.Tensor( self._num_chunks, ce_dim), requires_grad=True) nn.init.normal_(self._embs, mean=0., std=1.) # Note, the chunk embeddings are part of theta. hdims = self._hypernet.theta_shapes ntheta = MainNetInterface.shapes_to_num_weights(hdims) + \ (self._embs.numel() if not no_weights else 0) ntembs = int(np.sum([t.numel() for t in self.get_task_embs()])) self._num_weights = ntheta + ntembs print('Constructed hypernetwork with %d parameters ' % (ntheta \ + ntembs) + '(%d network weights + %d task embedding weights).' % (ntheta, ntembs)) print('The hypernetwork has a total of %d outputs.' % self._num_outputs) self._theta_shapes = [[self._num_chunks, ce_dim]] + \ self._hypernet.theta_shapes self._is_properly_setup() @property def chunk_embeddings(self): """Getter for read-only attribute :attr:`chunk_embeddings`. Get the chunk embeddings used to produce a full set of main network weights with the underlying (small) hypernetwork. Returns: A list of all chunk embedding vectors. """ return list(torch.split(self._embs, 1, dim=0)) # @override from CLHyperNetInterface def forward(self, task_id=None, theta=None, dTheta=None, task_emb=None, ext_inputs=None, squeeze=True): """Implementation of abstract super class method. Note: This methods can't handle external inputs yet! The method will iterate through the set of internal chunk embeddings, calling the internally maintained (small) full hypernetwork for each, in order to generate a full set of main network weights. """ if task_id is None and task_emb is None: raise Exception('The hyper network has to get either a task ID' + 'to choose the learned embedding or directly ' + 'get an embedding as input (e.g. from a task ' + 'recognition model).') if not self.has_theta and theta is None: raise Exception( 'Network was generated without internal weights. ' + 'Hence, "theta" option may not be None.') if ext_inputs is not None: # FIXME If this will be implemented, please consider: # * batch size will have to be multiplied based on num chunk # embeddings and the number of external inputs -> large batches # * noise dim must adhere correct behavior (different noise per # external input). raise NotImplementedError( 'This hypernetwork implementation does ' + 'not yet support the passing of external inputs.') if theta is None: theta = self.theta else: assert (len(theta) == len(self.theta_shapes)) assert (np.all(np.equal(self._embs.shape, list(theta[0].shape)))) chunk_embs = theta[0] hnet_theta = theta[1:] if dTheta is not None: assert (len(dTheta) == len(self.theta_shapes)) chunk_embs = chunk_embs + dTheta[0] hnet_dTheta = dTheta[1:] else: hnet_dTheta = None # Concatenate the same noise to all chunks, such that it can be # viewed as if it were an external input. if self._noise_dim != -1: if self.training: eps = torch.randn((1, self._noise_dim)) else: eps = torch.zeros((1, self._noise_dim)) if self._embs.is_cuda: eps = eps.to(self._embs.get_device()) eps = eps.expand(self._num_chunks, self._noise_dim) chunk_embs = torch.cat([chunk_embs, eps], dim=1) # get chunked weights from HyperNet weights = self._hypernet.forward(task_id=task_id, theta=hnet_theta, dTheta=hnet_dTheta, task_emb=task_emb, ext_inputs=chunk_embs) weights = weights[0].view(1, -1) ### Reshape weights dependent on the main networks architecture. ind = 0 ret = [] for j, s in enumerate(self.target_shapes): num = int(np.prod(s)) W = weights[0][ind:ind + num] ind += num W = W.view(*s) if self._shifts is not None: # FIXME temporary test! W += self._shifts[j] ret.append(W) return ret # @override from CLHyperNetInterface @property def theta(self): """Getter for read-only attribute ``theta``. ``theta`` are all learnable parameters of the chunked hypernet including the chunk embeddings that need to be learned. Not included are the task embeddings, i.e., ``theta`` comprises all parameters that should be regularized in order to avoid catastrophic forgetting when training the hypernetwork in a Continual Learning setting. Note: Chunk embeddings are prepended to the list of weights ``theta`` from the internal full hypernetwork. Returns: A list of tensors or ``None``, if ``no_weights`` was set to ``True`` in the constructor of this class. """ return [self._embs] + list(self._hypernet.theta) # @override from CLHyperNetInterface def get_task_embs(self): """Overriden super class method.""" return self._hypernet.get_task_embs() # @override from CLHyperNetInterface def get_task_emb(self, task_id): """Overriden super class method.""" return self._hypernet.get_task_emb(task_id) # @override from CLHyperNetInterface @property def has_theta(self): """Getter for read-only attribute ``has_theta``.""" return not self._no_weights # @override from CLHyperNetInterface @property def has_task_embs(self): """Getter for read-only attribute ``has_task_embs``.""" return self._hypernet.has_task_embs # @override from CLHyperNetInterface @property def num_task_embs(self): """Getter for read-only attribute ``num_task_embs``.""" return self._hypernet.num_task_embs def apply_chunked_hyperfan_init(self, method='in', use_xavier=False, temb_var=1., ext_inp_var=1., eps=1e-5, cemb_normal_init=False): r"""Initialize the network using a chunked hyperfan init. Inspired by the method `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we implemented for the full hypernetwork in method :meth:`toy_example.hyper_model.HyperNetwork.apply_hyperfan_init`, we heuristically developed a better initialization method for chunked hypernetworks. Unfortunately, the `Hyperfan Init` method does not apply to this kind of hypernetwork, since we reuse the same hypernet output head for the whole main network. Luckily, we can provide a simple heuristic. Similar to `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play with the variance of the input embeddings to affect the variance of the output weights. In a chunked hypernetwork, the input for each chunk is identical except for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}` denote the remaining inputs to the hypernetwork, which are identical for all chunks. Then, assuming the hypernetwork was initialized via fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}` can be written as follows (see documentation of method :meth:`toy_example.hyper_model.HyperNetwork.apply_hyperfan_init`): .. math:: \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Hence, we can achieve a desired output variance :math:`\text{Var}(v)` by initializing the chunk embeddinggs :math:`\mathbf{c}` via the following variance: .. math:: \text{Var}(c) = \max \Big\{ 0, \ \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \ n_e \text{Var}(e) \big] \Big\} Now, one important question remains. How do we pick a desired output variance :math:`\text{Var}(v)` for a chunk? Note, a chunk may include weights from several layers. The likelihood for this to happen depends on the main net architecture and the chunk size (see constructor argument ``chunk_dim``). The smaller the chunk size, the less likely it is that a chunk will contain elements from multiple main net weight tensors. In case each chunk would contain only weights from one main net weight tensor, we could simply pick the variance :math:`\text{Var}(v)` that would have been chosen by a main net initialization method (such as Xavier). In case a chunk contains contributions from several main net weight tensors, we apply the following heuristic. If a chunk contains contributions of a set of main network weight tensors :math:`W_1, \dots, W_K` with relative contribution sizes\ :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where :math:`n_v` denotes the chunk size and if the corresponding main network initialization method would require init varainces :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request a weighted average as follow: .. math:: \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k) What about bias vectors? Usually, the variance analysis applied to Xavier or Kaiming init assumes that biases are initialized to zero. This is not possible in this setting, as it would require assigning a negative variance to :math:`\mathbf{c}`. Instead, we follow the default PyTorch initialization (e.g., see method ``reset_parameters`` in class :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of initialization corresponds to a variance of :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`. Warning: Note, in order to compute the fan-in of layers with bias vectors, we need access to the corresponding weight tensor in the same layer. Since there is no clean way of matching a bias shape to its corresponging weight tensor shape we use the following heuristic, which should be correct for most main networks. We assume that the shape directly preceding a bias shape in the constructor argument ``target_shapes`` is the corresponding weight tensor. Note: Constructor argument ``noise_dim`` is automatically considered by this method. Note: We hypernet inputs should be zero mean. Warning: This method considers all 1D target weight tensors as bias vectors. Note: To avoid that the variances with which chunks are initialized have to be clipped (because they are too small or even negative), the variance of the remaining hypernet inputs should be properly scaled. In general, one should adhere the following rule .. math:: \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v) This method will calculate and print the maximum value that should be chosen for :math:`\text{Var}(e)` and will print warnings if variances have to be clipped. Args: method (str): The type of initialization that should be applied. Possible options are: - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output variances of the hypernetwork should correspond to fan-in variances. - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output variances of the hypernetwork should correspond to fan-out variances. - ``harmonic``: Use the harmonic mean of the fan-in and fan-out variance as target variance of the hypernetwork output. use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``) init should be used. temb_var (float): The initial variance of the task embeddings. .. note:: If ``temb_std`` was set in the constructor, then this method will automatically correct the provided ``temb_var`` as follows: :code:`temb_var += temb_std**2`. ext_inp_var (float): The initial variance of the external input. Only needs to be specified if external inputs are provided. .. note:: Not supported yet by this hypernetwork type, but should soon be included as a feature. eps (float): The minimum variance with which a chunk embedding is initialized. cemb_normal_init (bool): Use normal init for chunk embeedings rather than uniform init. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value for argument "method".') if not self.has_theta: raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') ### Compute input variance ### # The input variance does not include the variance of chunk embeddings! # Instead, it is the varaince of the inputs that are shared across all # chunks. if self._temb_std != -1: # Sum of uncorrelated variables. temb_var += self._temb_std**2 assert self._noise_dim == -1 or self._noise_dim > 0 # TODO external inputs are not yet considered. inp_dim = self._te_dim + \ (self._noise_dim if self._noise_dim != -1 else 0) #(self._size_ext_input if self._size_ext_input is not None else 0) \ inp_var = (self._te_dim / inp_dim) * temb_var #if self._size_ext_input is not None: # inp_var += (self._size_ext_input / inp_dim) * ext_inp_var if self._noise_dim != -1: inp_var += (self._noise_dim / inp_dim) * 1. c_dim = self._ce_dim ### Compute target variance of each output tensor ### target_vars = [] for i, s in enumerate(self.target_shapes): # FIXME 1D shape is not necessarily bias vector. if len(s) == 1: # Assume it's a bias vector # Assume that last shape has been the corresponding weight # tensor. if i > 0 and len(self.target_shapes[i - 1]) > 1: fan_in, _ = iutils.calc_fan_in_and_out( \ self.target_shapes[i-1]) else: # FIXME Quick-fix, use fan-out instead. fan_in = s[0] target_vars.append(1. / (3. * fan_in)) else: fan_in, fan_out = iutils.calc_fan_in_and_out(s) c_relu = 1 if use_xavier else 2 var_in = c_relu / fan_in var_out = c_relu / fan_out if method == 'in': var = var_in elif method == 'out': var = var_out else: var = 2 * (1. / var_in + 1. / var_out) target_vars.append(var) ### Target variance per chunk ### chunk_vars = [] i = 0 n = np.prod(self.target_shapes[i]) for j in range(self._num_chunks): m = self._chunk_dim var = 0 while m > 0: # Special treatment to fill up last chunk. if j == self._num_chunks - 1 and i == len(target_vars) - 1: assert n <= m o = m else: o = min(m, n) var += o / self._chunk_dim * target_vars[i] m -= o n -= o if n == 0: i += 1 if i < len(target_vars): n = np.prod(self.target_shapes[i]) chunk_vars.append(var) max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars) max_inp_std = math.sqrt(max_inp_var) print('Initializing hypernet with Chunked Hyperfan Init ...') if inp_var >= max_inp_var: warn('Note, hypernetwork inputs should have an initial total ' + 'variance (std) smaller than %f (%f) in order for this ' \ % (max_inp_var, max_inp_std) + 'method to work properly.') ### Compute variances of chunk embeddings ### # We could have done that in the previous loop. But I think the code is # more readible this way. c_vars = [] n_clipped = 0 for i, var in enumerate(chunk_vars): c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var) if c_var < eps: n_clipped += 1 #warn('Initial variance of chunk embedding %d has to ' % i + \ # 'be clipped.') c_vars.append(max(eps, c_var)) if n_clipped > 0: warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \ 'chunk embeddings had to be clipped.') ### Initialize chunk embeddings ### for i, c_emb in enumerate(self.chunk_embeddings): c_std = math.sqrt(c_vars[i]) if cemb_normal_init: torch.nn.init.normal_(c_emb, mean=0, std=c_std) else: a = math.sqrt(3.0) * c_std torch.nn.init._no_grad_uniform_(c_emb, -a, a) ### Initialize hypernet with fan-in init ### for i, w in enumerate(self._hypernet.theta): if w.ndim == 1: # bias assert i % 2 == 1 torch.nn.init.constant_(w, 0) else: if use_xavier: iutils.xavier_fan_in_(w) else: torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
def __init__(self, main_dims, num_tasks, out_size=[64, 64], num_layers=5, num_filters=None, kernel_size=5, sa_units=[1, 3], rem_layers=[50, 50, 50], te_dim=8, ce_dim=8, no_theta=False, init_theta=None, use_batch_norm=False, use_spectral_norm=False, dropout_rate=-1, discard_remainder=False, noise_dim=-1, temb_std=-1): # FIXME find a way using super to handle multiple inheritence. #super(SAHyperNetwork, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) if init_theta is not None: # FIXME I would need to know the number of parameter tensors in each # hypernet before creating them to split the list init_theta. raise NotImplementedError( 'Argument "init_theta" not implemented ' + 'yet!') assert (init_theta is None or no_theta is False) self._no_theta = no_theta self._te_dim = te_dim self._discard_remainder = discard_remainder self._target_shapes = main_dims self._num_outputs = MainNetInterface.shapes_to_num_weights(main_dims) print('Building a self-attention hypernet for a network with %d '% \ self._num_outputs + 'weights.') assert (len(out_size) in [2, 3]) self._out_size = out_size num_outs = np.prod(out_size) assert (num_outs <= self._num_outputs) self._noise_dim = noise_dim self._temb_std = temb_std num_embs = self._num_outputs // num_outs rem_weights = self._num_outputs % num_outs if rem_weights > 0 and not discard_remainder: print('%d remaining weights (%.2f%%) are generated by a fully-' \ % (rem_weights, 100.0 * rem_weights / self._num_outputs) + \ 'connected hypernetwork.') elif rem_weights > 0: num_embs += 1 print('%d weights generated by the last chunk of the self-' % (num_outs - rem_weights) + 'attention hypernet will be ' + 'discarded.') self._num_embs = num_embs ### Generate Hypernet. self._hypernet = SAHnetPart(out_size=out_size, num_layers=num_layers, num_filters=num_filters, kernel_size=kernel_size, sa_units=sa_units, input_dim=te_dim + ce_dim + (noise_dim if noise_dim != -1 else 0), use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm, no_theta=no_theta, init_theta=None) self._rem_hypernet = None self._remainder = rem_weights if rem_weights > 0 and not discard_remainder: print('A second hypernet for the remainder of the weights has ' + 'to be created, as %d is not dividable by %d ' % (self._num_outputs, num_outs) + '(remaidner %d)' % rem_weights) self._rem_hypernet = HyperNetwork( [[rem_weights]], None, layers=rem_layers, te_dim=te_dim, no_te_embs=True, no_weights=no_theta, ce_dim=(noise_dim if noise_dim != -1 else None), dropout_rate=dropout_rate, use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm, noise_dim=-1, temb_std=None) ### Generate embeddings for all weight chunks. if no_theta: self._embs = None else: self._embs = nn.Parameter(data=torch.Tensor(num_embs, ce_dim), requires_grad=True) torch.nn.init.normal_(self._embs, mean=0., std=1.) # There is no need for a chunk embedding, as this network always # produces the same chunk. #if self._remainder > 0 and not discard_remainder: # self._rem_emb = nn.Parameter(data=torch.Tensor(1, ce_dim), # requires_grad=True) # torch.nn.init.normal_(self._rem_emb, mean=0., std=1.) ### Generate task embeddings. self._task_embs = nn.ParameterList() # We store individual task embeddings as it makes it easier to pass # only subsets of task embeddings to an optimizer. for _ in range(num_tasks): self._task_embs.append( nn.Parameter(data=torch.Tensor(te_dim), requires_grad=True)) torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.) self._num_weights = 0 for p in list(self.parameters()): self._num_weights += np.prod(p.shape) print('Total number of parameters in the hypernetwork: %d' % self._num_weights) self._theta_shapes = [[num_embs, ce_dim]] + \ self._hypernet.theta_shapes if self._rem_hypernet is not None: self._theta_shapes += self._rem_hypernet.theta_shapes
class SAHyperNetwork(nn.Module, CLHyperNetInterface): """This class manages an instance of class :class:`SAHnetPart` and most likely an instance of class :class:`toy_example.hyper_model.HyperNetwork`. Given a certain output shape, the network will use a transpose convolutional hypernetwork with self-attention layers (instance of class :class:`SAHnetPart`) to generate as many weights as possible by running the network multiple times with different (learned) embeddings as inputs. The remaining weights will be generated using an instance of class :class:`toy_example.hyper_model.HyperNetwork` (only necessary if the number of main network weights is not divisible by the number of :class:`SAHnetPart` outputs). Hence, the constructor creates an instance of the class :class:`SAHnetPart` and, if needed an instance of the class :class:`toy_example.hyper_model.HyperNetwork`. Additionally, it will create all embedding vectors (including task embeddings). Here are some suggested configurations, that have a relative small number of remaining weights (thus, the bulk of weights is generated by the SA Hypernet). **resnet32**: - out_size = [36, 36], remaining: 21 weights - out_size = [50, 50], remaining: 193 weights - out_size = [77, 77], remaining: 231 weights - out_size = [77, 77, 3], remaining: 231 weights - out_size = [90, 97], remaining: 3 weights - out_size = [51, 54], remaining: 21 weights - out_size = [51, 54, 3], remaining: 21 weights - out_size = [11, 21], remaining: 0 weights Attributes: chunk_embeddings: List of embedding vectors that encode main network location of the weights to be generated. Args: main_dims: A list of lists, each entry denoting the size of a weight or bias tensor in the hypernet Note, the output of the :meth:`forward` method will be a list of tensors, each having the shape of the corresponding list of integers provided as entry via this argument. See attribute :attr:`mnets.mnet_interface.MainNetInterface.param_shapes` for more information. num_tasks: Number of task embeddings to be generated. out_size: See constructor of class :class:`SAHnetPart`. num_layers: See constructor of class :class:`SAHnetPart`. num_filters: See constructor of class :class:`SAHnetPart`. kernel_size: See constructor of class :class:`SAHnetPart`. sa_units: See constructor of class :class:`SAHnetPart`. rem_layers: A list of integers, each indicating the size of a hidden layer in the network :class:`toy_example.hyper_model.HyperNetwork`, that handles the remaining weights. te_dim: The dimensionality of task embeddings. ce_dim: The dimensionality of the chunk embeddings (that should notify the hypernets which weights of the main network it has to generate). .. note:: The fully-connected hypernet for the remaining weights receives no such embedding. no_theta: If set to ``True``, no trainable parameters ``theta`` will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the forward function. Does not affect task embeddings. init_theta (optional): This option is for convenience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convenient way of initializing a network with, for instance, a weight draw produced by the hypernetwork. The given data has to be in the same shape as the attribute ``theta`` if the network would be constructed with ``theta``. Does not affect task embeddings. use_batch_norm: Enable batchnorm in all subnetworks. use_spectral_norm: Enable spectral normalization in all subnetworks. dropout_rate: See constructor of class :class:`toy_example.hyper_model.HyperNetwork`. Does only apply to this network type. discard_remainder: Instead of generating a separate :class:`toy_example.hyper_model.HyperNetwork`for the remaining weights, these will be generated by another run of the internal :class:`SAHnetPart` network, discarding those outputs that are not needed. noise_dim: If ``-1``, no noise will be applied. Otherwise the hypernetwork will receive as additional input zero-mean Gaussian noise with unit variance during training (zeroes will be inputted during eval-mode). The same noise vector is concatenated to all chunk embeddings when generating one set of weights. temb_std (optional): If not ``-1``, the task embeddings will be perturbed by zero-mean Gaussian noise with the given std (additive noise). The perturbation is only applied if the network is in training mode. Note, per batch of external inputs, the perturbation of the task embedding will be shared. """ def __init__(self, main_dims, num_tasks, out_size=[64, 64], num_layers=5, num_filters=None, kernel_size=5, sa_units=[1, 3], rem_layers=[50, 50, 50], te_dim=8, ce_dim=8, no_theta=False, init_theta=None, use_batch_norm=False, use_spectral_norm=False, dropout_rate=-1, discard_remainder=False, noise_dim=-1, temb_std=-1): # FIXME find a way using super to handle multiple inheritence. #super(SAHyperNetwork, self).__init__() nn.Module.__init__(self) CLHyperNetInterface.__init__(self) if init_theta is not None: # FIXME I would need to know the number of parameter tensors in each # hypernet before creating them to split the list init_theta. raise NotImplementedError( 'Argument "init_theta" not implemented ' + 'yet!') assert (init_theta is None or no_theta is False) self._no_theta = no_theta self._te_dim = te_dim self._discard_remainder = discard_remainder self._target_shapes = main_dims self._num_outputs = MainNetInterface.shapes_to_num_weights(main_dims) print('Building a self-attention hypernet for a network with %d '% \ self._num_outputs + 'weights.') assert (len(out_size) in [2, 3]) self._out_size = out_size num_outs = np.prod(out_size) assert (num_outs <= self._num_outputs) self._noise_dim = noise_dim self._temb_std = temb_std num_embs = self._num_outputs // num_outs rem_weights = self._num_outputs % num_outs if rem_weights > 0 and not discard_remainder: print('%d remaining weights (%.2f%%) are generated by a fully-' \ % (rem_weights, 100.0 * rem_weights / self._num_outputs) + \ 'connected hypernetwork.') elif rem_weights > 0: num_embs += 1 print('%d weights generated by the last chunk of the self-' % (num_outs - rem_weights) + 'attention hypernet will be ' + 'discarded.') self._num_embs = num_embs ### Generate Hypernet. self._hypernet = SAHnetPart(out_size=out_size, num_layers=num_layers, num_filters=num_filters, kernel_size=kernel_size, sa_units=sa_units, input_dim=te_dim + ce_dim + (noise_dim if noise_dim != -1 else 0), use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm, no_theta=no_theta, init_theta=None) self._rem_hypernet = None self._remainder = rem_weights if rem_weights > 0 and not discard_remainder: print('A second hypernet for the remainder of the weights has ' + 'to be created, as %d is not dividable by %d ' % (self._num_outputs, num_outs) + '(remaidner %d)' % rem_weights) self._rem_hypernet = HyperNetwork( [[rem_weights]], None, layers=rem_layers, te_dim=te_dim, no_te_embs=True, no_weights=no_theta, ce_dim=(noise_dim if noise_dim != -1 else None), dropout_rate=dropout_rate, use_batch_norm=use_batch_norm, use_spectral_norm=use_spectral_norm, noise_dim=-1, temb_std=None) ### Generate embeddings for all weight chunks. if no_theta: self._embs = None else: self._embs = nn.Parameter(data=torch.Tensor(num_embs, ce_dim), requires_grad=True) torch.nn.init.normal_(self._embs, mean=0., std=1.) # There is no need for a chunk embedding, as this network always # produces the same chunk. #if self._remainder > 0 and not discard_remainder: # self._rem_emb = nn.Parameter(data=torch.Tensor(1, ce_dim), # requires_grad=True) # torch.nn.init.normal_(self._rem_emb, mean=0., std=1.) ### Generate task embeddings. self._task_embs = nn.ParameterList() # We store individual task embeddings as it makes it easier to pass # only subsets of task embeddings to an optimizer. for _ in range(num_tasks): self._task_embs.append( nn.Parameter(data=torch.Tensor(te_dim), requires_grad=True)) torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.) self._num_weights = 0 for p in list(self.parameters()): self._num_weights += np.prod(p.shape) print('Total number of parameters in the hypernetwork: %d' % self._num_weights) self._theta_shapes = [[num_embs, ce_dim]] + \ self._hypernet.theta_shapes if self._rem_hypernet is not None: self._theta_shapes += self._rem_hypernet.theta_shapes # @override from CLHyperNetInterface def forward(self, task_id=None, theta=None, dTheta=None, task_emb=None, ext_inputs=None, squeeze=True): """Implementation of abstract super class method. Note, this methods can't handle external inputs yet! The method will iterate through the set of internal chunk embeddings, calling the internally maintained transpose conv. hypernetwork (potentially with self-attention layers). If necessary, a small portion of the chunks will be created by an additional fully-connected network. """ if task_id is None and task_emb is None: raise Exception('The hyper network has to get either a task ID' + 'to choose the learned embedding or directly ' + 'get an embedding as input (e.g. from a task ' + 'recognition model).') if not self.has_theta and theta is None: raise Exception( 'Network was generated without internal weights. ' + 'Hence, "theta" option may not be None.') if ext_inputs is not None: # FIXME If this will be implemented, please consider: # * batch size will have to be multiplied based on num chunk # embeddings and the number of external inputs -> large batches # * noise dim must adhere correct behavior (different noise per # external input). raise NotImplementedError( 'This hypernetwork implementation does ' + 'not yet support the passing of external inputs.') if theta is None: theta = self.theta else: assert (len(theta) == len(self.theta_shapes)) assert (np.all(np.equal(self._embs.shape, list(theta[0].shape)))) nhnet_shapes = len(self._hypernet.theta_shapes) chunk_embs = theta[0] hnet_theta = theta[1:1 + nhnet_shapes] if self._rem_hypernet is not None: rem_hnet_theta = theta[1 + nhnet_shapes:] if dTheta is not None: assert (len(dTheta) == len(self.theta_shapes)) chunk_embs = chunk_embs + dTheta[0] hnet_dTheta = dTheta[1:1 + nhnet_shapes] if self._rem_hypernet is not None: rem_hnet_dTheta = dTheta[1 + nhnet_shapes:] else: hnet_dTheta = None rem_hnet_dTheta = None # Currently, there is no option in the constructor to not generate # task embeddings, that is why the code below is commented out. # Select task embeddings. #if not self.has_task_embs and task_emb is None: # raise Exception('The network was created with no internal task ' + # 'embeddings, thus parameter "task_emb" has to ' + # 'be specified.') if task_emb is None: task_emb = self._task_embs[task_id] if self.training and self._temb_std != -1: task_emb.add(torch.randn_like(task_emb) * self._temb_std) # Concatenate the same noise to all chunks, such that it can be # viewed as if it were an external input. if self._noise_dim != -1: if self.training: eps = torch.randn((1, self._noise_dim)) else: eps = torch.zeros((1, self._noise_dim)) if self._embs.is_cuda: eps = eps.to(self._embs.get_device()) # The hypernet input is a concatenation of the task embedding with # the noise vector and each chunk embedding. hnet_input = torch.cat([task_emb.view(1, -1), eps], dim=1) hnet_input = hnet_input.expand(self._num_embs, self._te_dim + self._noise_dim) hnet_input = torch.cat([chunk_embs, hnet_input], dim=1) else: eps = None # The hypernet input is a concatenation of the task embedding with # each chunk embedding. hnet_input = task_emb.view(1, -1).expand(self._num_embs, self._te_dim) hnet_input = torch.cat([chunk_embs, hnet_input], dim=1) ### Gather all generated weights. weights = self._hypernet.forward(task_id=None, theta=hnet_theta, dTheta=hnet_dTheta, task_emb=None, ext_inputs=hnet_input) weights = weights.view(1, -1) if self._rem_hypernet is not None: rem_weights = self._rem_hypernet.forward(theta=rem_hnet_theta, dTheta=rem_hnet_dTheta, task_emb=task_emb, ext_inputs=eps) weights = torch.cat([weights, rem_weights[0].view(1, -1)], dim=1) ### Reshape weights. ind = 0 ret = [] for s in self.target_shapes: num = int(np.prod(s)) W = weights[0][ind:ind + num] ind += num ret.append(W.view(*s)) return ret @property def chunk_embeddings(self): """Getter for read-only attribute chunk_embeddings. Returns: A list of all chunk embedding vectors. """ # Note, the remainder network has no chunk embedding. return list(torch.split(self._embs, 1, dim=0)) # @override from CLHyperNetInterface @property def theta(self): """Getter for read-only attribute ``theta``. Theta are all learnable parameters of the chunked hypernet including the chunk embeddings that need to be learned. Not included are the task embeddings. .. note:: Chunk embeddings are prepended to the list of weights ``theta`` from the internal SA hypernetwork (if existing, ``theta`` from the remainder network will be appended). Returns: A list of tensors or ``None``, if ``no_theta`` was set to ``True`` in the constructor of this class. """ theta = [self._embs] + list(self._hypernet.theta) if self._rem_hypernet is not None: theta += list(self._rem_hypernet.theta) return theta # @override from CLHyperNetInterface @property def has_theta(self): """Getter for read-only attribute ``has_theta``.""" return not self._no_theta