def _parse_context_mod_args(cm_kwargs): """Parse context-modulation arguments for a class. This function first loads the default values of all context-mod arguments passed to class :class:`mnets.mlp.MLP`. If any of these arguments is not occurring in the dictionary ``cm_kwargs``, then they will be added using the default value from class :class:`mnets.mlp.MLP`. Args: cm_kwargs (dict): A dictionary, that is modified in place (i.e., missing keys are added). Returns: (list): A list of key names from ``cm_kwargs`` that are not related to context-modulation, i.e., unknown to this function. """ from mnets.mlp import MLP # All context-mod related arguments in `mnets.mlp.MLP.__init__`. cm_keys = ['use_context_mod', 'context_mod_inputs', 'no_last_layer_context_mod', 'context_mod_no_weights', 'context_mod_post_activation', 'context_mod_gain_offset', 'context_mod_gain_softplus'] default_cm_kwargs = misc.get_default_args(MLP.__init__) for k in cm_keys: assert k in default_cm_kwargs.keys() if k not in cm_kwargs.keys(): cm_kwargs[k] = default_cm_kwargs[k] # Extract keyword arguments that do not belong to context-mod. unknown_kwargs = [] for k in cm_kwargs.keys(): if k not in default_cm_kwargs.keys(): unknown_kwargs.append(k) return unknown_kwargs
def get_hypernet(config, device, net_type, target_shapes, num_conds, no_cond_weights=False, no_uncond_weights=False, uncond_in_size=0, shmlp_chunk_shapes=None, shmlp_num_per_chunk=None, shmlp_assembly_fct=None, verbose=True, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. Note: The function expects command-line arguments available according to the function :func:`utils.cli_args.hnet_args`. device: PyTorch device. net_type (str): The type of network. The following options are available: - ``'hmlp'`` - ``'chunked_hmlp'`` - ``'structured_hmlp'`` - ``'hdeconv'`` - ``'chunked_hdeconv'`` target_shapes (list): See argument ``target_shapes`` of :class:`hnets.mlp_hnet.HMLP`. num_conds (int): Number of conditions that should be known to the hypernetwork. no_cond_weights (bool): See argument ``no_cond_weights`` of :class:`hnets.mlp_hnet.HMLP`. no_uncond_weights (bool): See argument ``no_uncond_weights`` of :class:`hnets.mlp_hnet.HMLP`. uncond_in_size (int): See argument ``uncond_in_size`` of :class:`hnets.mlp_hnet.HMLP`. shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this function are prefixed, since several hypernetworks should be generated. Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hnet_args`. """ assert net_type in [ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ] hnet = None ### FIXME Code almost identically copied from `get_mnet_model` ### if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) if hc('hnet_net_act'): net_act = gc('hnet_net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('hnet_no_bias') dropout_rate = get_val('hnet_dropout_rate') specnorm = get_val('hnet_specnorm') batchnorm = get_val('hnet_batchnorm') no_batchnorm = get_val('hnet_no_batchnorm') #bn_no_running_stats = get_val('hnet_bn_no_running_stats') #n_distill_stats = get_val('hnet_bn_distill_stats') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x ### FIXME Code copied until here ### if hc('hmlp_arch'): hmlp_arch_is_list = False hmlp_arch = gc('hmlp_arch') if ';' in hmlp_arch: hmlp_arch_is_list = True if net_type != 'structured_hmlp': raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) + 'contain semicolons for network type ' + '"structured_hmlp"!') hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')] else: hmlp_arch = misc.str_to_ints(hmlp_arch) if hc('chunk_emb_size'): chunk_emb_size = gc('chunk_emb_size') chunk_emb_size = misc.str_to_ints(chunk_emb_size) if len(chunk_emb_size) == 1: chunk_emb_size = chunk_emb_size[0] else: if net_type != 'structured_hmlp': raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) + 'only contain multiple values for network ' + 'type "structured_hmlp"!') if hc('cond_emb_size'): cond_emb_size = gc('cond_emb_size') else: cond_emb_size = 0 if net_type == 'hmlp': assert hc('hmlp_arch') # Default keyword arguments of class HMLP. dkws = misc.get_default_args(HMLP.__init__) hnet = HMLP(target_shapes, uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'chunked_hmlp': assert hc('hmlp_arch') assert hc('chmlp_chunk_size') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') # Default keyword arguments of class ChunkedHMLP. dkws = misc.get_default_args(ChunkedHMLP.__init__) hnet = ChunkedHMLP( target_shapes, gc('chmlp_chunk_size'), chunk_emb_size=chunk_emb_size, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'structured_hmlp': assert hc('hmlp_arch') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') assert shmlp_chunk_shapes is not None and \ shmlp_num_per_chunk is not None and \ shmlp_assembly_fct is not None # Default keyword arguments of class StructuredHMLP. dkws = misc.get_default_args(StructuredHMLP.__init__) dkws_hmlp = misc.get_default_args(HMLP.__init__) shmlp_hmlp_kwargs = [] if not hmlp_arch_is_list: hmlp_arch = [hmlp_arch] for i, arch in enumerate(hmlp_arch): shmlp_hmlp_kwargs.append({ 'layers': arch, 'activation_fn': assign(net_act, dkws_hmlp['activation_fn']), 'use_bias': assign(not no_bias, dkws_hmlp['use_bias']), 'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']), 'use_spectral_norm': \ assign(specnorm, dkws_hmlp['use_spectral_norm']), 'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm']) }) if len(shmlp_hmlp_kwargs) == 1: shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0] hnet = StructuredHMLP(target_shapes, shmlp_chunk_shapes, shmlp_num_per_chunk, chunk_emb_size, shmlp_hmlp_kwargs, shmlp_assembly_fct, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, verbose=verbose, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds).to(device) elif net_type == 'hdeconv': #HDeconv raise NotImplementedError else: assert net_type == 'chunked_hdeconv' #ChunkedHDeconv raise NotImplementedError return hnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False, **mnet_kwargs): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``wrn``: :class:`mnets.wide_resnet.WRN` - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer` - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. **mnet_kwargs: Additional keyword arguments that will be passed to the main network constructor. Returns: The created main network model. """ assert (net_type in [ 'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp', 'simple_rnn', 'wrn', 'iresnet' ]) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') # This argument has to be handled during usage of the network and not during # construction. #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) # Default keyword arguments of class MLP. dkws = misc.get_default_args(MLP.__init__) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True, **mnet_kwargs).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) assert hc('resnet_block_depth') and hc('resnet_channel_sizes') # Default keyword arguments of class ResNet. dkws = misc.get_default_args(ResNet.__init__) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], n=gc('resnet_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')), verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'wrn': assert (len(out_shape) == 1) assert hc('wrn_block_depth') and hc('wrn_widening_factor') # Default keyword arguments of class WRN. dkws = misc.get_default_args(WRN.__init__) mnet = WRN( in_shape=in_shape, num_classes=out_shape[0], n=gc('wrn_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')), verbose=True, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), k=gc('wrn_widening_factor'), use_fc_bias=gc('wrn_use_fc_bias'), dropout_rate=gc('dropout_rate'), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'iresnet': assert (len(out_shape) == 1) assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \ and hc('iresnet_blocks_per_group') \ and hc('iresnet_bottleneck_blocks') \ and hc('iresnet_projection_shortcut') # Default keyword arguments of class WRN. dkws = misc.get_default_args(ResNetIN.__init__) mnet = ResNetIN( in_shape=in_shape, num_classes=out_shape[0], use_bias=assign(not no_bias, dkws['use_bias']), use_fc_bias=gc('wrn_use_fc_bias'), num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')), blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')), projection_shortcut=gc('iresnet_projection_shortcut'), bottleneck_blocks=gc('iresnet_bottleneck_blocks'), #cutout_mod=False, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #chw_input_format=False, verbose=True, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) # Default keyword arguments of class ZenkeNet. dkws = misc.get_default_args(ZenkeNet.__init__) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), **mnet_kwargs).to(device) elif net_type == 'bio_conv_net': assert (len(out_shape) == 1) # Default keyword arguments of class BioConvNet. #dkws = misc.get_default_args(BioConvNet.__init__) mnet = BioConvNet( in_shape=in_shape, num_classes=out_shape[0], no_weights=no_weights, #init_weights=None, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'chunked_mlp': assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \ hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \ hc('cmlp_cemb_dim') assert len(in_shape) == 1 and len(out_shape) == 1 # Default keyword arguments of class ChunkSqueezer. dkws = misc.get_default_args(ChunkSqueezer.__init__) mnet = ChunkSqueezer( n_in=in_shape[0], n_out=out_shape[0], inp_chunk_dim=gc('cmlp_in_cdim'), out_chunk_dim=gc('cmlp_out_cdim'), cemb_size=gc('cmlp_cemb_dim'), #cemb_init_std=1., red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')), net_layers=misc.str_to_ints(gc('cmlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), #dynamic_biases=None, no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), verbose=True, **mnet_kwargs).to(device) elif net_type == 'lenet': assert hc('lenet_type') assert len(out_shape) == 1 # Default keyword arguments of class LeNet. dkws = misc.get_default_args(LeNet.__init__) mnet = LeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, arch=gc('lenet_type'), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), # TODO Context-mod weights. **mnet_kwargs).to(device) else: assert (net_type == 'simple_rnn') assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \ hc('srnn_post_fc_layers') and hc('srnn_no_fc_out') and \ hc('srnn_rec_type') assert len(in_shape) == 1 and len(out_shape) == 1 if gc('srnn_rec_type') == 'lstm': use_lstm = True else: assert gc('srnn_rec_type') == 'elman' use_lstm = False # Default keyword arguments of class SimpleRNN. dkws = misc.get_default_args(SimpleRNN.__init__) rnn_layers = misc.str_to_ints(gc('srnn_rec_layers')) fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers')) if gc('srnn_no_fc_out'): rnn_layers.append(out_shape[0]) else: fc_layers.append(out_shape[0]) mnet = SimpleRNN(n_in=in_shape[0], rnn_layers=rnn_layers, fc_layers_pre=misc.str_to_ints( gc('srnn_pre_fc_layers')), fc_layers=fc_layers, activation=assign(net_act, dkws['activation']), use_lstm=use_lstm, use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, verbose=True, **mnet_kwargs).to(device) return mnet