def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.hypernet_args`. num_tasks (int): The number of task embeddings the hypernetwork should have. device: PyTorch device. mnet_shapes: Dimensions of the weight tensors of the main network. See main net argument :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hypernet_args`. Returns: The created hypernet model. """ if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) hyper_chunks = misc.str_to_ints(gc('hyper_chunks')) assert (len(hyper_chunks) in [1, 2, 3]) if len(hyper_chunks) == 1: hyper_chunks = hyper_chunks[0] hnet_arch = misc.str_to_ints(gc('hnet_arch')) sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters')) sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels')) sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers')) hnet_act = misc.str_to_act(gc('hnet_act')) if isinstance(hyper_chunks, list): # Chunked self-attention hypernet if len(sa_hnet_kernels) == 1: sa_hnet_kernels = sa_hnet_kernels[0] # Note, that the user can specify the kernel size for each dimension and # layer separately. elif len(sa_hnet_kernels) > 2 and \ len(sa_hnet_kernels) == gc('sa_hnet_num_layers') * 2: tmp = sa_hnet_kernels sa_hnet_kernels = [] for i in range(0, len(tmp), 2): sa_hnet_kernels.append([tmp[i], tmp[i + 1]]) if gc('hnet_dropout_rate') != -1: warn('SA-Hypernet doesn\'t use dropout. Dropout rate will be ' + 'ignored.') if gc('hnet_act') != 'relu': warn('SA-Hypernet doesn\'t support the other non-linearities ' + 'than ReLUs yet. Option "%shnet_act" (%s) will be ignored.' % (cprefix, gc('hnet_act'))) hnet = SAHyperNetwork( mnet_shapes, num_tasks, out_size=hyper_chunks, num_layers=gc('sa_hnet_num_layers'), num_filters=sa_hnet_filters, kernel_size=sa_hnet_kernels, sa_units=sa_hnet_attention_layers, # Note, we don't use an additional hypernet for the remaining # weights! #rem_layers=hnet_arch, te_dim=gc('temb_size'), ce_dim=gc('emb_size'), no_theta=False, # Batchnorm and spectral norma are not yet implemented. #use_batch_norm=gc('hnet_batchnorm'), #use_spectral_norm=gc('hnet_specnorm'), # Droput would only be used for the additional network, which we # don't use. #dropout_rate=gc('hnet_dropout_rate'), discard_remainder=True, noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) elif hyper_chunks != -1: # Chunked fully-connected hypernet hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks, chunk_dim=hyper_chunks, layers=hnet_arch, activation_fn=hnet_act, te_dim=gc('temb_size'), ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) else: # Fully-connected hypernet. hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch, te_dim=gc('temb_size'), activation_fn=hnet_act, dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) return hnet
def generate_classifier(config, data_handlers, device): """Create a classifier network. Depending on the experiment and method, the method manages to build either a classifier for task inference or a classifier that solves our task is build. This also implies if the network will receive weights from a hypernetwork or will have weights on its own. Following important configurations will be determined in order to create the classifier: \n * in- and output and hidden layer dimensions of the classifier. \n * architecture, chunk- and task-embedding details of the hypernetwork. See :class:`mnets.mlp.MLP` for details on the network that will be created to be a classifier. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. Returns: (tuple): Tuple containing: - **net**: The classifier network. - **class_hnet**: (optional) The classifier's hypernetwork. """ n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.class_incremental: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.training_task_infer or config.class_incremental: # task inference network config.out_dim = 1 # have all output neurons already build up for cl 2 if config.cl_scenario != 2: n_out = config.out_dim * config.num_tasks else: n_out = config.out_dim if config.training_task_infer or config.class_incremental: n_out = config.num_tasks # build classifier print('For the Classifier: ') class_arch = misc.str_to_ints(config.class_fc_arch) if config.training_with_hnet: no_weights = True else: no_weights = False net = MLP(n_in=n_in, n_out=n_out, hidden_layers=class_arch, activation_fn=misc.str_to_act(config.class_net_act), dropout_rate=config.class_dropout_rate, no_weights=no_weights).to(device) print('Constructed MLP with shapes: ', net.param_shapes) config.num_weights_class_net = \ MainNetInterface.shapes_to_num_weights(net.param_shapes) # build classifier hnet # this is set in the run method in train.py if config.training_with_hnet: class_hnet = sim_utils.get_hnet_model(config, config.num_tasks, device, net.param_shapes, cprefix='class_') init_params = list(class_hnet.parameters()) config.num_weights_class_hyper_net = sum( p.numel() for p in class_hnet.parameters() if p.requires_grad) config.compression_ratio_class = config.num_weights_class_hyper_net / \ config.num_weights_class_net print('Created classifier Hypernetwork with ratio: ', config.compression_ratio_class) if config.compression_ratio_class > 1: print('Note that the compression ratio is computed compared to ' + 'current target network, not might not be directly ' + 'comparable with the number of parameters of work we ' + 'compare against.') else: class_hnet = None init_params = list(net.parameters()) config.num_weights_class_hyper_net = None config.compression_ratio_class = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if config.training_with_hnet: for temb in class_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(class_hnet, 'chunk_embeddings'): for emb in class_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) if not config.training_with_hnet: return net else: return net, class_hnet
def get_hypernet(config, device, net_type, target_shapes, num_conds, no_cond_weights=False, no_uncond_weights=False, uncond_in_size=0, shmlp_chunk_shapes=None, shmlp_num_per_chunk=None, shmlp_assembly_fct=None, verbose=True, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. Note: The function expects command-line arguments available according to the function :func:`utils.cli_args.hnet_args`. device: PyTorch device. net_type (str): The type of network. The following options are available: - ``'hmlp'`` - ``'chunked_hmlp'`` - ``'structured_hmlp'`` - ``'hdeconv'`` - ``'chunked_hdeconv'`` target_shapes (list): See argument ``target_shapes`` of :class:`hnets.mlp_hnet.HMLP`. num_conds (int): Number of conditions that should be known to the hypernetwork. no_cond_weights (bool): See argument ``no_cond_weights`` of :class:`hnets.mlp_hnet.HMLP`. no_uncond_weights (bool): See argument ``no_uncond_weights`` of :class:`hnets.mlp_hnet.HMLP`. uncond_in_size (int): See argument ``uncond_in_size`` of :class:`hnets.mlp_hnet.HMLP`. shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this function are prefixed, since several hypernetworks should be generated. Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hnet_args`. """ assert net_type in [ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ] hnet = None ### FIXME Code almost identically copied from `get_mnet_model` ### if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) if hc('hnet_net_act'): net_act = gc('hnet_net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('hnet_no_bias') dropout_rate = get_val('hnet_dropout_rate') specnorm = get_val('hnet_specnorm') batchnorm = get_val('hnet_batchnorm') no_batchnorm = get_val('hnet_no_batchnorm') #bn_no_running_stats = get_val('hnet_bn_no_running_stats') #n_distill_stats = get_val('hnet_bn_distill_stats') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x ### FIXME Code copied until here ### if hc('hmlp_arch'): hmlp_arch_is_list = False hmlp_arch = gc('hmlp_arch') if ';' in hmlp_arch: hmlp_arch_is_list = True if net_type != 'structured_hmlp': raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) + 'contain semicolons for network type ' + '"structured_hmlp"!') hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')] else: hmlp_arch = misc.str_to_ints(hmlp_arch) if hc('chunk_emb_size'): chunk_emb_size = gc('chunk_emb_size') chunk_emb_size = misc.str_to_ints(chunk_emb_size) if len(chunk_emb_size) == 1: chunk_emb_size = chunk_emb_size[0] else: if net_type != 'structured_hmlp': raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) + 'only contain multiple values for network ' + 'type "structured_hmlp"!') if hc('cond_emb_size'): cond_emb_size = gc('cond_emb_size') else: cond_emb_size = 0 if net_type == 'hmlp': assert hc('hmlp_arch') # Default keyword arguments of class HMLP. dkws = misc.get_default_args(HMLP.__init__) hnet = HMLP(target_shapes, uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'chunked_hmlp': assert hc('hmlp_arch') assert hc('chmlp_chunk_size') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') # Default keyword arguments of class ChunkedHMLP. dkws = misc.get_default_args(ChunkedHMLP.__init__) hnet = ChunkedHMLP( target_shapes, gc('chmlp_chunk_size'), chunk_emb_size=chunk_emb_size, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'structured_hmlp': assert hc('hmlp_arch') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') assert shmlp_chunk_shapes is not None and \ shmlp_num_per_chunk is not None and \ shmlp_assembly_fct is not None # Default keyword arguments of class StructuredHMLP. dkws = misc.get_default_args(StructuredHMLP.__init__) dkws_hmlp = misc.get_default_args(HMLP.__init__) shmlp_hmlp_kwargs = [] if not hmlp_arch_is_list: hmlp_arch = [hmlp_arch] for i, arch in enumerate(hmlp_arch): shmlp_hmlp_kwargs.append({ 'layers': arch, 'activation_fn': assign(net_act, dkws_hmlp['activation_fn']), 'use_bias': assign(not no_bias, dkws_hmlp['use_bias']), 'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']), 'use_spectral_norm': \ assign(specnorm, dkws_hmlp['use_spectral_norm']), 'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm']) }) if len(shmlp_hmlp_kwargs) == 1: shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0] hnet = StructuredHMLP(target_shapes, shmlp_chunk_shapes, shmlp_num_per_chunk, chunk_emb_size, shmlp_hmlp_kwargs, shmlp_assembly_fct, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, verbose=verbose, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds).to(device) elif net_type == 'hdeconv': #HDeconv raise NotImplementedError else: assert net_type == 'chunked_hdeconv' #ChunkedHDeconv raise NotImplementedError return hnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. Returns: The created main network model. """ assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net']) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # FIXME if an argument wasn't specified, then we use the default value that # is currently (at time of implementation) in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, torch.nn.ReLU()), use_bias=not assign(no_bias, False), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, -1), use_spectral_norm=assign(specnorm, False), use_batch_norm=assign(use_bn, False), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, True), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False ).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, 0.25)).to(device) else: assert (net_type == 'bio_conv_net') assert (len(out_shape) == 1) raise NotImplementedError('Implementation not publicly available!') return mnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False, **mnet_kwargs): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``wrn``: :class:`mnets.wide_resnet.WRN` - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer` - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. **mnet_kwargs: Additional keyword arguments that will be passed to the main network constructor. Returns: The created main network model. """ assert (net_type in [ 'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp', 'simple_rnn', 'wrn', 'iresnet' ]) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') # This argument has to be handled during usage of the network and not during # construction. #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) # Default keyword arguments of class MLP. dkws = misc.get_default_args(MLP.__init__) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True, **mnet_kwargs).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) assert hc('resnet_block_depth') and hc('resnet_channel_sizes') # Default keyword arguments of class ResNet. dkws = misc.get_default_args(ResNet.__init__) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], n=gc('resnet_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')), verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'wrn': assert (len(out_shape) == 1) assert hc('wrn_block_depth') and hc('wrn_widening_factor') # Default keyword arguments of class WRN. dkws = misc.get_default_args(WRN.__init__) mnet = WRN( in_shape=in_shape, num_classes=out_shape[0], n=gc('wrn_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')), verbose=True, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), k=gc('wrn_widening_factor'), use_fc_bias=gc('wrn_use_fc_bias'), dropout_rate=gc('dropout_rate'), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'iresnet': assert (len(out_shape) == 1) assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \ and hc('iresnet_blocks_per_group') \ and hc('iresnet_bottleneck_blocks') \ and hc('iresnet_projection_shortcut') # Default keyword arguments of class WRN. dkws = misc.get_default_args(ResNetIN.__init__) mnet = ResNetIN( in_shape=in_shape, num_classes=out_shape[0], use_bias=assign(not no_bias, dkws['use_bias']), use_fc_bias=gc('wrn_use_fc_bias'), num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')), blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')), projection_shortcut=gc('iresnet_projection_shortcut'), bottleneck_blocks=gc('iresnet_bottleneck_blocks'), #cutout_mod=False, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #chw_input_format=False, verbose=True, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) # Default keyword arguments of class ZenkeNet. dkws = misc.get_default_args(ZenkeNet.__init__) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), **mnet_kwargs).to(device) elif net_type == 'bio_conv_net': assert (len(out_shape) == 1) # Default keyword arguments of class BioConvNet. #dkws = misc.get_default_args(BioConvNet.__init__) mnet = BioConvNet( in_shape=in_shape, num_classes=out_shape[0], no_weights=no_weights, #init_weights=None, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'chunked_mlp': assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \ hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \ hc('cmlp_cemb_dim') assert len(in_shape) == 1 and len(out_shape) == 1 # Default keyword arguments of class ChunkSqueezer. dkws = misc.get_default_args(ChunkSqueezer.__init__) mnet = ChunkSqueezer( n_in=in_shape[0], n_out=out_shape[0], inp_chunk_dim=gc('cmlp_in_cdim'), out_chunk_dim=gc('cmlp_out_cdim'), cemb_size=gc('cmlp_cemb_dim'), #cemb_init_std=1., red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')), net_layers=misc.str_to_ints(gc('cmlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), #dynamic_biases=None, no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), verbose=True, **mnet_kwargs).to(device) elif net_type == 'lenet': assert hc('lenet_type') assert len(out_shape) == 1 # Default keyword arguments of class LeNet. dkws = misc.get_default_args(LeNet.__init__) mnet = LeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, arch=gc('lenet_type'), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), # TODO Context-mod weights. **mnet_kwargs).to(device) else: assert (net_type == 'simple_rnn') assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \ hc('srnn_post_fc_layers') and hc('srnn_no_fc_out') and \ hc('srnn_rec_type') assert len(in_shape) == 1 and len(out_shape) == 1 if gc('srnn_rec_type') == 'lstm': use_lstm = True else: assert gc('srnn_rec_type') == 'elman' use_lstm = False # Default keyword arguments of class SimpleRNN. dkws = misc.get_default_args(SimpleRNN.__init__) rnn_layers = misc.str_to_ints(gc('srnn_rec_layers')) fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers')) if gc('srnn_no_fc_out'): rnn_layers.append(out_shape[0]) else: fc_layers.append(out_shape[0]) mnet = SimpleRNN(n_in=in_shape[0], rnn_layers=rnn_layers, fc_layers_pre=misc.str_to_ints( gc('srnn_pre_fc_layers')), fc_layers=fc_layers, activation=assign(net_act, dkws['activation']), use_lstm=use_lstm, use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, verbose=True, **mnet_kwargs).to(device) return mnet
def _generate_networks(config, data_handlers, device, create_hnet=True, create_rnet=False, no_replay=False): """Create the main-net, hypernetwork and recognition network. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. create_hnet: Whether a hypernetwork should be constructed. If not, the main network will have trainable weights. create_rnet: Whether a task-recognition autoencoder should be created. no_replay: If the recognition network should be an instance of class MainModel rather than of class RecognitionNet (note, for multitask learning, no replay network is required). Returns: mnet: Main network instance. hnet: Hypernetwork instance. This return value is None if no hypernetwork should be constructed. rnet: RecognitionNet instance. This return value is None if no recognition network should be constructed. """ num_tasks = len(data_handlers) n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks main_arch = misc.str_to_ints(config.main_arch) main_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=n_y, hidden_layers=main_arch) mnet = MainNetwork(main_shapes, activation_fn=misc.str_to_act(config.main_act), use_bias=True, no_weights=create_hnet).to(device) if create_hnet: hnet_arch = misc.str_to_ints(config.hnet_arch) hnet = HyperNetwork(main_shapes, num_tasks, layers=hnet_arch, te_dim=config.emb_size, activation_fn=misc.str_to_act( config.hnet_act)).to(device) init_params = list(hnet.parameters()) else: hnet = None init_params = list(mnet.parameters()) if create_rnet: ae_arch = misc.str_to_ints(config.ae_arch) if no_replay: rnet_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=num_tasks, hidden_layers=ae_arch, use_bias=True) rnet = MainNetwork(rnet_shapes, activation_fn=misc.str_to_act(config.ae_act), use_bias=True, no_weights=False, dropout_rate=-1, out_fn=lambda x: F.softmax(x, dim=1)) else: rnet = RecognitionNet(n_x, num_tasks, dim_z=config.ae_dim_z, enc_layers=ae_arch, activation_fn=misc.str_to_act(config.ae_act), use_bias=True).to(device) init_params += list(rnet.parameters()) else: rnet = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) elif config.normal_init: torch.nn.init.normal_(W, mean=0, std=config.std_normal_init) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_hnet: for temb in hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if config.use_hyperfan_init: hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2) return mnet, hnet, rnet
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None, no_weights=False, no_tembs=False, temb_size=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. .. deprecated:: 1.0 Please use function :func:`get_hypernet` instead. As this function creates deprecated hypernetworks. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.hypernet_args`. num_tasks (int): The number of task embeddings the hypernetwork should have. device: PyTorch device. mnet_shapes: Dimensions of the weight tensors of the main network. See main net argument :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hypernet_args`. no_weights (bool): Whether the hyper network should be generated without internal weights (excluding task embeddings). no_tembs (bool): Whether the hypernetwork should be generated without internally maintained task embeddings. temb_size (int, optional): If user config should be overwritten, then this option can be used to specify the dimensionality of task embeddings. Returns: The created hypernet model. """ warn('Please use function "utils.sim_utils.get_hypernet" instead. As ' +\ 'this function creates deprecated hypernetworks.', DeprecationWarning) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) hyper_chunks = misc.str_to_ints(gc('hyper_chunks')) assert(len(hyper_chunks) in [1,2,3]) if len(hyper_chunks) == 1: hyper_chunks = hyper_chunks[0] hnet_arch = misc.str_to_ints(gc('hnet_arch')) sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters')) sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels')) sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers')) hnet_act = misc.str_to_act(gc('hnet_act')) if temb_size is None: temb_size = gc('temb_size') if isinstance(hyper_chunks, list): # Chunked self-attention hypernet raise NotImplementedError('Not publicly available') elif hyper_chunks != -1: # Chunked fully-connected hypernet hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks, chunk_dim=hyper_chunks, layers=hnet_arch, activation_fn=hnet_act, te_dim=temb_size, no_te_embs=no_tembs, ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), no_weights=no_weights, temb_std=gc('temb_std')).to(device) else: # Fully-connected hypernet. hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch, te_dim=temb_size, no_te_embs=no_tembs, activation_fn=hnet_act, dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), no_weights=no_weights, temb_std=gc('temb_std')).to(device) return hnet
def generate_replay_networks(config, data_handlers, device, create_rp_hnet=True, only_train_replay=False): """Create a replay model that consists of either a encoder/decoder or a discriminator/generator pair. Additionally, this method manages the creation of a hypernetwork for the generator/decoder. Following important configurations will be determined in order to create the replay model: * in- and output and hidden layer dimensions of the encoder/decoder. * architecture, chunk- and task-embedding details of decoder's hypernetwork. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device.. create_rp_hnet: Whether a hypernetwork for the replay should be constructed. If not, the decoder/generator will have trainable weights on its own. only_train_replay: We normally do not train on the last task since we do not need to replay this last tasks data. But if we want a replay method to be able to generate data from all tasks then we set this option to true. Returns: (tuple): Tuple containing: - **enc**: Encoder/discriminator network instance. - **dec**: Decoder/generator networkinstance. - **dec_hnet**: Hypernetwork instance for the decoder/generator. This return value is None if no hypernetwork should be constructed. """ if config.replay_method == 'gan': n_out = 1 else: n_out = config.latent_dim * 2 n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.single_class_replay: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.infer_task_id: # task inference network config.out_dim = 1 # builld encoder print('For the replay encoder/discriminator: ') enc_arch = misc.str_to_ints(config.enc_fc_arch) enc = MLP(n_in=n_in, n_out=n_out, hidden_layers=enc_arch, activation_fn=misc.str_to_act(config.enc_net_act), dropout_rate=config.enc_dropout_rate, no_weights=False).to(device) print('Constructed MLP with shapes: ', enc.param_shapes) init_params = list(enc.parameters()) # builld decoder print('For the replay decoder/generator: ') dec_arch = misc.str_to_ints(config.dec_fc_arch) # add dimensions for conditional input n_out = config.latent_dim if config.conditional_replay: n_out += config.conditional_dim dec = MLP(n_in=n_out, n_out=n_in, hidden_layers=dec_arch, activation_fn=misc.str_to_act(config.dec_net_act), use_bias=True, no_weights=config.rp_beta > 0, dropout_rate=config.dec_dropout_rate).to(device) print('Constructed MLP with shapes: ', dec.param_shapes) config.num_weights_enc = \ MainNetInterface.shapes_to_num_weights(enc.param_shapes) config.num_weights_dec = \ MainNetInterface.shapes_to_num_weights(dec.param_shapes) config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec # we do not need a replay model for the last task # train on last task or not if only_train_replay: subtr = 0 else: subtr = 1 num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1 if config.single_class_replay: # we do not need a replay model for the last task if config.num_tasks > 1: num_embeddings = config.out_dim * (config.num_tasks - subtr) else: num_embeddings = config.out_dim * (config.num_tasks) config.num_embeddings = num_embeddings # build decoder hnet if create_rp_hnet: print('For the decoder/generator hypernetwork: ') d_hnet = sim_utils.get_hnet_model(config, num_embeddings, device, dec.hyper_shapes_learned, cprefix='rp_') init_params += list(d_hnet.parameters()) config.num_weights_rp_hyper_net = sum(p.numel() for p in d_hnet.parameters() if p.requires_grad) config.compression_ratio_rp = config.num_weights_rp_hyper_net / \ config.num_weights_dec print('Created replay hypernetwork with ratio: ', config.compression_ratio_rp) if config.compression_ratio_rp > 1: print('Note that the compression ratio is computed compared to ' + 'current target network,\nthis might not be directly ' + 'comparable with the number of parameters of methods we ' + 'compare against.') else: num_embeddings = config.num_tasks - subtr d_hnet = None init_params += list(dec.parameters()) config.num_weights_rp_hyper_net = 0 config.compression_ratio_rp = 0 ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_rp_hnet: for temb in d_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(d_hnet, 'chunk_embeddings'): for emb in d_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) return enc, dec, d_hnet
def generate_gauss_networks(config, logger, data_handlers, device, no_mnet_weights=None, create_hnet=True, in_shape=None, out_shape=None, net_type='mlp', non_gaussian=False): """Create main network and potentially the corresponding hypernetwork. The function will first create a normal MLP and then convert it into a network with Gaussian weight distribution by using the wrapper :class:`probabilistic.gauss_mnet_interface.GaussianBNNWrapper`. This function also takes care of weight initialization. Args: config: Command-line arguments. logger: Console (and file) logger. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. no_mnet_weights (bool, optional): Whether the main network should not have trainable weights. If left unspecified, then the main network will only have trainable weights if ``create_hnet`` is ``False``. create_hnet (bool): Whether a hypernetwork should be constructed. in_shape (list, optional): Input shape that is passed to function :func:`utils.sim_utils.get_mnet_model` as argument ``in_shape``. If not specified, it is set to ``[data_handlers[0].in_shape[0]]``. out_shape (list, optional): Output shape that is passed to function :func:`utils.sim_utils.get_mnet_model` as argument ``out_shape``. If not specified, it is set to ``[data_handlers[0].out_shape[0]]``. net_type (str): See argument ``net_type`` of function :func:`utils.sim_utils.get_mnet_model`. non_gaussian (bool): If ``True``, then the main network will not be converted into a Gaussian network. Hence, networks remain deterministic. Returns: (tuple): Tuple containing: - **mnet**: Main network instance. - **hnet** (optional): Hypernetwork instance. This return value is ``None`` if no hypernetwork should be constructed. """ assert not hasattr(config, 'mean_only') or config.mean_only == non_gaussian assert not non_gaussian or not config.local_reparam_trick assert not non_gaussian or not config.hyper_gauss_init num_tasks = len(data_handlers) # Should be set, except for regression. if in_shape is None or out_shape is None: assert in_shape is None and out_shape is None assert net_type == 'mlp' assert hasattr(config, 'multi_head') n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks in_shape = [n_x] out_shape = [n_y] ### Main network. logger.info('Creating main network ...') if no_mnet_weights is None: no_mnet_weights = create_hnet if config.local_reparam_trick: if net_type != 'mlp': raise NotImplementedError('The local reparametrization trick is ' + 'only implemented for MLPs so far!') assert len(in_shape) == 1 and len(out_shape) == 1 mlp_arch = utils.str_to_ints(config.mlp_arch) net_act = utils.str_to_act(config.net_act) mnet = GaussianMLP(n_in=in_shape[0], n_out=out_shape[0], hidden_layers=mlp_arch, activation_fn=net_act, use_bias=not config.no_bias, no_weights=no_mnet_weights).to(device) else: mnet_kwargs = {} if net_type == 'iresnet': mnet_kwargs['cutout_mod'] = True mnet = sutils.get_mnet_model(config, net_type, in_shape, out_shape, device, no_weights=no_mnet_weights, **mnet_kwargs) # Initiaize main net weights, if any. assert (not hasattr(config, 'custom_network_init')) mnet.custom_init(normal_init=config.normal_init, normal_std=config.std_normal_init, zero_bias=True) # Convert main net into Gaussian BNN. orig_mnet = mnet if not non_gaussian: mnet = GaussianBNNWrapper(mnet, no_mean_reinit=config.keep_orig_init, logvar_encoding=config.use_logvar_enc, apply_rho_offset=True, is_radial=config.radial_bnn).to(device) else: logger.debug('Created main network will not be converted into a ' + 'Gaussian main network.') ### Hypernet. hnet = None if create_hnet: logger.info('Creating hypernetwork ...') chunk_shapes, num_per_chunk, assembly_fct = None, None, None if config.hnet_type == 'structured_hmlp': if net_type == 'resnet': chunk_shapes, num_per_chunk, orig_assembly_fct = \ resnet_chunking(orig_mnet, gcd_chunking=config.shmlp_gcd_chunking) elif net_type == 'wrn': chunk_shapes, num_per_chunk, orig_assembly_fct = \ wrn_chunking(orig_mnet, gcd_chunking=config.shmlp_gcd_chunking, ignore_bn_weights=False, ignore_out_weights=False) else: raise NotImplementedError( '"structured_hmlp" not implemented ' + 'for network of type %s.' % net_type) if non_gaussian: assembly_fct = orig_assembly_fct else: chunk_shapes = chunk_shapes + chunk_shapes num_per_chunk = num_per_chunk + num_per_chunk def assembly_fct_gauss(list_of_chunks): n = len(list_of_chunks) mean_chunks = list_of_chunks[:n // 2] rho_chunks = list_of_chunks[n // 2:] return orig_assembly_fct(mean_chunks) + \ orig_assembly_fct(rho_chunks) assembly_fct = assembly_fct_gauss # For now, we either produce all or no weights with the hypernet. # Note, it can be that the mnet was produced with internal weights. assert mnet.hyper_shapes_learned is None or \ len(mnet.param_shapes) == len(mnet.hyper_shapes_learned) hnet = sutils.get_hypernet(config, device, config.hnet_type, mnet.param_shapes, num_tasks, shmlp_chunk_shapes=chunk_shapes, shmlp_num_per_chunk=num_per_chunk, shmlp_assembly_fct=assembly_fct) if config.hnet_out_masking != 0: logger.info('Generating binary masks to select task-specific ' + 'subnetworks from hypernetwork.') # Add a wrapper around the hypernpetwork that masks its outputs # using a task-specific binary mask layer per layer. Note that # output weights are not masked. # Ensure that masks are kind of deterministic for a given hyper- # param config/task. mask_gen = torch.Generator() mask_gen = mask_gen.manual_seed(42) # Generate a random binary mask per task. assert len(mnet.param_shapes) == len(hnet.target_shapes) hnet_out_masks = [] for tid in range(config.num_tasks): hnet_out_mask = [] for layer_shapes, is_output in zip(mnet.param_shapes, \ mnet.get_output_weight_mask()): layer_mask = torch.ones(layer_shapes) if is_output is None: # We only mask weights that are not output weights. layer_mask = torch.rand(layer_shapes, generator=mask_gen) layer_mask[layer_mask > config.hnet_out_masking] = 1 layer_mask[layer_mask <= config.hnet_out_masking] = 0 hnet_out_mask.append(layer_mask) hnet_out_masks.append(hnet_out_mask) hnet_out_masks = hnet.convert_out_format(hnet_out_masks, 'sequential', 'flattened') def hnet_out_masking_func(hnet_out_int, uncond_input=None, cond_input=None, cond_id=None): assert isinstance(cond_id, (int, list)) if isinstance(cond_id, int): cond_id = [cond_id] hnet_out_int[hnet_out_masks[cond_id, :] == 0] = 0 return hnet_out_int def hnet_inp_handler(uncond_input=None, cond_input=None, cond_id=None): # Identity return uncond_input, cond_input, cond_id hnet = HPerturbWrapper(hnet, output_handler=hnet_out_masking_func, input_handler=hnet_inp_handler) #if config.hnet_type == 'structured_hmlp': # print(num_per_chunk) # for ii, int_hnet in enumerate(hnet.internal_hnets): # print(' Internal hnet %d with %d outputs.' % \ # (ii, int_hnet.num_outputs)) ### Initialize hypernetwork. if not config.hyper_gauss_init: apply_custom_hnet_init(config, logger, hnet) else: # Initialize task embeddings, if any. hnet_helpers.init_conditional_embeddings( hnet, normal_std=config.std_normal_temb) gauss_hyperfan_init(hnet, mnet=mnet, use_xavier=True, cond_var=config.std_normal_temb**2, keep_hyperfan_mean=config.keep_orig_init) return mnet, hnet