def apply_chunked_hyperfan_init(self, method='in', use_xavier=False, temb_var=1., ext_inp_var=1., eps=1e-5, cemb_normal_init=False): r"""Initialize the network using a chunked hyperfan init. Inspired by the method `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we implemented for the full hypernetwork in method :meth:`toy_example.hyper_model.HyperNetwork.apply_hyperfan_init`, we heuristically developed a better initialization method for chunked hypernetworks. Unfortunately, the `Hyperfan Init` method does not apply to this kind of hypernetwork, since we reuse the same hypernet output head for the whole main network. Luckily, we can provide a simple heuristic. Similar to `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play with the variance of the input embeddings to affect the variance of the output weights. In a chunked hypernetwork, the input for each chunk is identical except for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}` denote the remaining inputs to the hypernetwork, which are identical for all chunks. Then, assuming the hypernetwork was initialized via fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}` can be written as follows (see documentation of method :meth:`toy_example.hyper_model.HyperNetwork.apply_hyperfan_init`): .. math:: \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Hence, we can achieve a desired output variance :math:`\text{Var}(v)` by initializing the chunk embeddinggs :math:`\mathbf{c}` via the following variance: .. math:: \text{Var}(c) = \max \Big\{ 0, \ \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \ n_e \text{Var}(e) \big] \Big\} Now, one important question remains. How do we pick a desired output variance :math:`\text{Var}(v)` for a chunk? Note, a chunk may include weights from several layers. The likelihood for this to happen depends on the main net architecture and the chunk size (see constructor argument ``chunk_dim``). The smaller the chunk size, the less likely it is that a chunk will contain elements from multiple main net weight tensors. In case each chunk would contain only weights from one main net weight tensor, we could simply pick the variance :math:`\text{Var}(v)` that would have been chosen by a main net initialization method (such as Xavier). In case a chunk contains contributions from several main net weight tensors, we apply the following heuristic. If a chunk contains contributions of a set of main network weight tensors :math:`W_1, \dots, W_K` with relative contribution sizes\ :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where :math:`n_v` denotes the chunk size and if the corresponding main network initialization method would require init varainces :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request a weighted average as follow: .. math:: \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k) What about bias vectors? Usually, the variance analysis applied to Xavier or Kaiming init assumes that biases are initialized to zero. This is not possible in this setting, as it would require assigning a negative variance to :math:`\mathbf{c}`. Instead, we follow the default PyTorch initialization (e.g., see method ``reset_parameters`` in class :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of initialization corresponds to a variance of :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`. Warning: Note, in order to compute the fan-in of layers with bias vectors, we need access to the corresponding weight tensor in the same layer. Since there is no clean way of matching a bias shape to its corresponging weight tensor shape we use the following heuristic, which should be correct for most main networks. We assume that the shape directly preceding a bias shape in the constructor argument ``target_shapes`` is the corresponding weight tensor. Note: Constructor argument ``noise_dim`` is automatically considered by this method. Note: We hypernet inputs should be zero mean. Warning: This method considers all 1D target weight tensors as bias vectors. Note: To avoid that the variances with which chunks are initialized have to be clipped (because they are too small or even negative), the variance of the remaining hypernet inputs should be properly scaled. In general, one should adhere the following rule .. math:: \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v) This method will calculate and print the maximum value that should be chosen for :math:`\text{Var}(e)` and will print warnings if variances have to be clipped. Args: method (str): The type of initialization that should be applied. Possible options are: - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output variances of the hypernetwork should correspond to fan-in variances. - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output variances of the hypernetwork should correspond to fan-out variances. - ``harmonic``: Use the harmonic mean of the fan-in and fan-out variance as target variance of the hypernetwork output. use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``) init should be used. temb_var (float): The initial variance of the task embeddings. .. note:: If ``temb_std`` was set in the constructor, then this method will automatically correct the provided ``temb_var`` as follows: :code:`temb_var += temb_std**2`. ext_inp_var (float): The initial variance of the external input. Only needs to be specified if external inputs are provided. .. note:: Not supported yet by this hypernetwork type, but should soon be included as a feature. eps (float): The minimum variance with which a chunk embedding is initialized. cemb_normal_init (bool): Use normal init for chunk embeedings rather than uniform init. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value for argument "method".') if not self.has_theta: raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') ### Compute input variance ### # The input variance does not include the variance of chunk embeddings! # Instead, it is the varaince of the inputs that are shared across all # chunks. if self._temb_std != -1: # Sum of uncorrelated variables. temb_var += self._temb_std**2 assert self._noise_dim == -1 or self._noise_dim > 0 # TODO external inputs are not yet considered. inp_dim = self._te_dim + \ (self._noise_dim if self._noise_dim != -1 else 0) #(self._size_ext_input if self._size_ext_input is not None else 0) \ inp_var = (self._te_dim / inp_dim) * temb_var #if self._size_ext_input is not None: # inp_var += (self._size_ext_input / inp_dim) * ext_inp_var if self._noise_dim != -1: inp_var += (self._noise_dim / inp_dim) * 1. c_dim = self._ce_dim ### Compute target variance of each output tensor ### target_vars = [] for i, s in enumerate(self.target_shapes): # FIXME 1D shape is not necessarily bias vector. if len(s) == 1: # Assume it's a bias vector # Assume that last shape has been the corresponding weight # tensor. if i > 0 and len(self.target_shapes[i - 1]) > 1: fan_in, _ = iutils.calc_fan_in_and_out( \ self.target_shapes[i-1]) else: # FIXME Quick-fix, use fan-out instead. fan_in = s[0] target_vars.append(1. / (3. * fan_in)) else: fan_in, fan_out = iutils.calc_fan_in_and_out(s) c_relu = 1 if use_xavier else 2 var_in = c_relu / fan_in var_out = c_relu / fan_out if method == 'in': var = var_in elif method == 'out': var = var_out else: var = 2 * (1. / var_in + 1. / var_out) target_vars.append(var) ### Target variance per chunk ### chunk_vars = [] i = 0 n = np.prod(self.target_shapes[i]) for j in range(self._num_chunks): m = self._chunk_dim var = 0 while m > 0: # Special treatment to fill up last chunk. if j == self._num_chunks - 1 and i == len(target_vars) - 1: assert n <= m o = m else: o = min(m, n) var += o / self._chunk_dim * target_vars[i] m -= o n -= o if n == 0: i += 1 if i < len(target_vars): n = np.prod(self.target_shapes[i]) chunk_vars.append(var) max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars) max_inp_std = math.sqrt(max_inp_var) print('Initializing hypernet with Chunked Hyperfan Init ...') if inp_var >= max_inp_var: warn('Note, hypernetwork inputs should have an initial total ' + 'variance (std) smaller than %f (%f) in order for this ' \ % (max_inp_var, max_inp_std) + 'method to work properly.') ### Compute variances of chunk embeddings ### # We could have done that in the previous loop. But I think the code is # more readible this way. c_vars = [] n_clipped = 0 for i, var in enumerate(chunk_vars): c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var) if c_var < eps: n_clipped += 1 #warn('Initial variance of chunk embedding %d has to ' % i + \ # 'be clipped.') c_vars.append(max(eps, c_var)) if n_clipped > 0: warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \ 'chunk embeddings had to be clipped.') ### Initialize chunk embeddings ### for i, c_emb in enumerate(self.chunk_embeddings): c_std = math.sqrt(c_vars[i]) if cemb_normal_init: torch.nn.init.normal_(c_emb, mean=0, std=c_std) else: a = math.sqrt(3.0) * c_std torch.nn.init._no_grad_uniform_(c_emb, -a, a) ### Initialize hypernet with fan-in init ### for i, w in enumerate(self._hypernet.theta): if w.ndim == 1: # bias assert i % 2 == 1 torch.nn.init.constant_(w, 0) else: if use_xavier: iutils.xavier_fan_in_(w) else: torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
def apply_hyperfan_init(self, method='in', use_xavier=False, uncond_var=1., cond_var=1., mnet=None, w_val=None, w_var=None, b_val=None, b_var=None): r"""Initialize the network using `hyperfan init`. Hyperfan initialization was developed in the following paper for this kind of hypernetwork "Principled Weight Initialization for Hypernetworks" https://openreview.net/forum?id=H1lma24tPB The initialization is based on the following idea: When the main network would be initialized using Xavier or Kaiming init, then variance of activations (fan-in) or gradients (fan-out) would be preserved by using a proper variance for the initial weight distribution (assuming certain assumptions hold at initialization, which are different for Xavier and Kaiming). When using this kind of initializations in the hypernetwork, then the variance of the initial main net weight distribution would simply equal the variance of the input embeddings (which can lead to exploding activations, e.g., for fan-in inits). The above mentioned paper proposes a quick fix for the type of hypernet that resembles the simple MLP hnet implemented in this class, i.e., which have a separate output head per weight tensor in the main network. Assuming that input embeddings are initialized with a certain variance (e.g., 1) and we use Xavier or Kaiming init for the hypernet, then the variance of the last hidden activation will also be 1. Then, we can modify the variance of the weights of each output head in the hypernet to obtain the same variance per main net weight tensor that we would typically obtain when applying Xavier or Kaiming to the main network directly. Note: If ``mnet`` is not provided or the corresponding attribute :attr:`mnets.mnet_interface.MainNetInterface.param_shapes_meta` is not implemented, then this method assumes that 1D target tensors (cf. constructor argument ``target_shapes``) represent bias vectors in the main network. Note: To compute the hyperfan-out initialization of bias vectors, we need access to the fan-in of the layer, which we can only compute based on the corresponding weight tensor in the same layer. This is only possible if ``mnet`` is provided. Otherwise, the following heuristic is applied. We assume that the shape directly preceding a bias shape in the constructor argument ``target_shapes`` is the corresponding weight tensor. Note: All hypernet inputs are assumed to be zero-mean random variables. **Variance of the hypernet input** In general, the input to the hypernetwork can be a concatenation of multiple embeddings (see description of arguments ``uncond_var`` and ``cond_var``). Let's denote the complete hypernetwork input by :math:`\mathbf{x} \in \mathbb{R}^n`, which consists of a conditional embedding :math:`\mathbf{e} \in \mathbb{R}^{n_e}` and an unconditional input :math:`\mathbf{c} \in \mathbb{R}^{n_c}`, i.e., .. math:: \mathbf{x} = \begin{bmatrix} \ \mathbf{e} \\ \ \mathbf{c} \ \end{bmatrix} We simply define the variance of an input :math:`\text{Var}(x_j)` as the weighted average of the individual variances, i.e., .. math:: \text{Var}(x_j) \equiv \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) To see that this is correct, consider a linear layer :math:`\mathbf{y} = W \mathbf{x}` or .. math:: y_i &= \sum_j w_{ij} x_j \\ \ &= \sum_{j=1}^{n_e} w_{ij} e_j + \ \sum_{j=n_e+1}^{n_e+n_c} w_{ij} c_{j-n_e} Hence, we can compute the variance of :math:`y_i` as follows (assuming the typical Xavier assumptions): .. math:: \text{Var}(y) &= n_e \text{Var}(w) \text{Var}(e) + \ n_c \text{Var}(w) \text{Var}(c) \\ \ &= \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Note, that Xavier would have initialized :math:`W` using :math:`\text{Var}(w) = \frac{1}{n} = \frac{1}{n_e+n_c}`. Args: method (str): The type of initialization that should be applied. Possible options are: - ``'in'``: Use `Hyperfan-in`. - ``'out'``: Use `Hyperfan-out`. - ``'harmonic'``: Use the harmonic mean of the `Hyperfan-in` and `Hyperfan-out` init. use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``) init should be used. uncond_var (float): The variance of unconditional embeddings. This value is only taken into consideration if ``uncond_in_size > 0`` (cf. constructor arguments). cond_var (float): The initial variance of conditional embeddings. This value is only taken into consideration if ``cond_in_size > 0`` (cf. constructor arguments). mnet (mnets.mnet_interface.MainNetInterface, optional): If applicable, the user should provide the main (or target) network, whose weights are generated by this hypernetwork. The ``mnet`` instance is used to extract valuable information that improve the initialization result. If provided, it is assumed that ``target_shapes`` (cf. constructor arguments) corresponds either to :attr:`mnets.mnet_interface.MainNetInterface.param_shapes` or :attr:`mnets.mnet_interface.MainNetInterface.hyper_shapes_learned`. w_val (list or dict, optional): The mean of the distribution with which output head weight matrices are initialized. Note, each weight tensor prescribed by :attr:`hnets.hnet_interface.HyperNetInterface.target_shapes` is produced via an independent linear output head. One may either specify a list of numbers having the same length as :attr:`hnets.hnet_interface.HyperNetInterface.target_shapes` or specify a dictionary which may have as keys the tensor names occurring in :attr:`mnets.mnet_interface.MainNetInterface.param_shapes_meta` and the corresponding mean value for the weight matrices of all output heads producing this type of tensor. If a list is provided, entries may be ``None`` and if a dictionary is provided, not all types of parameter tensors need to be specified. For tensors, for which no value is specified, the default value will be used. The default values for tensor types ``'weight'`` and ``'bias'`` are calculated based on the proposed hyperfan-initialization. For other tensor types the actual hypernet outputs should be drawn from the following distributions - ``'bn_scale'``: :math:`w \sim \delta(w - 1)` - ``'bn_shift'``: :math:`w \sim \delta(w)` - ``'cm_scale'``: :math:`w \sim \delta(w - 1)` - ``'cm_shift'``: :math:`w \sim \delta(w)` - ``'embedding'``: :math:`w \sim \mathcal{N}(0, 1)` Which would correspond to the following passed arguments .. code-block:: python w_val = { 'bn_scale': 0, 'bn_shift': 0, 'cm_scale': 0, 'cm_shift': 0, 'embedding': 0 } w_var = { 'bn_scale': 0, 'bn_shift': 0, 'cm_scale': 0, 'cm_shift': 0, 'embedding': 0 } b_val = { 'bn_scale': 1, 'bn_shift': 0, 'cm_scale': 1, 'cm_shift': 0, 'embedding': 0 } b_var = { 'bn_scale': 0, 'bn_shift': 0, 'cm_scale': 0, 'cm_shift': 0, 'embedding': 1 } w_var (list or dict, optional): The variance of the distribution with which output head weight matrices are initialized. Variance values of zero means that weights are set to a constant defined by ``w_val``. See description of argument ``w_val`` for more details. b_val (list or dict, optional): The mean of the distribution with which output head bias vectors are initialized. See description of argument ``w_val`` for more details. b_var (list or dict, optional): The variance of the distribution with which output head bias vectors are initialized. See description of argument ``w_val`` for more details. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value "%s" for argument "method".' % method) if self.unconditional_params is None: assert self._no_uncond_weights raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') ### Extract meta-information about target shapes ### meta = None if mnet is not None: assert isinstance(mnet, MainNetInterface) try: meta = mnet.param_shapes_meta except: meta = None if meta is not None: if len(self.target_shapes) == len(mnet.param_shapes): pass # meta = mnet.param_shapes_meta elif len(self.target_shapes) == len(mnet.hyper_shapes_learned): meta = [] for ii in mnet.hyper_shapes_learned_ref: meta.append(mnet.param_shapes_meta[ii]) else: warn('Target shapes of this hypernetwork could not be ' + 'matched to the meta information provided to the ' + 'initialization.') meta = None # TODO If the user doesn't (or can't) provide an `mnet` instance, we # should alternatively allow him to pass meta information directly. if meta is None: meta = [] # Heuristical approach to derive meta information from given shapes. layer_ind = 0 for i, s in enumerate(self.target_shapes): curr_meta = dict() if len(s) > 1: curr_meta['name'] = 'weight' curr_meta['layer'] = layer_ind layer_ind += 1 else: # just a heuristic, we can't know curr_meta['name'] = 'bias' if i > 0 and meta[-1]['name'] == 'weight': curr_meta['layer'] = meta[-1]['layer'] else: curr_meta['layer'] = -1 meta.append(curr_meta) assert len(meta) == len(self.target_shapes) # Mapping from layer index to the corresponding shape. layer_shapes = dict() # Mapping from layer index to whether the layer has a bias vector. layer_has_bias = defaultdict(lambda: False) for i, m in enumerate(meta): if m['name'] == 'weight' and m['layer'] != -1: assert len(self.target_shapes[i]) > 1 layer_shapes[m['layer']] = self.target_shapes[i] if m['name'] == 'bias' and m['layer'] != -1: layer_has_bias[m['layer']] = True ### Compute input variance ### cond_dim = self._cond_in_size uncond_dim = self._uncond_in_size inp_dim = cond_dim + uncond_dim input_variance = 0 if cond_dim > 0: input_variance += (cond_dim / inp_dim) * cond_var if uncond_dim > 0: input_variance += (uncond_dim / inp_dim) * uncond_var ### Initialize hidden layers to preserve variance ### # Note, if batchnorm layers are used, they will simply be initialized to # have no effect after initialization. This does not effect the # performed whitening operation. if self.batchnorm_layers is not None: for bn_layer in self.batchnorm_layers: if hasattr(bn_layer, 'scale'): nn.init.ones_(bn_layer.scale) if hasattr(bn_layer, 'bias'): nn.init.zeros_(bn_layer.bias) # Since batchnorm layers whiten the statistics of hidden # acitivities, the variance of the input will not be preserved by # Xavier/Kaiming. if len(self.batchnorm_layers) > 0: input_variance = 1. # We initialize biases with 0 (see Xavier assumption 4 in the Hyperfan # paper). Otherwise, we couldn't ignore the biases when computing the # output variance of a layer. # Note, we have to use fan-in init for the hidden layer to ensure the # property, that we preserve the input variance. assert len(self._layers) + 1 == len(self.layer_weight_tensors) for i, w_tensor in enumerate(self.layer_weight_tensors[:-1]): if use_xavier: iutils.xavier_fan_in_(w_tensor) else: torch.nn.init.kaiming_uniform_(w_tensor, mode='fan_in', nonlinearity='relu') if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[i]) ### Define default parameters of weight init distributions ### w_val_list = [] w_var_list = [] b_val_list = [] b_var_list = [] for i, m in enumerate(meta): def extract_val(user_arg): curr = None if isinstance(user_arg, (list, tuple)) and \ user_arg[i] is not None: curr = user_arg[i] elif isinstance(user_arg, (dict)) and \ m['name'] in user_arg.keys(): curr = user_arg[m['name']] return curr curr_w_val = extract_val(w_val) curr_w_var = extract_val(w_var) curr_b_val = extract_val(b_val) curr_b_var = extract_val(b_var) if m['name'] == 'weight' or m['name'] == 'bias': if None in [curr_w_val, curr_w_var, curr_b_val, curr_b_var]: # If distribution not fully specified, then we just fall # back to hyper-fan init. curr_w_val = None curr_w_var = None curr_b_val = None curr_b_var = None else: assert m['name'] in [ 'bn_scale', 'bn_shift', 'cm_scale', 'cm_shift', 'embedding' ] if curr_w_val is None: curr_w_val = 0 if curr_w_var is None: curr_w_var = 0 if curr_b_val is None: curr_b_val = 1 if m['name'] in ['bn_scale', 'cm_scale'] \ else 0 if curr_b_var is None: curr_b_var = 1 if m['name'] in ['embedding'] else 0 w_val_list.append(curr_w_val) w_var_list.append(curr_w_var) b_val_list.append(curr_b_val) b_var_list.append(curr_b_var) ### Initialize output heads ### # Note, that all output heads are realized internally via one large # fully-connected layer. # All output heads are linear layers. The biases of these linear # layers (called gamma and beta in the paper) are simply initialized # to zero. Note, that we allow deviations from this below. if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[-1]) c_relu = 1 if use_xavier else 2 # We are not interested in the fan-out, since the fan-out is just # the number of elements in the main network. # `fan-in` is called `d_k` in the paper and is just the size of the # last hidden layer. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(\ self.layer_weight_tensors[-1]) s_ind = 0 for i, out_shape in enumerate(self.target_shapes): m = meta[i] e_ind = s_ind + int(np.prod(out_shape)) curr_w_val = w_val_list[i] curr_w_var = w_var_list[i] curr_b_val = b_val_list[i] curr_b_var = b_var_list[i] if curr_w_val is None: c_bias = 2 if layer_has_bias[m['layer']] else 1 if m['name'] == 'bias': m_fan_out = out_shape[0] # NOTE For the hyperfan-out init, we also need to know the # fan-in of the layer. if m['layer'] != -1: m_fan_in, _ = iutils.calc_fan_in_and_out( \ layer_shapes[m['layer']]) else: # FIXME Quick-fix. m_fan_in = m_fan_out var_in = c_relu / (2. * fan_in * input_variance) num = c_relu * (1. - m_fan_in / m_fan_out) denom = fan_in * input_variance var_out = max(0, num / denom) else: assert m['name'] == 'weight' m_fan_in, m_fan_out = iutils.calc_fan_in_and_out(out_shape) var_in = c_relu / (c_bias * m_fan_in * fan_in * \ input_variance) var_out = c_relu / (m_fan_out * fan_in * input_variance) if method == 'in': var = var_in elif method == 'out': var = var_out elif method == 'harmonic': var = 2 * (1. / var_in + 1. / var_out) else: raise ValueError('Method %s invalid.' % method) # Initialize output head weight tensor using `var`. std = math.sqrt(var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_( \ self.layer_weight_tensors[-1][s_ind:e_ind, :], -a, a) else: if curr_w_var == 0: nn.init.constant_( self.layer_weight_tensors[-1][s_ind:e_ind, :], curr_w_val) else: std = math.sqrt(curr_w_var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_( \ self.layer_weight_tensors[-1][s_ind:e_ind, :], curr_w_val-a, curr_w_val+a) if curr_b_var == 0: nn.init.constant_(self.layer_bias_vectors[-1][s_ind:e_ind], curr_b_val) else: std = math.sqrt(curr_b_var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_( \ self.layer_bias_vectors[-1][s_ind:e_ind], curr_b_val-a, curr_b_val+a) s_ind = e_ind
def apply_hyperfan_init(self, method='in', use_xavier=False, temb_var=1., ext_inp_var=1.): r"""Initialize the network using hyperfan init. Hyperfan initialization was developed in the following paper for this kind of hypernetwork "Principled Weight Initialization for Hypernetworks" https://openreview.net/forum?id=H1lma24tPB The initialization is based on the following idea: When the main network would be initialized using Xavier or Kaiming init, then variance of activations (fan-in) or gradients (fan-out) would be preserved by using a proper variance for the initial weight distribution (assuming certain assumptions hold at initialization, which are different for Xavier and Kaiming). When using these kind of initializations in the hypernetwork, then the variance of the initial main net weight distribution would simply equal the variance of the input embeddings (which can lead to exploding activations, e.g., for fan-in inits). The above mentioned paper proposes a quick fix for the type of hypernets which have a separate output head per weight tensor in the main network (which is the case for this hypernetwork class). Assuming that input embeddings are initialized with a certain variance (e.g., 1) and we use Xavier or Kaiming init for the hypernet, then the variance of the last hidden activation will also be 1. Then, we can modify the variance of the weights of each output head in the hypernet to obtain the variance for the main net weight tensors that we would typically obtain when applying Xavier or Kaiming to the main network directly. Warning: This method currently assumes that 1D target tensors (cmp. constructor argument ``target_shapes``) are bias vectors in the main network. Warning: To compute the hyperfan-out initialization of bias vectors, we need access to the fan-in of the layer, which we can only compute based on the corresponding weight tensor in the same layer. Since there is no clean way of matching a bias shape to its corresponging weight tensor shape we use the following heuristic, which should be correct for most main networks. We assume that the shape directly preceding a bias shape in the constructor argument ``target_shapes`` is the corresponding weight tensor. **Variance of the hypernet input** In general, the input to the hypernetwork can be a concatenation of multiple embeddings (see description of arguments ``temb_var`` and ``ext_inp_var``). Let's denote the complete hypernetwork input by :math:`\mathbf{x} \in \mathbb{R}^n`, which consists of a task embedding :math:`\mathbf{e} \in \mathbb{R}^{n_e}` and an external input :math:`\mathbf{c} \in \mathbb{R}^{n_c}`, i.e., .. math:: \mathbf{x} = \begin{bmatrix} \ \mathbf{e} \\ \ \mathbf{c} \ \end{bmatrix} We simply define the variance of an input :math:`\text{Var}(x_j)` as the weighted average of the individual variances, i.e., .. math:: \text{Var}(x_j) \equiv \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) To see that this is correct, consider a linear layer :math:`\mathbf{y} = W \mathbf{x}` or .. math:: y_i &= \sum_j w_{ij} x_j \\ \ &= \sum_{j=1}^{n_e} w_{ij} e_j + \ \sum_{j=n_e+1}^{n_e+n_c} w_{ij} c_{j-n_e} Hence, we can compute the variance of :math:`y_i` as follows (assuming the typical Xavier assumptions): .. math:: \text{Var}(y) &= n_e \text{Var}(w) \text{Var}(e) + \ n_c \text{Var}(w) \text{Var}(c) \\ \ &= \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Note, that Xavier would have initialized :math:`W` using :math:`\text{Var}(w) = \frac{1}{n} = \frac{1}{n_e+n_c}`. Note: This method will automatically incorporate the noise embedding that is inputted into the network if constructor argument ``noise_dim`` was set. Note: We hypernet inputs should be zero mean. Args: method (str): The type of initialization that should be applied. Possible options are: - ``in``: Use `Hyperfan-in`. - ``out``: Use `Hyperfan-out`. - ``harmonic``: Use the harmonic mean of the `Hyperfan-in` and `Hyperfan-out` init. use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``) init should be used. temb_var (float): The initial variance of the task embeddings. .. note:: If ``temb_std`` was set in the constructor, then this method will automatically correct the provided ``temb_var`` as follows: :code:`temb_var += temb_std**2`. ext_inp_var (float): The initial variance of the external input. Only needs to be specified if external inputs are provided (see argument ``ce_dim`` of constructor). """ # FIXME If the network has external inputs and task embeddings, then # both these inputs might have different variances. Thus, a single # parameter `input_variance` might not be sufficient. # Now, we assume that the user provides a proper variance. We could # simplify the job for him by providing multiple arguments and compute # the weighting ourselves. # FIXME Handle constructor arguments `noise_dim` and `temb_std`. # Note, we would jost need to add `temb_std**2` to the variance of # task embeddings, since the variance of a sum of uncorrelated RVs is # just the sum of the individual variances. if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value for argument "method".') if not self.has_theta: raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') ### Compute input variance ### if self._temb_std != -1: # Sum of uncorrelated variables. temb_var += self._temb_std**2 assert self._size_ext_input is None or self._size_ext_input > 0 assert self._noise_dim == -1 or self._noise_dim > 0 inp_dim = self._te_dim + \ (self._size_ext_input if self._size_ext_input is not None else 0) \ + (self._noise_dim if self._noise_dim != -1 else 0) input_variance = (self._te_dim / inp_dim) * temb_var if self._size_ext_input is not None: input_variance += (self._size_ext_input / inp_dim) * ext_inp_var if self._noise_dim != -1: input_variance += (self._noise_dim / inp_dim) * 1. ### Initialize hidden layers to preserve variance ### # We initialize biases with 0 (see Xavier assumption 4 in the Hyperfan # paper). Otherwise, we couldn't ignore the biases when computing the # output variance of a layer. # Note, we have to use fan-in init for the hidden layer to ensure the # property, that we preserve the input variance. for i in range(0, len(self._hidden_dims), 2 if self._use_bias else 1): #W = self.theta[i] if use_xavier: iutils.xavier_fan_in_(self.theta[i]) else: torch.nn.init.kaiming_uniform_(self.theta[i], mode='fan_in', nonlinearity='relu') if self._use_bias: #b = self.theta[i+1] torch.nn.init.constant_(self.theta[i + 1], 0) ### Initialize output heads ### c_relu = 1 if use_xavier else 2 # FIXME Not a proper way to figure out whether the hnet produces # bias vectors in the mnet. c_bias = 1 for s in self.target_shapes: if len(s) == 1: c_bias = 2 break # This is how we should do it instead. #c_bias = 2 if mnet.has_bias else 1 j = 0 for i in range(len(self._hidden_dims), len(self._theta_shapes), 2 if self._use_bias else 1): # All output heads are linear layers. The biases of these linear # layers (called gamma and beta in the paper) are simply initialized # to zero. if self._use_bias: #b = self.theta[i+1] torch.nn.init.constant_(self.theta[i + 1], 0) # We are not interested in the fan-out, since the fan-out is just # the number of elements in the corresponding main network tensor. # `fan-in` is called `d_k` in the paper and is just the size of the # last hidden layer. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( \ self.theta[i]) out_shape = self.target_shapes[j] # FIXME 1D output tensors don't need to be bias vectors. They can # be arbitrary embeddings or, for instance, batchnorm weights. if len(out_shape) == 1: # Assume output is bias vector. m_fan_out = out_shape[0] # NOTE For the hyperfan-out init, we also need to know the # fan-in of the layer. # FIXME We have no proper way at the moment to get the correct # fan-in of the layer this bias vector belongs to. if j > 0 and len(self.target_shapes[j - 1]) > 1: m_fan_in, _ = iutils.calc_fan_in_and_out( \ self.target_shapes[j-1]) else: # FIXME Quick-fix. m_fan_in = m_fan_out var_in = c_relu / (2. * fan_in * input_variance) num = c_relu * (1. - m_fan_in / m_fan_out) denom = fan_in * input_variance var_out = c_relu / max(0, num / denom) else: m_fan_in, m_fan_out = iutils.calc_fan_in_and_out(out_shape) var_in = c_relu / (c_bias * m_fan_in * fan_in * input_variance) var_out = c_relu / (m_fan_out * fan_in * input_variance) if method == 'in': var = var_in elif method == 'out': var = var_out elif method == 'harmonic': var = 2 * (1. / var_in + 1. / var_out) else: raise ValueError('Method %s invalid.' % method) # Initialize output head weight tensor using `var`. std = math.sqrt(var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_(self.theta[i], -a, a)
def apply_chunked_hyperfan_init(self, method='in', use_xavier=False, uncond_var=1., cond_var=1., eps=1e-5, cemb_normal_init=False, mnet=None, target_vars=None): r"""Initialize the network using a chunked hyperfan init. Inspired by the method `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we implemented for the MLP hypernetwork in method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`, we heuristically developed a better initialization method for chunked hypernetworks. Unfortunately, the `Hyperfan Init` method from the paper does not apply to this kind of hypernetwork, since we reuse the same hypernet output head for the whole main network. Luckily, we can provide a simple heuristic. Similar to `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play with the variance of the input embeddings to affect the variance of the output weights. In a chunked hypernetwork, the input for each chunk is identical except for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}` denote the remaining inputs to the hypernetwork, which are identical for all chunks. Then, assuming the hypernetwork was initialized via fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}` can be written as follows (see documentation of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`): .. math:: \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Hence, we can achieve a desired output variance :math:`\text{Var}(v)` by initializing the chunk embeddings :math:`\mathbf{c}` via the following variance: .. math:: \text{Var}(c) = \max \Big\{ 0, \ \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \ n_e \text{Var}(e) \big] \Big\} Now, one important question remains. How do we pick a desired output variance :math:`\text{Var}(v)` for a chunk? Note, a chunk may include weights from several layers. The likelihood for this to happen depends on the main net architecture and the chunk size (see constructor argument ``chunk_size``). The smaller the chunk size, the less likely it is that a chunk will contain elements from multiple main net weight tensors. In case each chunk would contain only weights from one main net weight tensor, we could simply pick the variance :math:`\text{Var}(v)` that would have been chosen by a main net initialization method (such as Xavier). In case a chunk contains contributions from several main net weight tensors, we apply the following heuristic. If a chunk contains contributions of a set of main network weight tensors :math:`W_1, \dots, W_K` with relative contribution sizes\ :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where :math:`n_v` denotes the chunk size and if the corresponding main network initialization method would require init variances :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request a weighted average as follow: .. math:: \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k) What about bias vectors? Usually, the variance analysis applied to Xavier or Kaiming init assumes that biases are initialized to zero. This is not possible in this setting, as it would require assigning a negative variance to :math:`\mathbf{c}`. Instead, we follow the default PyTorch initialization (e.g., see method ``reset_parameters`` in class :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of initialization corresponds to a variance of :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`. Note: All hypernet inputs are assumed to be zero-mean random variables. Note: To avoid that the variances with which chunks are initialized have to be clipped (because they are too small or even negative), the variance of the remaining hypernet inputs should be properly scaled. In general, one should adhere the following rule .. math:: \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v) This method will calculate and print the maximum value that should be chosen for :math:`\text{Var}(e)` and will print warnings if variances have to be clipped. Args: (....): See arguments of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`. method (str): The type of initialization that should be applied. Possible options are: - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output variances of the hypernetwork should correspond to fan-in variances. - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output variances of the hypernetwork should correspond to fan-out variances. - ``harmonic``: Use the harmonic mean of the fan-in and fan-out variance as target variance of the hypernetwork output. eps (float): The minimum variance with which a chunk embedding is initialized. cemb_normal_init (bool): Use normal init for chunk embeddings rather than uniform init. target_vars (list or dict, optional): The variance of the distribution for each parameter tensor generated by this hypernetwork. Target variance values can either be provided as list of length ``len(hnet.target_shapes)`` or as dictionary. The usage is analoguous to the usage of parameter ``w_val`` of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`. Note: This method currently does not allow initial output distributions with non-zero mean. However, the docstring of method :meth:`probabilistic.gauss_hnet_init.gauss_hyperfan_init` describes how this is in principle feasible and might be incorporated in the future. Note: Unspecified target variances for parameter tensors of type ``'weight'`` or ``'bias'`` are computed as described above. Default target variances for all other parameter tensor types are simply ``1``. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value "%s" for argument "method".' % method) if self.unconditional_params is None: assert self._no_uncond_weights raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') if self.unconditional_params is None and self._cond_chunk_embs: assert self._no_cond_weights raise ValueError('Chunked hyperfan init cannot be applied if ' + 'chunk embeddings are not internally maintained.') ### Extract meta-information about target shapes ### # FIXME This section is copied from the HMLP implementation. meta = None if mnet is not None: assert isinstance(mnet, MainNetInterface) try: meta = mnet.param_shapes_meta except: meta = None if meta is not None: if len(self.target_shapes) == len(mnet.param_shapes): pass # meta = mnet.param_shapes_meta elif len(self.target_shapes) == len(mnet.hyper_shapes_learned): meta = [] for ii in mnet.hyper_shapes_learned_ref: meta.append(mnet.param_shapes_meta[ii]) else: warn('Target shapes of this hypernetwork could not be ' + 'matched to the meta information provided to the ' + 'initialization.') meta = None # TODO If the user doesn't (or can't) provide an `mnet` instance, we # should alternatively allow him to pass meta information directly. if meta is None: meta = [] # Heuristical approach to derive meta information from given shapes. layer_ind = 0 for i, s in enumerate(self.target_shapes): curr_meta = dict() if len(s) > 1: curr_meta['name'] = 'weight' curr_meta['layer'] = layer_ind layer_ind += 1 else: # just a heuristic, we can't know curr_meta['name'] = 'bias' if i > 0 and meta[-1]['name'] == 'weight': curr_meta['layer'] = meta[-1]['layer'] else: curr_meta['layer'] = -1 meta.append(curr_meta) assert len(meta) == len(self.target_shapes) # Mapping from layer index to the corresponding shape. layer_shapes = dict() # Mapping from layer index to whether the layer has a bias vector. layer_has_bias = defaultdict(lambda: False) for i, m in enumerate(meta): if m['name'] == 'weight' and m['layer'] != -1: assert len(self.target_shapes[i]) > 1 layer_shapes[m['layer']] = self.target_shapes[i] if m['name'] == 'bias' and m['layer'] != -1: layer_has_bias[m['layer']] = True ### Compute input variance ### # The input variance does not include the variance of chunk embeddings! # Instead, it is the variance of the inputs that are shared across all # chunks. cond_dim = self._cond_in_size uncond_dim = self._uncond_in_size # Note, `inp_dim` can be zero if conditional chunk embeddings are used. inp_dim = cond_dim + uncond_dim inp_var = 0 if cond_dim > 0: inp_var += (cond_dim / inp_dim) * cond_var if uncond_dim > 0: inp_var += (uncond_dim / inp_dim) * uncond_var c_dim = self.chunk_emb_size ### Initialize hypernet with fan-in init ### if self.batchnorm_layers is not None and len( self.batchnorm_layers) > 0: # Note, batchnorm layers simply whiten the incoming statistics. # Thus, if we tune the variance of chunk embeddings, this variance # is normalized by a batchnorm layer and thus vanishes. raise RuntimeError('Chunked hyperfan init not applicable if a ' + 'hypernetwork with batchnorm layers is used.') # Note, the whole internal hypernetwork is initialized with fan-in init # to simply pass the variance of all inputs to the hypernet output. for i, w_tensor in enumerate(self.layer_weight_tensors): if use_xavier: iutils.xavier_fan_in_(w_tensor) else: torch.nn.init.kaiming_uniform_(w_tensor, mode='fan_in', nonlinearity='relu') if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[i]) ### Compute target variance of each output tensor ### if target_vars is None: target_vars = [None] * len(self.target_shapes) elif isinstance(target_vars, dict): target_vars_d = target_vars target_vars = [] for i, m in enumerate(meta): if m['name'] in target_vars_d.keys(): target_vars.append(target_vars_d[m['name']]) else: target_vars.append(None) else: assert isinstance(target_vars, (list, tuple)) assert len(target_vars) == len(self.target_shapes) for i, s in enumerate(self.target_shapes): if target_vars[i] is not None: # Use user specified target variance. continue m = meta[i] if m['name'] == 'bias': if m['layer'] != -1: fan_in, _ = iutils.calc_fan_in_and_out( \ layer_shapes[m['layer']]) else: # FIXME Quick-fix, use fan-out instead. fan_in = s[0] target_vars[i] = 1. / (3. * fan_in) elif m['name'] == 'weight': fan_in, fan_out = iutils.calc_fan_in_and_out(s) c_relu = 1 if use_xavier else 2 var_in = c_relu / fan_in var_out = c_relu / fan_out if method == 'in': var = var_in elif method == 'out': var = var_out else: var = 2 * (1. / var_in + 1. / var_out) target_vars[i] = var else: target_vars[i] = 1. ### Target variance per chunk ### chunk_vars = [] i = 0 n = np.prod(self.target_shapes[i]) for j in range(self.num_chunks): m = self._chunk_size var = 0 while m > 0: # Special treatment to fill up last chunk. if j == self.num_chunks - 1 and i == len(target_vars) - 1: assert n <= m o = m else: o = min(m, n) var += o / self._chunk_size * target_vars[i] m -= o n -= o if n == 0: i += 1 if i < len(target_vars): n = np.prod(self.target_shapes[i]) chunk_vars.append(var) if inp_dim > 0: max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars) max_inp_std = math.sqrt(max_inp_var) print('Initializing hypernet with Chunked Hyperfan Init ...') if inp_var >= max_inp_var: warn('Note, hypernetwork inputs should have an initial total ' + 'variance (std) smaller than %f (%f) in order for this ' \ % (max_inp_var, max_inp_std) + 'method to work properly.') ### Compute variances of chunk embeddings ### # We could have done that in the previous loop. But I think the code is # more readible this way. c_vars = [] n_clipped = 0 for i, var in enumerate(chunk_vars): c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var) if c_var < eps: n_clipped += 1 #warn('Initial variance of chunk embedding %d has to ' % i + \ # 'be clipped.') c_vars.append(max(eps, c_var)) if n_clipped > 0: warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \ 'chunk embeddings had to be clipped.') ### Initialize chunk embeddings ### for i in range(self.num_chunks): c_std = math.sqrt(c_vars[i]) num_conds = self.num_known_conds if self._cond_chunk_embs else 1 for j in range(num_conds): cond_id = j if self._cond_chunk_embs else None c_emb = self.get_chunk_emb(chunk_id=i, cond_id=cond_id) if cemb_normal_init: torch.nn.init.normal_(c_emb, mean=0, std=c_std) else: a = math.sqrt(3.0) * c_std torch.nn.init._no_grad_uniform_(c_emb, -a, a)