def get_hypernet(config, device, net_type, target_shapes, num_conds, no_cond_weights=False, no_uncond_weights=False, uncond_in_size=0, shmlp_chunk_shapes=None, shmlp_num_per_chunk=None, shmlp_assembly_fct=None, verbose=True, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. Note: The function expects command-line arguments available according to the function :func:`utils.cli_args.hnet_args`. device: PyTorch device. net_type (str): The type of network. The following options are available: - ``'hmlp'`` - ``'chunked_hmlp'`` - ``'structured_hmlp'`` - ``'hdeconv'`` - ``'chunked_hdeconv'`` target_shapes (list): See argument ``target_shapes`` of :class:`hnets.mlp_hnet.HMLP`. num_conds (int): Number of conditions that should be known to the hypernetwork. no_cond_weights (bool): See argument ``no_cond_weights`` of :class:`hnets.mlp_hnet.HMLP`. no_uncond_weights (bool): See argument ``no_uncond_weights`` of :class:`hnets.mlp_hnet.HMLP`. uncond_in_size (int): See argument ``uncond_in_size`` of :class:`hnets.mlp_hnet.HMLP`. shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this function are prefixed, since several hypernetworks should be generated. Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hnet_args`. """ assert net_type in [ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ] hnet = None ### FIXME Code almost identically copied from `get_mnet_model` ### if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) if hc('hnet_net_act'): net_act = gc('hnet_net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('hnet_no_bias') dropout_rate = get_val('hnet_dropout_rate') specnorm = get_val('hnet_specnorm') batchnorm = get_val('hnet_batchnorm') no_batchnorm = get_val('hnet_no_batchnorm') #bn_no_running_stats = get_val('hnet_bn_no_running_stats') #n_distill_stats = get_val('hnet_bn_distill_stats') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x ### FIXME Code copied until here ### if hc('hmlp_arch'): hmlp_arch_is_list = False hmlp_arch = gc('hmlp_arch') if ';' in hmlp_arch: hmlp_arch_is_list = True if net_type != 'structured_hmlp': raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) + 'contain semicolons for network type ' + '"structured_hmlp"!') hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')] else: hmlp_arch = misc.str_to_ints(hmlp_arch) if hc('chunk_emb_size'): chunk_emb_size = gc('chunk_emb_size') chunk_emb_size = misc.str_to_ints(chunk_emb_size) if len(chunk_emb_size) == 1: chunk_emb_size = chunk_emb_size[0] else: if net_type != 'structured_hmlp': raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) + 'only contain multiple values for network ' + 'type "structured_hmlp"!') if hc('cond_emb_size'): cond_emb_size = gc('cond_emb_size') else: cond_emb_size = 0 if net_type == 'hmlp': assert hc('hmlp_arch') # Default keyword arguments of class HMLP. dkws = misc.get_default_args(HMLP.__init__) hnet = HMLP(target_shapes, uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'chunked_hmlp': assert hc('hmlp_arch') assert hc('chmlp_chunk_size') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') # Default keyword arguments of class ChunkedHMLP. dkws = misc.get_default_args(ChunkedHMLP.__init__) hnet = ChunkedHMLP( target_shapes, gc('chmlp_chunk_size'), chunk_emb_size=chunk_emb_size, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'structured_hmlp': assert hc('hmlp_arch') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') assert shmlp_chunk_shapes is not None and \ shmlp_num_per_chunk is not None and \ shmlp_assembly_fct is not None # Default keyword arguments of class StructuredHMLP. dkws = misc.get_default_args(StructuredHMLP.__init__) dkws_hmlp = misc.get_default_args(HMLP.__init__) shmlp_hmlp_kwargs = [] if not hmlp_arch_is_list: hmlp_arch = [hmlp_arch] for i, arch in enumerate(hmlp_arch): shmlp_hmlp_kwargs.append({ 'layers': arch, 'activation_fn': assign(net_act, dkws_hmlp['activation_fn']), 'use_bias': assign(not no_bias, dkws_hmlp['use_bias']), 'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']), 'use_spectral_norm': \ assign(specnorm, dkws_hmlp['use_spectral_norm']), 'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm']) }) if len(shmlp_hmlp_kwargs) == 1: shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0] hnet = StructuredHMLP(target_shapes, shmlp_chunk_shapes, shmlp_num_per_chunk, chunk_emb_size, shmlp_hmlp_kwargs, shmlp_assembly_fct, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, verbose=verbose, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds).to(device) elif net_type == 'hdeconv': #HDeconv raise NotImplementedError else: assert net_type == 'chunked_hdeconv' #ChunkedHDeconv raise NotImplementedError return hnet
def __init__(self, target_shapes, chunk_shapes, num_per_chunk, chunk_emb_sizes, hmlp_kwargs, assembly_fct, cond_chunk_embs=False, uncond_in_size=0, cond_in_size=8, verbose=True, no_uncond_weights=False, no_cond_weights=False, num_cond_embs=1): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) ### Basic checks for user inputs ### assert isinstance(chunk_shapes, (list, tuple)) and len(chunk_shapes) > 0 num_chunk_weights = 0 for chunk in chunk_shapes: # Each chunk is a list of shapes! assert isinstance(chunk, (list, tuple)) and len(chunk) > 0 num_chunk_weights += StructuredHMLP.shapes_to_num_weights(chunk) num_trgt_weights = StructuredHMLP.shapes_to_num_weights(target_shapes) if num_trgt_weights > num_chunk_weights: # TODO Should we display a warning? The user might actively want # to reuse the same weights in the target network. In the end, the # user should be completely free on how he assembles the chunks to # weights within the `assembly_fct`. pass assert isinstance(num_per_chunk, (list, tuple)) and \ len(num_per_chunk) == len(chunk_shapes) if 0 in num_per_chunk: raise ValueError('Option "num_per_chunk" may not contains 0s. ' + 'Each internal hypernetwork must create at ' + 'least one chunk!') assert isinstance(chunk_emb_sizes, (int, list, tuple)) if isinstance(chunk_emb_sizes, int): chunk_emb_sizes = [chunk_emb_sizes] * len(chunk_shapes) assert len(chunk_emb_sizes) == len(chunk_shapes) if 0 in chunk_emb_sizes and uncond_in_size == 0 and cond_in_size == 0: raise ValueError('Argument "chunk_emb_sizes" may not contain ' + '0s if "uncond_in_size" and "cond_in_size" are ' + '0!') for i, s in enumerate(chunk_emb_sizes): if s == 0 and num_per_chunk[i] != 1: raise ValueError('Option "chunk_emb_sizes" may only contain ' + 'zeroes if the corresponding entry in ' + '"num_per_chunk" is 1.') assert isinstance(hmlp_kwargs, (dict, list, tuple)) if isinstance(hmlp_kwargs, dict): hmlp_kwargs = [dict(hmlp_kwargs) for _ in range(len(chunk_shapes))] assert len(hmlp_kwargs) == len(chunk_shapes) for hkwargs in hmlp_kwargs: assert isinstance(hkwargs, dict) forbidden = [ 'uncond_in_size', 'cond_in_size', 'no_uncond_weights', 'no_cond_weights', 'num_cond_embs' ] for kw in forbidden: if kw in hkwargs.keys(): raise ValueError('Key %s may not be passed with argument ' \ % kw + '"hmlp_kwargs"!') if 'verbose' not in hkwargs.keys(): hkwargs['verbose'] = False ### Make constructor arguments internally available ### self._chunk_shapes = chunk_shapes self._num_per_chunk = num_per_chunk self._chunk_emb_sizes = chunk_emb_sizes #self._hkwargs = hkwargs self._assembly_fct = assembly_fct self._cond_chunk_embs = cond_chunk_embs self._uncond_in_size = uncond_in_size self._cond_in_size = cond_in_size self._no_uncond_weights = no_uncond_weights self._no_cond_weights = no_cond_weights self._num_cond_embs = num_cond_embs ### Create underlying full hypernets ### num_hnets = len(chunk_shapes) self._hnets = [] for i in range(num_hnets): # Note, even if chunk embeddings are considered conditional, they # are maintained in this object and just fed as an external input # to the underlying hnet. hnet_uncond_in_size = uncond_in_size + chunk_emb_sizes[i] # Conditional inputs (`cond_in_size`) will be maintained by the # first internal hypernetwork. if i == 0: hnet_no_cond_weights = no_cond_weights hnet_num_cond_embs = num_cond_embs if cond_chunk_embs and cond_in_size == 0: # If there are no other conditional embeddings except the # chunk embeddings, we tell the first underlying hnet # explicitly that it doesn't need to maintain any # conditional weights to avoid that it will throw a warning. hnet_num_cond_embs = 0 else: # All other hypernetworks will be passed the conditional # embeddings from the first hypernet as input. hnet_no_cond_weights = True hnet_num_cond_embs = 0 self._hnets.append( HMLP(chunk_shapes[i], uncond_in_size=hnet_uncond_in_size, cond_in_size=cond_in_size, no_uncond_weights=no_uncond_weights, no_cond_weights=hnet_no_cond_weights, num_cond_embs=hnet_num_cond_embs, **hmlp_kwargs[i])) ### Setup attributes required by interface ### # Most of these attributes are taken over from the internally # maintained hypernetworks. self._target_shapes = target_shapes self._num_known_conds = self._num_cond_embs # As we just append the weights of the internal hypernets we will have # output weights all over the place. # Additionally, it would be complicated to assign outputs to target # outputs, as we do not know, what is happening in the `assembly_fct`. # Also, keep in mind that we will append chunk embeddings at the end # of `param_shapes`. self._mask_fc_out = False self._unconditional_param_shapes_ref = [] self._param_shapes = [] self._param_shapes_meta = [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() for i, hnet in enumerate(self._hnets): # Note, it is important to convert lists into new object and not # just copy references! # Note, we have to adapt all references if `i > 0`. ps_len_old = len(self._param_shapes) for ref in hnet._unconditional_param_shapes_ref: self._unconditional_param_shapes_ref.append(ref + ps_len_old) if hnet._internal_params is not None: if self._internal_params is None: self._internal_params = nn.ParameterList() ip_len_old = len(self._internal_params) self._internal_params.extend( \ nn.ParameterList(hnet._internal_params)) self._param_shapes.extend(list(hnet._param_shapes)) for meta in hnet.param_shapes_meta: assert 'hnet_ind' not in meta.keys() assert 'layer' in meta.keys() assert 'index' in meta.keys() new_meta = dict(meta) new_meta['hnet_ind'] = i if i > 0: # FIXME We should properly adjust colliding `layer` IDs. new_meta['layer'] = -1 new_meta['index'] = meta['index'] + ip_len_old self._param_shapes_meta.append(new_meta) if hnet._hyper_shapes_learned is not None: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = [] self._hyper_shapes_learned_ref = [] self._hyper_shapes_learned.extend( \ list(hnet._hyper_shapes_learned)) for ref in hnet._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ref + ps_len_old) if hnet._hyper_shapes_distilled is not None: if self._hyper_shapes_distilled is None: self._hyper_shapes_distilled = [] self._hyper_shapes_distilled.extend( \ list(hnet._hyper_shapes_distilled)) if self._has_bias is None: self._has_bias = hnet._has_bias elif self._has_bias != hnet._has_bias: self._has_bias = False # FIXME We should overwrite the getter and throw an error! warn('Some internally maintained hypernetworks use biases, ' + 'while others don\'t. Setting attribute "has_bias" to ' + 'False.') if self._has_fc_out is None: self._has_fc_out = hnet._has_fc_out else: assert self._has_fc_out == hnet._has_fc_out if self._has_linear_out is None: self._has_linear_out = hnet._has_linear_out else: assert self._has_linear_out == hnet._has_linear_out self._layer_weight_tensors.extend( \ nn.ParameterList(hnet._layer_weight_tensors)) self._layer_bias_vectors.extend( \ nn.ParameterList(hnet._layer_bias_vectors)) if hnet._batchnorm_layers is not None: if self._batchnorm_layers is None: self._batchnorm_layers = nn.ModuleList() self._batchnorm_layers.extend( \ nn.ModuleList(hnet._batchnorm_layers)) if hnet._context_mod_layers is not None: if self._context_mod_layers is None: self._context_mod_layers = nn.ModuleList() self._context_mod_layers.extend( \ nn.ModuleList(hnet._context_mod_layers)) if self._hyper_shapes_distilled is not None: raise NotImplementedError('Distillation of parameters not ' + 'supported yet!') ### Create chunk embeddings ### if cond_in_size == 0 and uncond_in_size == 0 and 0 in chunk_emb_sizes: raise ValueError('At least one internal hypernetwork has no ' + 'chunk embedding(s). Therefore, the input size ' + 'might not be 0.') if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs: # Note, we could also allow this case. It would be analoguous to # creating a full hypernet with no unconditional input and one # conditional embedding. But the user can explicitly achieve that # as noted below. raise ValueError('If no external (conditional or unconditional) ' + 'input is provided to the hypernetwork, then ' + 'it can only learn a fixed output. If this ' + 'behavior is desired, please enable ' + '"cond_chunk_embs" and set "num_cond_embs=1".') chunk_emb_shapes = [] # To which internal hnet does the corresponding chunk shape belong to. chunk_emb_refs = [] for i, size in enumerate(chunk_emb_sizes): if size == 0: # No chunk embeddings for internal hnet `i`. continue chunk_emb_refs.append(i) assert num_per_chunk[i] > 0 chunk_emb_shapes.append([num_per_chunk[i], size]) self._chunk_emb_shapes = chunk_emb_shapes self._chunk_emb_refs = chunk_emb_refs # How often do we have to instantiate the chunk embeddings prescribed by # `chunk_emb_shapes`? num_cemb_weights = 1 no_cemb_weights = no_uncond_weights if cond_chunk_embs: num_cemb_weights = num_cond_embs no_cemb_weights = no_cond_weights # Number of conditional and unconditional parameters so far. tmp_num_uncond = len(self._unconditional_param_shapes_ref) tmp_num_cond = len(self._param_shapes) - tmp_num_uncond # List of lists of inds. # Indices of chunk embedding per condition within # `conditional_param_shapes`, if chunk embeddings are conditional. # Otherwise, indices of chunk embeddings within # `unconditional_param_shapes`. self._chunk_emb_inds = [[] for _ in range(num_cemb_weights)] for i in range(num_cemb_weights): for j, shape in enumerate(chunk_emb_shapes): if not no_cemb_weights: self._internal_params.append(nn.Parameter( \ data=torch.Tensor(*shape), requires_grad=True)) torch.nn.init.normal_(self._internal_params[-1], mean=0., std=1.) else: self._hyper_shapes_learned.append(shape) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) if not cond_chunk_embs: self._unconditional_param_shapes_ref.append( \ len(self.param_shapes)) self._param_shapes.append(shape) # In principle, these embeddings also belong to the input, so we # just assign them as "layer" 0 (note, the underlying hnets use # the same layer ID for its embeddings. self._param_shapes_meta.append({ 'name': 'embedding', 'index': -1 if no_cemb_weights else \ len(self._internal_params)-1, 'layer': 0, 'info': 'chunk embeddings', 'hnet_ind': chunk_emb_refs[j], 'cond_id': i if cond_chunk_embs else -1 }) if cond_chunk_embs: self._chunk_emb_inds[i].append(tmp_num_cond) tmp_num_cond += 1 else: self._chunk_emb_inds[i].append(tmp_num_uncond) tmp_num_uncond += 1 assert len(self.param_shapes) == tmp_num_uncond + tmp_num_cond ### Finalize construction ### self._is_properly_setup() if verbose: print('Created Structured Chunked MLP Hypernet.') print('It manages %d full hypernetworks internally that produce ' \ % (num_hnets) + '%s chunks in total.' % (self.num_chunks)) print('The internal hypernetworks have a combined output size of ' + '%d compared to %d weights produced by this network.' \ % (num_chunk_weights, self.num_outputs)) print(self)
def __init__(self, target_shapes, chunk_size, chunk_emb_size=8, cond_chunk_embs=False, uncond_in_size=0, cond_in_size=8, layers=(100, 100), verbose=True, activation_fn=torch.nn.ReLU(), use_bias=True, no_uncond_weights=False, no_cond_weights=False, num_cond_embs=1, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) assert isinstance(chunk_size, int) and chunk_size > 0 assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0 ### Make constructor arguments internally available ### self._chunk_size = chunk_size self._chunk_emb_size = chunk_emb_size self._cond_chunk_embs = cond_chunk_embs self._uncond_in_size = uncond_in_size self._cond_in_size = cond_in_size self._no_uncond_weights = no_uncond_weights self._no_cond_weights = no_cond_weights self._num_cond_embs = num_cond_embs ### Create underlying full hypernet ### # Note, even if chunk embeddings are considered conditional, they # are maintained in this object and just fed as an external input to the # underlying hnet. hnet_uncond_in_size = uncond_in_size + chunk_emb_size hnet_num_cond_embs = num_cond_embs if cond_chunk_embs and num_cond_embs == 0: raise ValueError('Conditional chunk embeddings can only be used ' + 'if conditions are known to the hypernetwork!') if cond_chunk_embs and cond_in_size == 0: # If there are no other conditional embeddings except the chunk # embeddings, we tell the underlying hnet explicitly that it doesn't # need to maintain any conditional weights to avoid that it will # throw a warning. hnet_num_cond_embs = 0 self._hnet = HMLP([[chunk_size]], uncond_in_size=hnet_uncond_in_size, cond_in_size=cond_in_size, layers=layers, verbose=False, activation_fn=activation_fn, use_bias=use_bias, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=hnet_num_cond_embs, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm) ### Setup attributes required by interface ### # Most of these attributes are taken over from `self._hnet` self._target_shapes = target_shapes self._num_known_conds = self._num_cond_embs self._unconditional_param_shapes_ref = \ list(self._hnet._unconditional_param_shapes_ref) if self._hnet._internal_params is not None: self._internal_params = \ nn.ParameterList(self._hnet._internal_params) self._param_shapes = list(self._hnet._param_shapes) self._param_shapes_meta = list(self._hnet._param_shapes_meta) if self._hnet._hyper_shapes_learned is not None: self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned) self._hyper_shapes_learned_ref = \ list(self._hnet._hyper_shapes_learned_ref) if self._hnet._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = \ list(self._hnet._hyper_shapes_distilled) self._has_bias = self._hnet._has_bias self._has_fc_out = self._hnet._has_fc_out # Just to make that clear explicitly. We will additionally append # the chunk embeddings at the end of `param_shapes`. # We don't prepend it to the beginning, to keep conditional input # embeddings at the beginning. self._mask_fc_out = False self._has_linear_out = self._hnet._has_linear_out self._layer_weight_tensors = \ nn.ParameterList(self._hnet._layer_weight_tensors) self._layer_bias_vectors = \ nn.ParameterList(self._hnet._layer_bias_vectors) if self._hnet._batchnorm_layers is not None: self._batchnorm_layers = nn.ModuleList( self._hnet._batchnorm_layers) if self._hnet._context_mod_layers is not None: self._context_mod_layers = \ nn.ModuleList(self._hnet._context_mod_layers) ### Create chunk embeddings ### if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs: # Note, we could also allow this case. It would be analoguous to # creating a full hypernet with no unconditional input and one # conditional embedding. But the user can explicitly achieve that # as noted below. raise ValueError('If no external (conditional or unconditional) ' + 'input is provided to the hypernetwork, then ' + 'it can only learn a fixed output. If this ' + 'behavior is desired, please enable ' + '"cond_chunk_embs" and set "num_cond_embs=1".') num_cemb_mats = 1 no_cemb_weights = no_uncond_weights if cond_chunk_embs: num_cemb_mats = num_cond_embs no_cemb_weights = no_cond_weights self._cemb_shape = [self.num_chunks, chunk_emb_size] for _ in range(num_cemb_mats): if not no_cemb_weights: self._internal_params.append(nn.Parameter( \ data=torch.Tensor(*self._cemb_shape), requires_grad=True)) torch.nn.init.normal_(self._internal_params[-1], mean=0., std=1.) else: self._hyper_shapes_learned.append(self._cemb_shape) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) if not cond_chunk_embs: self._unconditional_param_shapes_ref.append( \ len(self.param_shapes)) self._param_shapes.append(self._cemb_shape) # In principle, these embeddings also belong to the input, so we # just assign them as "layer" 0 (note, the underlying hnet uses the # same layer ID for its embeddings. self._param_shapes_meta.append({ 'name': 'embedding', 'index': -1 if no_cemb_weights else \ len(self._internal_params)-1, 'layer': 0, 'info': 'chunk embeddings' }) ### Finalize construction ### self._is_properly_setup() if verbose: print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \ % (self.num_chunks, chunk_size)) print(self)
class ChunkedHMLP(nn.Module, HyperNetInterface): """Implementation of a `chunked fully-connected hypernet`. The ``target_shapes`` will be flattened and split into chunks of size ``chunk_size``. In total, there will be ``np.ceil(self.num_outputs/chunk_size)`` chunks, where the last chunk produced might contain a remainder that is discarded. Each chunk has it's own `chunk embedding` that is fed into the underlying hypernetwork. Note: It is possible to set ``uncond_in_size`` and ``cond_in_size`` to zero if ``cond_chunk_embs`` is ``True``. Attributes: num_chunks (int): The number of chunks that make up the final hypernet output. This also corresponds to the number of chunk embeddings required per forward sweep. chunk_emb_size (int): See constructor argument ``chunk_emb_size``. Args: (....): See constructor arguments of class :class:`hnets.mlp_hnet.HMLP`. chunk_size (int): The chunk size, i.e, the number of weights produced by single the internally maintained instance of a full hypernet (see :class:`hnets.mlp_hnet.HMLP`) at a time (i.e., per chunk embedding). chunk_emb_size (int): The size of a chunk embedding. cond_chunk_embs (bool): Whether chunk embeddings are unconditional (``False``) or conditional (``True``) parameters. See constructor argument ``cond_chunk_embs``. Note: Embeddings will be initialized with a normal distribution using zero mean and unit variance. cond_chunk_embs (bool): Consider chunk embeddings to be conditional. In this case, there will be a different set of chunk embeddings per condition (specified via ``num_cond_embs``). If ``False``, there will be a total of :attr:`num_chunks` chunk embeddings that are maintained within :attr:`hnets.hnet_interface.\ HyperNetInterface.unconditional_param_shapes`. If ``True``, there will be ``num_cond_embs * self.num_chunks`` chunk embeddings that are maintained within :attr:`hnets.hnet_interface.\ HyperNetInterface.conditional_param_shapes`. However, if ``num_cond_embs == 0``, then chunk embeddings have to be provided in a special way to the :meth:`forward` method (see the corresponding argument ``weights``). """ def __init__(self, target_shapes, chunk_size, chunk_emb_size=8, cond_chunk_embs=False, uncond_in_size=0, cond_in_size=8, layers=(100, 100), verbose=True, activation_fn=torch.nn.ReLU(), use_bias=True, no_uncond_weights=False, no_cond_weights=False, num_cond_embs=1, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) assert isinstance(chunk_size, int) and chunk_size > 0 assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0 ### Make constructor arguments internally available ### self._chunk_size = chunk_size self._chunk_emb_size = chunk_emb_size self._cond_chunk_embs = cond_chunk_embs self._uncond_in_size = uncond_in_size self._cond_in_size = cond_in_size self._no_uncond_weights = no_uncond_weights self._no_cond_weights = no_cond_weights self._num_cond_embs = num_cond_embs ### Create underlying full hypernet ### # Note, even if chunk embeddings are considered conditional, they # are maintained in this object and just fed as an external input to the # underlying hnet. hnet_uncond_in_size = uncond_in_size + chunk_emb_size hnet_num_cond_embs = num_cond_embs if cond_chunk_embs and num_cond_embs == 0: raise ValueError('Conditional chunk embeddings can only be used ' + 'if conditions are known to the hypernetwork!') if cond_chunk_embs and cond_in_size == 0: # If there are no other conditional embeddings except the chunk # embeddings, we tell the underlying hnet explicitly that it doesn't # need to maintain any conditional weights to avoid that it will # throw a warning. hnet_num_cond_embs = 0 self._hnet = HMLP([[chunk_size]], uncond_in_size=hnet_uncond_in_size, cond_in_size=cond_in_size, layers=layers, verbose=False, activation_fn=activation_fn, use_bias=use_bias, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=hnet_num_cond_embs, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm) ### Setup attributes required by interface ### # Most of these attributes are taken over from `self._hnet` self._target_shapes = target_shapes self._num_known_conds = self._num_cond_embs self._unconditional_param_shapes_ref = \ list(self._hnet._unconditional_param_shapes_ref) if self._hnet._internal_params is not None: self._internal_params = \ nn.ParameterList(self._hnet._internal_params) self._param_shapes = list(self._hnet._param_shapes) self._param_shapes_meta = list(self._hnet._param_shapes_meta) if self._hnet._hyper_shapes_learned is not None: self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned) self._hyper_shapes_learned_ref = \ list(self._hnet._hyper_shapes_learned_ref) if self._hnet._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = \ list(self._hnet._hyper_shapes_distilled) self._has_bias = self._hnet._has_bias self._has_fc_out = self._hnet._has_fc_out # Just to make that clear explicitly. We will additionally append # the chunk embeddings at the end of `param_shapes`. # We don't prepend it to the beginning, to keep conditional input # embeddings at the beginning. self._mask_fc_out = False self._has_linear_out = self._hnet._has_linear_out self._layer_weight_tensors = \ nn.ParameterList(self._hnet._layer_weight_tensors) self._layer_bias_vectors = \ nn.ParameterList(self._hnet._layer_bias_vectors) if self._hnet._batchnorm_layers is not None: self._batchnorm_layers = nn.ModuleList( self._hnet._batchnorm_layers) if self._hnet._context_mod_layers is not None: self._context_mod_layers = \ nn.ModuleList(self._hnet._context_mod_layers) ### Create chunk embeddings ### if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs: # Note, we could also allow this case. It would be analoguous to # creating a full hypernet with no unconditional input and one # conditional embedding. But the user can explicitly achieve that # as noted below. raise ValueError('If no external (conditional or unconditional) ' + 'input is provided to the hypernetwork, then ' + 'it can only learn a fixed output. If this ' + 'behavior is desired, please enable ' + '"cond_chunk_embs" and set "num_cond_embs=1".') num_cemb_mats = 1 no_cemb_weights = no_uncond_weights if cond_chunk_embs: num_cemb_mats = num_cond_embs no_cemb_weights = no_cond_weights self._cemb_shape = [self.num_chunks, chunk_emb_size] for _ in range(num_cemb_mats): if not no_cemb_weights: self._internal_params.append(nn.Parameter( \ data=torch.Tensor(*self._cemb_shape), requires_grad=True)) torch.nn.init.normal_(self._internal_params[-1], mean=0., std=1.) else: self._hyper_shapes_learned.append(self._cemb_shape) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) if not cond_chunk_embs: self._unconditional_param_shapes_ref.append( \ len(self.param_shapes)) self._param_shapes.append(self._cemb_shape) # In principle, these embeddings also belong to the input, so we # just assign them as "layer" 0 (note, the underlying hnet uses the # same layer ID for its embeddings. self._param_shapes_meta.append({ 'name': 'embedding', 'index': -1 if no_cemb_weights else \ len(self._internal_params)-1, 'layer': 0, 'info': 'chunk embeddings' }) ### Finalize construction ### self._is_properly_setup() if verbose: print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \ % (self.num_chunks, chunk_size)) print(self) @property def num_chunks(self): """Getter for read-only attribute :attr:`num_chunks`.""" return int(np.ceil(self.num_outputs / self._chunk_size)) @property def chunk_emb_size(self): """Getter for read-only attribute :attr:`chunk_emb_size`.""" return self._chunk_emb_size @property def cond_chunk_embs(self): """Getter for read-only attribute :attr:`cond_chunk_embs`.""" return self._cond_chunk_embs def forward(self, uncond_input=None, cond_input=None, cond_id=None, weights=None, distilled_params=None, condition=None, ret_format='squeezed', ext_inputs=None, task_emb=None, task_id=None, theta=None, dTheta=None): """Compute the weights of a target network. Args: (....): See docstring of method :meth:`hnets.mlp_hnet.HMLP.forward`. weights (list or dict, optional): If provided as ``dict`` and chunk embeddings are considered conditional (see constructor argument ``cond_chunk_embs``), then the additional key ``chunk_embs`` can be used to pass a batch of chunk embeddings. This option is mutually exclusive with the option of passing ``cond_id``. Note, if conditional inputs via ``cond_input`` are expected, then the batch sizes must agree. A batch of chunk embeddings is expected to be tensor of shape ``[B, num_chunks, chunk_emb_size]``, where ``B`` denotes the batch size. Returns: (list or torch.Tensor): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. """ cond_chunk_embs = None if isinstance(weights, dict): if 'chunk_embs' in weights.keys(): cond_chunk_embs = weights['chunk_embs'] if not self._cond_chunk_embs: raise ValueError('Key "chunk_embs" for argument ' + '"weights" is only allowed if chunk ' + 'embeddings are conditional.') assert len(cond_chunk_embs.shape) == 3 and \ np.all(np.equal(cond_chunk_embs.shape[1:], [self.num_chunks, self.chunk_emb_size])) if cond_id is not None: raise ValueError( 'Option "cond_id" is mutually exclusive ' + 'with key "chunk_embs" for argument ' + '"weights".') assert cond_input is None or \ cond_input.shape[0] == cond_chunk_embs.shape[0] # Remove `chunk_embs` from dictionary, since upper class parser # doesn't know how to deal with it. del weights['chunk_embs'] if len(weights.keys()) == 0: # Empty dictionary. weights = None if cond_input is not None and self._cond_chunk_embs and \ cond_chunk_embs is None: raise ValueError('Conditional chunk embeddings have to be ' + 'provided via "weights" if "cond_input" is ' + 'specified.') _input_required = self._cond_in_size > 0 or self._uncond_in_size > 0 # We parse `cond_id` afterwards if chunk embeddings are also # conditional. if self._cond_chunk_embs: _parse_cond_id_fct = lambda x, y, z: None else: _parse_cond_id_fct = None uncond_input, cond_input, uncond_weights, cond_weights = \ self._preprocess_forward_args(_input_required=_input_required, _parse_cond_id_fct=_parse_cond_id_fct, uncond_input=uncond_input, cond_input=cond_input, cond_id=cond_id, weights=weights, distilled_params=distilled_params, condition=condition, ret_format=ret_format, ext_inputs=ext_inputs, task_emb=task_emb, task_id=task_id, theta=theta, dTheta=dTheta) ### Translate IDs to conditional inputs ### if cond_id is not None and self._cond_chunk_embs: assert cond_input is None and cond_chunk_embs is None cond_id = [cond_id] if isinstance(cond_id, int) else cond_id if cond_weights is None: raise ValueError('Forward option "cond_id" can only be ' + 'used if conditional parameters are ' + 'maintained internally or passed to the ' + 'forward method via option "weights".') cond_chunk_embs = [] cond_input = [] if self._cond_in_size > 0 else None for i, cid in enumerate(cond_id): if cid < 0 or cid >= self._num_cond_embs: raise ValueError('Condition %d not existing!' % (cid)) # Note, we do not necessarily have conditional embeddings. if self._cond_in_size > 0: cond_input.append(cond_weights[cid]) cond_chunk_embs.append( \ cond_weights[-self._num_cond_embs+cid]) if self._cond_in_size > 0: cond_input = torch.stack(cond_input, dim=0) cond_chunk_embs = torch.stack(cond_chunk_embs, dim=0) ### Assemble hypernetwork input ### batch_size = None if cond_input is not None: batch_size = cond_input.shape[0] if cond_chunk_embs is not None: assert batch_size is None or batch_size == cond_chunk_embs.shape[0] batch_size = cond_chunk_embs.shape[0] if uncond_input is not None: if batch_size is None: batch_size = uncond_input.shape[0] else: assert batch_size == uncond_input.shape[0] assert batch_size is not None chunk_embs = None if self._cond_chunk_embs: assert cond_chunk_embs is not None and \ len(cond_chunk_embs.shape) == 3 assert self._cond_in_size == 0 or cond_input is not None chunk_embs = cond_chunk_embs else: assert cond_chunk_embs is None chunk_embs = uncond_weights[-1] # Insert batch dimension. chunk_embs = chunk_embs.expand(batch_size, self.num_chunks, self.chunk_emb_size) # We now have the following setup: # cond_input: [batch_size, cond_in_size] or None # uncond_input: [batch_size, uncond_in_size] or None # chunk_embs: [batch_size, num_chunks, chunk_emb_size] # We now first copy the hypernet inputs for each chunk, arriving at # cond_input: [batch_size, num_chunks, cond_in_size] or None # uncond_input: [batch_size, num_chunks, uncond_in_size] or None if cond_input is not None: cond_input = cond_input.reshape(batch_size, 1, -1) cond_input = cond_input.expand(batch_size, self.num_chunks, self._cond_in_size) if uncond_input is not None: uncond_input = uncond_input.reshape(batch_size, 1, -1) uncond_input = uncond_input.expand(batch_size, self.num_chunks, self._uncond_in_size) # The chunk embeddings are considered unconditional inputs to the # underlying hypernetwork. uncond_input = torch.cat([uncond_input, chunk_embs], dim=2) else: uncond_input = chunk_embs # Now we build one big batch for the underlying hypernetwork, with # batch size: batch_size * num_chunks. if cond_input is not None: cond_input = cond_input.reshape(batch_size * self.num_chunks, -1) uncond_input = uncond_input.reshape(batch_size * self.num_chunks, -1) ### Weight of underlying hypernetwork ### weights = dict() if cond_weights is not None and self._cond_chunk_embs: weights['cond_weights'] = cond_weights[:-self._num_cond_embs] elif cond_weights is not None: weights['cond_weights'] = cond_weights assert uncond_weights is not None if self._cond_chunk_embs: weights['uncond_weights'] = uncond_weights else: weights['uncond_weights'] = uncond_weights[:-1] ### Process chunks ### hnet_out = self._hnet.forward(uncond_input=uncond_input, cond_input=cond_input, cond_id=None, weights=weights, distilled_params=distilled_params, condition=condition, ret_format='flattened') assert np.all( np.equal(hnet_out.shape, [batch_size * self.num_chunks, self._chunk_size])) # FIXME We can skip this line, right? hnet_out = hnet_out.view(batch_size, self.num_chunks, self._chunk_size) # Concatenate individual chunks. hnet_out = hnet_out.view(batch_size, self.num_chunks * self._chunk_size) # Throw away unused part of last chunk. hnet_out = hnet_out[:, :self.num_outputs] ### Assemble hypernet output ### ret = self._flat_to_ret_format(hnet_out, ret_format) return ret def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. Returns: See :meth:`hnets.mlp_hnet.HMLP.distillation_targets`. """ # We don't have any additional distillation targets. We also just pass # `distilled_params` to the underlying hypernetwork in the `forward` # method. return self._hnet.distillation_targets def apply_chunked_hyperfan_init(self, method='in', use_xavier=False, uncond_var=1., cond_var=1., eps=1e-5, cemb_normal_init=False, mnet=None, target_vars=None): r"""Initialize the network using a chunked hyperfan init. Inspired by the method `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we implemented for the MLP hypernetwork in method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`, we heuristically developed a better initialization method for chunked hypernetworks. Unfortunately, the `Hyperfan Init` method from the paper does not apply to this kind of hypernetwork, since we reuse the same hypernet output head for the whole main network. Luckily, we can provide a simple heuristic. Similar to `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play with the variance of the input embeddings to affect the variance of the output weights. In a chunked hypernetwork, the input for each chunk is identical except for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}` denote the remaining inputs to the hypernetwork, which are identical for all chunks. Then, assuming the hypernetwork was initialized via fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}` can be written as follows (see documentation of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`): .. math:: \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Hence, we can achieve a desired output variance :math:`\text{Var}(v)` by initializing the chunk embeddings :math:`\mathbf{c}` via the following variance: .. math:: \text{Var}(c) = \max \Big\{ 0, \ \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \ n_e \text{Var}(e) \big] \Big\} Now, one important question remains. How do we pick a desired output variance :math:`\text{Var}(v)` for a chunk? Note, a chunk may include weights from several layers. The likelihood for this to happen depends on the main net architecture and the chunk size (see constructor argument ``chunk_size``). The smaller the chunk size, the less likely it is that a chunk will contain elements from multiple main net weight tensors. In case each chunk would contain only weights from one main net weight tensor, we could simply pick the variance :math:`\text{Var}(v)` that would have been chosen by a main net initialization method (such as Xavier). In case a chunk contains contributions from several main net weight tensors, we apply the following heuristic. If a chunk contains contributions of a set of main network weight tensors :math:`W_1, \dots, W_K` with relative contribution sizes\ :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where :math:`n_v` denotes the chunk size and if the corresponding main network initialization method would require init variances :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request a weighted average as follow: .. math:: \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k) What about bias vectors? Usually, the variance analysis applied to Xavier or Kaiming init assumes that biases are initialized to zero. This is not possible in this setting, as it would require assigning a negative variance to :math:`\mathbf{c}`. Instead, we follow the default PyTorch initialization (e.g., see method ``reset_parameters`` in class :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of initialization corresponds to a variance of :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`. Note: All hypernet inputs are assumed to be zero-mean random variables. Note: To avoid that the variances with which chunks are initialized have to be clipped (because they are too small or even negative), the variance of the remaining hypernet inputs should be properly scaled. In general, one should adhere the following rule .. math:: \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v) This method will calculate and print the maximum value that should be chosen for :math:`\text{Var}(e)` and will print warnings if variances have to be clipped. Args: (....): See arguments of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`. method (str): The type of initialization that should be applied. Possible options are: - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output variances of the hypernetwork should correspond to fan-in variances. - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output variances of the hypernetwork should correspond to fan-out variances. - ``harmonic``: Use the harmonic mean of the fan-in and fan-out variance as target variance of the hypernetwork output. eps (float): The minimum variance with which a chunk embedding is initialized. cemb_normal_init (bool): Use normal init for chunk embeddings rather than uniform init. target_vars (list or dict, optional): The variance of the distribution for each parameter tensor generated by this hypernetwork. Target variance values can either be provided as list of length ``len(hnet.target_shapes)`` or as dictionary. The usage is analoguous to the usage of parameter ``w_val`` of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`. Note: This method currently does not allow initial output distributions with non-zero mean. However, the docstring of method :meth:`probabilistic.gauss_hnet_init.gauss_hyperfan_init` describes how this is in principle feasible and might be incorporated in the future. Note: Unspecified target variances for parameter tensors of type ``'weight'`` or ``'bias'`` are computed as described above. Default target variances for all other parameter tensor types are simply ``1``. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value "%s" for argument "method".' % method) if self.unconditional_params is None: assert self._no_uncond_weights raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') if self.unconditional_params is None and self._cond_chunk_embs: assert self._no_cond_weights raise ValueError('Chunked hyperfan init cannot be applied if ' + 'chunk embeddings are not internally maintained.') ### Extract meta-information about target shapes ### # FIXME This section is copied from the HMLP implementation. meta = None if mnet is not None: assert isinstance(mnet, MainNetInterface) try: meta = mnet.param_shapes_meta except: meta = None if meta is not None: if len(self.target_shapes) == len(mnet.param_shapes): pass # meta = mnet.param_shapes_meta elif len(self.target_shapes) == len(mnet.hyper_shapes_learned): meta = [] for ii in mnet.hyper_shapes_learned_ref: meta.append(mnet.param_shapes_meta[ii]) else: warn('Target shapes of this hypernetwork could not be ' + 'matched to the meta information provided to the ' + 'initialization.') meta = None # TODO If the user doesn't (or can't) provide an `mnet` instance, we # should alternatively allow him to pass meta information directly. if meta is None: meta = [] # Heuristical approach to derive meta information from given shapes. layer_ind = 0 for i, s in enumerate(self.target_shapes): curr_meta = dict() if len(s) > 1: curr_meta['name'] = 'weight' curr_meta['layer'] = layer_ind layer_ind += 1 else: # just a heuristic, we can't know curr_meta['name'] = 'bias' if i > 0 and meta[-1]['name'] == 'weight': curr_meta['layer'] = meta[-1]['layer'] else: curr_meta['layer'] = -1 meta.append(curr_meta) assert len(meta) == len(self.target_shapes) # Mapping from layer index to the corresponding shape. layer_shapes = dict() # Mapping from layer index to whether the layer has a bias vector. layer_has_bias = defaultdict(lambda: False) for i, m in enumerate(meta): if m['name'] == 'weight' and m['layer'] != -1: assert len(self.target_shapes[i]) > 1 layer_shapes[m['layer']] = self.target_shapes[i] if m['name'] == 'bias' and m['layer'] != -1: layer_has_bias[m['layer']] = True ### Compute input variance ### # The input variance does not include the variance of chunk embeddings! # Instead, it is the variance of the inputs that are shared across all # chunks. cond_dim = self._cond_in_size uncond_dim = self._uncond_in_size # Note, `inp_dim` can be zero if conditional chunk embeddings are used. inp_dim = cond_dim + uncond_dim inp_var = 0 if cond_dim > 0: inp_var += (cond_dim / inp_dim) * cond_var if uncond_dim > 0: inp_var += (uncond_dim / inp_dim) * uncond_var c_dim = self.chunk_emb_size ### Initialize hypernet with fan-in init ### if self.batchnorm_layers is not None and len( self.batchnorm_layers) > 0: # Note, batchnorm layers simply whiten the incoming statistics. # Thus, if we tune the variance of chunk embeddings, this variance # is normalized by a batchnorm layer and thus vanishes. raise RuntimeError('Chunked hyperfan init not applicable if a ' + 'hypernetwork with batchnorm layers is used.') # Note, the whole internal hypernetwork is initialized with fan-in init # to simply pass the variance of all inputs to the hypernet output. for i, w_tensor in enumerate(self.layer_weight_tensors): if use_xavier: iutils.xavier_fan_in_(w_tensor) else: torch.nn.init.kaiming_uniform_(w_tensor, mode='fan_in', nonlinearity='relu') if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[i]) ### Compute target variance of each output tensor ### if target_vars is None: target_vars = [None] * len(self.target_shapes) elif isinstance(target_vars, dict): target_vars_d = target_vars target_vars = [] for i, m in enumerate(meta): if m['name'] in target_vars_d.keys(): target_vars.append(target_vars_d[m['name']]) else: target_vars.append(None) else: assert isinstance(target_vars, (list, tuple)) assert len(target_vars) == len(self.target_shapes) for i, s in enumerate(self.target_shapes): if target_vars[i] is not None: # Use user specified target variance. continue m = meta[i] if m['name'] == 'bias': if m['layer'] != -1: fan_in, _ = iutils.calc_fan_in_and_out( \ layer_shapes[m['layer']]) else: # FIXME Quick-fix, use fan-out instead. fan_in = s[0] target_vars[i] = 1. / (3. * fan_in) elif m['name'] == 'weight': fan_in, fan_out = iutils.calc_fan_in_and_out(s) c_relu = 1 if use_xavier else 2 var_in = c_relu / fan_in var_out = c_relu / fan_out if method == 'in': var = var_in elif method == 'out': var = var_out else: var = 2 * (1. / var_in + 1. / var_out) target_vars[i] = var else: target_vars[i] = 1. ### Target variance per chunk ### chunk_vars = [] i = 0 n = np.prod(self.target_shapes[i]) for j in range(self.num_chunks): m = self._chunk_size var = 0 while m > 0: # Special treatment to fill up last chunk. if j == self.num_chunks - 1 and i == len(target_vars) - 1: assert n <= m o = m else: o = min(m, n) var += o / self._chunk_size * target_vars[i] m -= o n -= o if n == 0: i += 1 if i < len(target_vars): n = np.prod(self.target_shapes[i]) chunk_vars.append(var) if inp_dim > 0: max_inp_var = (inp_dim + c_dim) / inp_dim * min(chunk_vars) max_inp_std = math.sqrt(max_inp_var) print('Initializing hypernet with Chunked Hyperfan Init ...') if inp_var >= max_inp_var: warn('Note, hypernetwork inputs should have an initial total ' + 'variance (std) smaller than %f (%f) in order for this ' \ % (max_inp_var, max_inp_std) + 'method to work properly.') ### Compute variances of chunk embeddings ### # We could have done that in the previous loop. But I think the code is # more readible this way. c_vars = [] n_clipped = 0 for i, var in enumerate(chunk_vars): c_var = 1. / c_dim * ((inp_dim + c_dim) * var - inp_dim * inp_var) if c_var < eps: n_clipped += 1 #warn('Initial variance of chunk embedding %d has to ' % i + \ # 'be clipped.') c_vars.append(max(eps, c_var)) if n_clipped > 0: warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \ 'chunk embeddings had to be clipped.') ### Initialize chunk embeddings ### for i in range(self.num_chunks): c_std = math.sqrt(c_vars[i]) num_conds = self.num_known_conds if self._cond_chunk_embs else 1 for j in range(num_conds): cond_id = j if self._cond_chunk_embs else None c_emb = self.get_chunk_emb(chunk_id=i, cond_id=cond_id) if cemb_normal_init: torch.nn.init.normal_(c_emb, mean=0, std=c_std) else: a = math.sqrt(3.0) * c_std torch.nn.init._no_grad_uniform_(c_emb, -a, a) def get_cond_in_emb(self, cond_id): """Get the ``cond_id``-th (conditional) input embedding. Args: (....): See docstring of method :meth:`hnets.mlp_hnet.HMLP.get_cond_in_emb`. Returns: (torch.nn.Parameter) """ return self._hnet.get_cond_in_emb(cond_id) def get_chunk_emb(self, chunk_id=None, cond_id=None): """Get the ``chunk_id``-th chunk embedding. Args: chunk_id (int, optional): A number between 0 and :attr:`num_chunks` - 1. If not specified, a full chunk matrix with shape ``[num_chunks, chunk_emb_size]`` is returned. Otherwise, the ``chunk_id``-th row is returned. cond_id (int): Is mandatory if constructor argument ``cond_chunk_embs`` was set. Determines the set of chunk embeddings to be considered. Returns: (torch.nn.Parameter) """ if self._cond_chunk_embs: if cond_id is None: raise RuntimeError('Option "cond_id" has to be set if chunk ' + 'embeddings are conditional parameters!') if self.conditional_params is None: raise RuntimeError('Conditional chunk embeddings are not ' + 'internally maintained!') if not isinstance(cond_id, int) or cond_id < 0 or \ cond_id >= self._num_cond_embs: raise RuntimeError('Option "cond_id" must be between 0 and ' + '%d!' % (self._num_cond_embs - 1)) # Note, the last `self._num_cond_embs` params are chunk embeddings. chunk_embs = self.conditional_params[-self._num_cond_embs + cond_id] else: assert cond_id is None if self.unconditional_params is None: raise RuntimeError('Chunk embeddings are not internally ' + 'maintained!') chunk_embs = self.unconditional_params[-1] if chunk_id is None: return chunk_embs else: if not isinstance(chunk_id, int) or chunk_id < 0 or \ chunk_id >= self.num_chunks: raise RuntimeError('Option "chunk_id" must be between 0 and ' + '%d!' % (self.num_chunks - 1)) return chunk_embs[chunk_id, :]
class ChunkedHMLP(nn.Module, HyperNetInterface): """Implementation of a `chunked fully-connected hypernet`. The ``target_shapes`` will be flattened and split into chunks of size ``chunk_size``. In total, there will be ``np.ceil(self.num_outputs/chunk_size)`` chunks, where the last chunk produced might contain a remainder that is discarded. Each chunk has it's own `chunk embedding` that is fed into the underlying hypernetwork. Note: It is possible to set ``uncond_in_size`` and ``cond_in_size`` to zero if ``cond_chunk_embs`` is ``True``. Attributes: num_chunks (int): The number of chunks that make up the final hypernet output. This also corresponds to the number of chunk embeddings required per forward sweep. chunk_emb_size (int): See constructor argument ``chunk_emb_size``. Args: (....): See constructor arguments of class :class:`hnets.mlp_hnet.HMLP`. chunk_size (int): The chunk size, i.e, the number of weights produced by single the internally maintained instance of a full hypernet (see :class:`hnets.mlp_hnet.HMLP`) at a time (i.e., per chunk embedding). chunk_emb_size (int): The size of a chunk embedding. Note: Embeddings will be initialized with a normal distribution using zero mean and unit variance. cond_chunk_embs (bool): Consider chunk embeddings to be conditional. In this case, there will be a different set of chunk embeddings per condition (specified via ``num_cond_embs``). If ``False``, there will be a total of :attr:`num_chunks` chunk embeddings that are maintained within :attr:`hnets.hnet_interface.\ HyperNetInterface.unconditional_param_shapes`. If ``True``, there will be ``num_cond_embs * self.num_chunks`` chunk embeddings that are maintained within :attr:`hnets.hnet_interface.\ HyperNetInterface.conditional_param_shapes`. However, if ``num_cond_embs == 0``, then chunk embeddings have to be provided in a special way to the :meth:`forward` method (see the corresponding argument ``weights``). """ def __init__(self, target_shapes, chunk_size, chunk_emb_size=8, cond_chunk_embs=False, uncond_in_size=0, cond_in_size=8, layers=(100, 100), verbose=True, activation_fn=torch.nn.ReLU(), use_bias=True, no_uncond_weights=False, no_cond_weights=False, num_cond_embs=1, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) assert isinstance(chunk_size, int) and chunk_size > 0 assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0 ### Make constructor arguments internally available ### self._chunk_size = chunk_size self._chunk_emb_size = chunk_emb_size self._cond_chunk_embs = cond_chunk_embs self._uncond_in_size = uncond_in_size self._cond_in_size = cond_in_size self._no_uncond_weights = no_uncond_weights self._no_cond_weights = no_cond_weights self._num_cond_embs = num_cond_embs ### Create underlying full hypernet ### # Note, even if chunk embeddings are considered conditional, they # are maintained in this object and just fed as an external input to the # underlying hnet. hnet_uncond_in_size = uncond_in_size + chunk_emb_size hnet_num_cond_embs = num_cond_embs if cond_chunk_embs and cond_in_size == 0: # If there are no other conditional embeddings except the chunk # embeddings, we tell the underlying hnet explicitly that it doesn't # need to maintain any conditional weights to avoid that it will # throw a warning. hnet_num_cond_embs = 0 self._hnet = HMLP([[chunk_size]], uncond_in_size=hnet_uncond_in_size, cond_in_size=cond_in_size, layers=layers, verbose=False, activation_fn=activation_fn, use_bias=use_bias, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=hnet_num_cond_embs, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm) ### Setup attributes required by interface ### # Most of these attributes are taken over from `self._hnet` self._target_shapes = target_shapes self._num_known_conds = self._num_cond_embs self._unconditional_param_shapes_ref = \ list(self._hnet._unconditional_param_shapes_ref) if self._hnet._internal_params is not None: self._internal_params = \ nn.ParameterList(self._hnet._internal_params) self._param_shapes = list(self._hnet._param_shapes) self._param_shapes_meta = list(self._hnet._param_shapes_meta) if self._hnet._hyper_shapes_learned is not None: self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned) self._hyper_shapes_learned_ref = \ list(self._hnet._hyper_shapes_learned_ref) if self._hnet._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = \ list(self._hnet._hyper_shapes_distilled) self._has_bias = self._hnet._has_bias self._has_fc_out = self._hnet._has_fc_out # Just to make that clear explicitly. We will additionally append # the chunk embeddings at the end of `param_shapes`. # We don't prepend it to the beginning, to keep conditional input # embeddings at the beginning. self._mask_fc_out = False self._has_linear_out = self._hnet._has_linear_out self._layer_weight_tensors = \ nn.ParameterList(self._hnet._layer_weight_tensors) self._layer_bias_vectors = \ nn.ParameterList(self._hnet._layer_bias_vectors) if self._hnet._batchnorm_layers is not None: self._batchnorm_layers = nn.ModuleList(self._hnet._batchnorm_layers) if self._hnet._context_mod_layers is not None: self._context_mod_layers = \ nn.ModuleList(self._hnet._context_mod_layers) ### Create chunk embeddings ### if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs: # Note, we could also allow this case. It would be analoguous to # creating a full hypernet with no unconditional input and one # conditional embedding. But the user can explicitly achieve that # as noted below. raise ValueError('If no external (conditional or unconditional) ' + 'input is provided to the hypernetwork, then ' + 'it can only learn a fixed output. If this ' + 'behavior is desired, please enable ' + '"cond_chunk_embs" and set "num_cond_embs=1".') num_cemb_mats = 1 no_cemb_weights = no_uncond_weights if cond_chunk_embs: num_cemb_mats = num_cond_embs no_cemb_weights = no_cond_weights self._cemb_shape = [self.num_chunks, chunk_emb_size] for _ in range(num_cemb_mats): if not no_cemb_weights: self._internal_params.append(nn.Parameter( \ data=torch.Tensor(*self._cemb_shape), requires_grad=True)) torch.nn.init.normal_(self._internal_params[-1], mean=0., std=1.) else: self._hyper_shapes_learned.append(self._cemb_shape) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) if not cond_chunk_embs: self._unconditional_param_shapes_ref.append( \ len(self.param_shapes)) self._param_shapes.append(self._cemb_shape) # In principle, these embeddings also belong to the input, so we # just assign them as "layer" 0 (note, the underlying hnet uses the # same layer ID for its embeddings. self._param_shapes_meta.append({ 'name': 'embedding', 'index': -1 if no_cemb_weights else \ len(self._internal_params)-1, 'layer': 0, 'info': 'chunk embeddings' }) ### Finalize construction ### self._is_properly_setup() if verbose: print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \ % (self.num_chunks, chunk_size)) print(self) @property def num_chunks(self): """Getter for read-only attribute :attr:`num_chunks`.""" return int(np.ceil(self.num_outputs / self._chunk_size)) @property def chunk_emb_size(self): """Getter for read-only attribute :attr:`chunk_emb_size`.""" return self._chunk_emb_size def forward(self, uncond_input=None, cond_input=None, cond_id=None, weights=None, distilled_params=None, condition=None, ret_format='squeezed', ext_inputs=None, task_emb=None, task_id=None, theta=None, dTheta=None): """Compute the weights of a target network. Args: (....): See docstring of method :meth:`hnets.mlp_hnet.HMLP.forward`. weights (list or dict, optional): If provided as ``dict`` and chunk embeddings are considered conditional (see constructor argument ``cond_chunk_embs``), then the additional key ``chunk_embs`` can be used to pass a batch of chunk embeddings. This option is mutually exclusive with the option of passing ``cond_id``. Note, if conditional inputs via ``cond_input`` are expected, then the batch sizes must agree. A batch of chunk embeddings is expected to be tensor of shape ``[B, num_chunks, chunk_emb_size]``, where ``B`` denotes the batch size. Returns: (list or torch.Tensor): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. """ cond_chunk_embs = None if isinstance(weights, dict): if 'chunk_embs' in weights.keys(): cond_chunk_embs = weights['chunk_embs'] if not self._cond_chunk_embs: raise ValueError('Key "chunk_embs" for argument ' + '"weights" is only allowed if chunk ' + 'embeddings are conditional.') assert len(cond_chunk_embs.shape) == 3 and \ np.all(np.equal(cond_chunk_embs.shape[1:], [self.num_chunks, self.chunk_emb_size])) if cond_id is not None: raise ValueError('Option "cond_id" is mutually exclusive ' + 'with key "chunk_embs" for argument ' + '"weights".') assert cond_input is None or \ cond_input.shape[0] == cond_chunk_embs.shape[0] # Remove `chunk_embs` from dictionary, since upper class parser # doesn't know how to deal with it. del weights['chunk_embs'] if len(weights.keys()) == 0: # Empty dictionary. weights = None if cond_input is not None and self._cond_chunk_embs and \ cond_chunk_embs is None: raise ValueError('Conditional chunk embeddings have to be ' + 'provided via "weights" if "cond_input" is ' + 'specified.') _input_required = self._cond_in_size > 0 or self._uncond_in_size > 0 # We parse `cond_id` afterwards if chunk embeddings are also # conditional. if self._cond_chunk_embs: _parse_cond_id_fct = lambda x, y, z: None else: _parse_cond_id_fct = None uncond_input, cond_input, uncond_weights, cond_weights = \ self._preprocess_forward_args(_input_required=_input_required, _parse_cond_id_fct=_parse_cond_id_fct, uncond_input=uncond_input, cond_input=cond_input, cond_id=cond_id, weights=weights, distilled_params=distilled_params, condition=condition, ret_format=ret_format, ext_inputs=ext_inputs, task_emb=task_emb, task_id=task_id, theta=theta, dTheta=dTheta) ### Translate IDs to conditional inputs ### if cond_id is not None and self._cond_chunk_embs: assert cond_input is None and cond_chunk_embs is None cond_id = [cond_id] if isinstance(cond_id, int) else cond_id if cond_weights is None: raise ValueError('Forward option "cond_id" can only be ' + 'used if conditional parameters are ' + 'maintained internally or passed to the ' + 'forward method via option "weights".') cond_chunk_embs = [] cond_input = [] if self._cond_in_size > 0 else None for i, cid in enumerate(cond_id): if cid < 0 or cid >= self._num_cond_embs: raise ValueError('Condition %d not existing!' % (cid)) # Note, we do not necessarily have conditional embeddings. if self._cond_in_size > 0: cond_input.append(cond_weights[cid]) cond_chunk_embs.append( \ cond_weights[-self._num_cond_embs+cid]) if self._cond_in_size > 0: cond_input = torch.stack(cond_input, dim=0) cond_chunk_embs = torch.stack(cond_chunk_embs, dim=0) ### Assemble hypernetwork input ### batch_size = None if cond_input is not None: batch_size = cond_input.shape[0] if cond_chunk_embs is not None: assert batch_size is None or batch_size == cond_chunk_embs.shape[0] batch_size = cond_chunk_embs.shape[0] if uncond_input is not None: if batch_size is None: batch_size = uncond_input.shape[0] else: assert batch_size == uncond_input.shape[0] assert batch_size is not None chunk_embs = None if self._cond_chunk_embs: assert cond_chunk_embs is not None and \ len(cond_chunk_embs.shape) == 3 assert self._cond_in_size == 0 or cond_input is not None chunk_embs = cond_chunk_embs else: assert cond_chunk_embs is None chunk_embs = uncond_weights[-1] # Insert batch dimension. chunk_embs = chunk_embs.expand(batch_size, self.num_chunks, self.chunk_emb_size) # We now have the following setup: # cond_input: [batch_size, cond_in_size] or None # uncond_input: [batch_size, uncond_in_size] or None # chunk_embs: [batch_size, num_chunks, chunk_emb_size] # We now first copy the hypernet inputs for each chunk, arriving at # cond_input: [batch_size, num_chunks, cond_in_size] or None # uncond_input: [batch_size, num_chunks, uncond_in_size] or None if cond_input is not None: cond_input = cond_input.reshape(batch_size, 1, -1) cond_input = cond_input.expand(batch_size, self.num_chunks, self._cond_in_size) if uncond_input is not None: uncond_input = uncond_input.reshape(batch_size, 1, -1) uncond_input = uncond_input.expand(batch_size, self.num_chunks, self._uncond_in_size) # The chunk embeddings are considered unconditional inputs to the # underlying hypernetwork. uncond_input = torch.cat([uncond_input, chunk_embs], dim=2) else: uncond_input = chunk_embs # Now we build one big batch for the underlying hypernetwork, with # batch size: batch_size * num_chunks. if cond_input is not None: cond_input = cond_input.reshape(batch_size * self.num_chunks, -1) uncond_input = uncond_input.reshape(batch_size * self.num_chunks, -1) ### Weight of underlying hypernetwork ### weights = dict() if cond_weights is not None and self._cond_chunk_embs: weights['cond_weights'] = cond_weights[:-self._num_cond_embs] elif cond_weights is not None: weights['cond_weights'] = cond_weights assert uncond_weights is not None if self._cond_chunk_embs: weights['uncond_weights'] = uncond_weights else: weights['uncond_weights'] = uncond_weights[:-1] ### Process chunks ### hnet_out = self._hnet.forward(uncond_input=uncond_input, cond_input=cond_input, cond_id=None, weights=weights, distilled_params=distilled_params, condition=condition, ret_format='flattened') assert np.all(np.equal(hnet_out.shape, [batch_size * self.num_chunks, self._chunk_size])) # FIXME We can skip this line, right? hnet_out = hnet_out.view(batch_size, self.num_chunks, self._chunk_size) # Concatenate individual chunks. hnet_out = hnet_out.view(batch_size, self.num_chunks * self._chunk_size) ### Assemble hypernet output ### ret = self._flat_to_ret_format(hnet_out, ret_format) return ret def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. Returns: See :meth:`hnets.mlp_hnet.HMLP.distillation_targets`. """ # We don't have any additional distillation targets. We also just pass # `distilled_params` to the underlying hypernetwork in the `forward` # method. return self._hnet.distillation_targets def apply_chunked_hyperfan_init(self, method='in', use_xavier=False, uncond_var=1., cond_var=1., eps=1e-5, cemb_normal_init=False): """Not implemented yet!""" # TODO Translate from old hypernet implementation and take meta # information of generated parameters into account. raise NotImplementedError() def get_cond_in_emb(self, cond_id): """Get the ``cond_id``-th (conditional) input embedding. Args: (....): See docstring of method :meth:`hnets.mlp_hnet.HMLP.get_cond_in_emb`. Returns: (torch.nn.Parameter) """ return self._hnet.get_cond_in_emb(cond_id) def get_chunk_emb(self, chunk_id=None, cond_id=None): """Get the ``chunk_id``-th chunk embedding. Args: chunk_id (int, optional): A number between 0 and :attr:`num_chunks` - 1. If not specified, a full chunk matrix with shape ``[num_chunks, chunk_emb_size]`` is returned. Otherwise, the ``chunk_id``-th row is returned. cond_id (int): Is mandatory if constructor argument ``cond_chunk_embs`` was set. Determines the set of chunk embeddings to be considered. Returns: (torch.nn.Parameter) """ if self._cond_chunk_embs: if cond_id is None: raise RuntimeError('Option "cond_id" has to be set if chunk ' + 'embeddings are conditional parameters!') if self.conditional_params is None: raise RuntimeError('Conditional chunk embeddings are not ' + 'internally maintained!') if not isinstance(cond_id, int) or cond_id < 0 or \ cond_id >= self._num_cond_embs: raise RuntimeError('Option "cond_id" must be between 0 and ' + '%d!' % (self._num_cond_embs-1)) # Note, the last `self._num_cond_embs` params are chunk embeddings. chunk_embs = self.conditional_params[-self._num_cond_embs+cond_id] else: assert cond_id is None if self.unconditional_params is None: raise RuntimeError('Chunk embeddings are not internally ' + 'maintained!') chunk_embs = self.unconditional_params[-1] if chunk_id is None: return chunk_embs else: if not isinstance(chunk_id, int) or chunk_id < 0 or \ chunk_id >= self.num_chunks: raise RuntimeError('Option "chunk_id" must be between 0 and ' + '%d!' % (self.num_chunks-1)) return chunk_embs[chunk_id, :]