def generate_classifier(config, data_handlers, device): """Create a classifier network. Depending on the experiment and method, the method manages to build either a classifier for task inference or a classifier that solves our task is build. This also implies if the network will receive weights from a hypernetwork or will have weights on its own. Following important configurations will be determined in order to create the classifier: \n * in- and output and hidden layer dimensions of the classifier. \n * architecture, chunk- and task-embedding details of the hypernetwork. See :class:`mnets.mlp.MLP` for details on the network that will be created to be a classifier. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. Returns: (tuple): Tuple containing: - **net**: The classifier network. - **class_hnet**: (optional) The classifier's hypernetwork. """ n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.class_incremental: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.training_task_infer or config.class_incremental: # task inference network config.out_dim = 1 # have all output neurons already build up for cl 2 if config.cl_scenario != 2: n_out = config.out_dim * config.num_tasks else: n_out = config.out_dim if config.training_task_infer or config.class_incremental: n_out = config.num_tasks # build classifier print('For the Classifier: ') class_arch = misc.str_to_ints(config.class_fc_arch) if config.training_with_hnet: no_weights = True else: no_weights = False net = MLP(n_in=n_in, n_out=n_out, hidden_layers=class_arch, activation_fn=misc.str_to_act(config.class_net_act), dropout_rate=config.class_dropout_rate, no_weights=no_weights).to(device) print('Constructed MLP with shapes: ', net.param_shapes) config.num_weights_class_net = \ MainNetInterface.shapes_to_num_weights(net.param_shapes) # build classifier hnet # this is set in the run method in train.py if config.training_with_hnet: class_hnet = sim_utils.get_hnet_model(config, config.num_tasks, device, net.param_shapes, cprefix='class_') init_params = list(class_hnet.parameters()) config.num_weights_class_hyper_net = sum( p.numel() for p in class_hnet.parameters() if p.requires_grad) config.compression_ratio_class = config.num_weights_class_hyper_net / \ config.num_weights_class_net print('Created classifier Hypernetwork with ratio: ', config.compression_ratio_class) if config.compression_ratio_class > 1: print('Note that the compression ratio is computed compared to ' + 'current target network, not might not be directly ' + 'comparable with the number of parameters of work we ' + 'compare against.') else: class_hnet = None init_params = list(net.parameters()) config.num_weights_class_hyper_net = None config.compression_ratio_class = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if config.training_with_hnet: for temb in class_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(class_hnet, 'chunk_embeddings'): for emb in class_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) if not config.training_with_hnet: return net else: return net, class_hnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. Returns: The created main network model. """ assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net']) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # FIXME if an argument wasn't specified, then we use the default value that # is currently (at time of implementation) in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, torch.nn.ReLU()), use_bias=not assign(no_bias, False), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, -1), use_spectral_norm=assign(specnorm, False), use_batch_norm=assign(use_bn, False), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, True), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False ).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, 0.25)).to(device) else: assert (net_type == 'bio_conv_net') assert (len(out_shape) == 1) raise NotImplementedError('Implementation not publicly available!') return mnet
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.hypernet_args`. num_tasks (int): The number of task embeddings the hypernetwork should have. device: PyTorch device. mnet_shapes: Dimensions of the weight tensors of the main network. See main net argument :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hypernet_args`. Returns: The created hypernet model. """ if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) hyper_chunks = misc.str_to_ints(gc('hyper_chunks')) assert (len(hyper_chunks) in [1, 2, 3]) if len(hyper_chunks) == 1: hyper_chunks = hyper_chunks[0] hnet_arch = misc.str_to_ints(gc('hnet_arch')) sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters')) sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels')) sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers')) hnet_act = misc.str_to_act(gc('hnet_act')) if isinstance(hyper_chunks, list): # Chunked self-attention hypernet if len(sa_hnet_kernels) == 1: sa_hnet_kernels = sa_hnet_kernels[0] # Note, that the user can specify the kernel size for each dimension and # layer separately. elif len(sa_hnet_kernels) > 2 and \ len(sa_hnet_kernels) == gc('sa_hnet_num_layers') * 2: tmp = sa_hnet_kernels sa_hnet_kernels = [] for i in range(0, len(tmp), 2): sa_hnet_kernels.append([tmp[i], tmp[i + 1]]) if gc('hnet_dropout_rate') != -1: warn('SA-Hypernet doesn\'t use dropout. Dropout rate will be ' + 'ignored.') if gc('hnet_act') != 'relu': warn('SA-Hypernet doesn\'t support the other non-linearities ' + 'than ReLUs yet. Option "%shnet_act" (%s) will be ignored.' % (cprefix, gc('hnet_act'))) hnet = SAHyperNetwork( mnet_shapes, num_tasks, out_size=hyper_chunks, num_layers=gc('sa_hnet_num_layers'), num_filters=sa_hnet_filters, kernel_size=sa_hnet_kernels, sa_units=sa_hnet_attention_layers, # Note, we don't use an additional hypernet for the remaining # weights! #rem_layers=hnet_arch, te_dim=gc('temb_size'), ce_dim=gc('emb_size'), no_theta=False, # Batchnorm and spectral norma are not yet implemented. #use_batch_norm=gc('hnet_batchnorm'), #use_spectral_norm=gc('hnet_specnorm'), # Droput would only be used for the additional network, which we # don't use. #dropout_rate=gc('hnet_dropout_rate'), discard_remainder=True, noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) elif hyper_chunks != -1: # Chunked fully-connected hypernet hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks, chunk_dim=hyper_chunks, layers=hnet_arch, activation_fn=hnet_act, te_dim=gc('temb_size'), ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) else: # Fully-connected hypernet. hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch, te_dim=gc('temb_size'), activation_fn=hnet_act, dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), temb_std=gc('temb_std')).to(device) return hnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False, **mnet_kwargs): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``wrn``: :class:`mnets.wide_resnet.WRN` - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer` - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. **mnet_kwargs: Additional keyword arguments that will be passed to the main network constructor. Returns: The created main network model. """ assert (net_type in [ 'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp', 'simple_rnn', 'wrn', 'iresnet' ]) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') # This argument has to be handled during usage of the network and not during # construction. #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) # Default keyword arguments of class MLP. dkws = misc.get_default_args(MLP.__init__) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True, **mnet_kwargs).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) assert hc('resnet_block_depth') and hc('resnet_channel_sizes') # Default keyword arguments of class ResNet. dkws = misc.get_default_args(ResNet.__init__) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], n=gc('resnet_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')), verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'wrn': assert (len(out_shape) == 1) assert hc('wrn_block_depth') and hc('wrn_widening_factor') # Default keyword arguments of class WRN. dkws = misc.get_default_args(WRN.__init__) mnet = WRN( in_shape=in_shape, num_classes=out_shape[0], n=gc('wrn_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')), verbose=True, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), k=gc('wrn_widening_factor'), use_fc_bias=gc('wrn_use_fc_bias'), dropout_rate=gc('dropout_rate'), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'iresnet': assert (len(out_shape) == 1) assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \ and hc('iresnet_blocks_per_group') \ and hc('iresnet_bottleneck_blocks') \ and hc('iresnet_projection_shortcut') # Default keyword arguments of class WRN. dkws = misc.get_default_args(ResNetIN.__init__) mnet = ResNetIN( in_shape=in_shape, num_classes=out_shape[0], use_bias=assign(not no_bias, dkws['use_bias']), use_fc_bias=gc('wrn_use_fc_bias'), num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')), blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')), projection_shortcut=gc('iresnet_projection_shortcut'), bottleneck_blocks=gc('iresnet_bottleneck_blocks'), #cutout_mod=False, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #chw_input_format=False, verbose=True, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) # Default keyword arguments of class ZenkeNet. dkws = misc.get_default_args(ZenkeNet.__init__) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), **mnet_kwargs).to(device) elif net_type == 'bio_conv_net': assert (len(out_shape) == 1) # Default keyword arguments of class BioConvNet. #dkws = misc.get_default_args(BioConvNet.__init__) mnet = BioConvNet( in_shape=in_shape, num_classes=out_shape[0], no_weights=no_weights, #init_weights=None, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'chunked_mlp': assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \ hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \ hc('cmlp_cemb_dim') assert len(in_shape) == 1 and len(out_shape) == 1 # Default keyword arguments of class ChunkSqueezer. dkws = misc.get_default_args(ChunkSqueezer.__init__) mnet = ChunkSqueezer( n_in=in_shape[0], n_out=out_shape[0], inp_chunk_dim=gc('cmlp_in_cdim'), out_chunk_dim=gc('cmlp_out_cdim'), cemb_size=gc('cmlp_cemb_dim'), #cemb_init_std=1., red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')), net_layers=misc.str_to_ints(gc('cmlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), #dynamic_biases=None, no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), verbose=True, **mnet_kwargs).to(device) elif net_type == 'lenet': assert hc('lenet_type') assert len(out_shape) == 1 # Default keyword arguments of class LeNet. dkws = misc.get_default_args(LeNet.__init__) mnet = LeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, arch=gc('lenet_type'), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), # TODO Context-mod weights. **mnet_kwargs).to(device) else: assert (net_type == 'simple_rnn') assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \ hc('srnn_post_fc_layers') and hc('srnn_no_fc_out') and \ hc('srnn_rec_type') assert len(in_shape) == 1 and len(out_shape) == 1 if gc('srnn_rec_type') == 'lstm': use_lstm = True else: assert gc('srnn_rec_type') == 'elman' use_lstm = False # Default keyword arguments of class SimpleRNN. dkws = misc.get_default_args(SimpleRNN.__init__) rnn_layers = misc.str_to_ints(gc('srnn_rec_layers')) fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers')) if gc('srnn_no_fc_out'): rnn_layers.append(out_shape[0]) else: fc_layers.append(out_shape[0]) mnet = SimpleRNN(n_in=in_shape[0], rnn_layers=rnn_layers, fc_layers_pre=misc.str_to_ints( gc('srnn_pre_fc_layers')), fc_layers=fc_layers, activation=assign(net_act, dkws['activation']), use_lstm=use_lstm, use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, verbose=True, **mnet_kwargs).to(device) return mnet
def get_hypernet(config, device, net_type, target_shapes, num_conds, no_cond_weights=False, no_uncond_weights=False, uncond_in_size=0, shmlp_chunk_shapes=None, shmlp_num_per_chunk=None, shmlp_assembly_fct=None, verbose=True, cprefix=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. Args: config (argparse.Namespace): Command-line arguments. Note: The function expects command-line arguments available according to the function :func:`utils.cli_args.hnet_args`. device: PyTorch device. net_type (str): The type of network. The following options are available: - ``'hmlp'`` - ``'chunked_hmlp'`` - ``'structured_hmlp'`` - ``'hdeconv'`` - ``'chunked_hdeconv'`` target_shapes (list): See argument ``target_shapes`` of :class:`hnets.mlp_hnet.HMLP`. num_conds (int): Number of conditions that should be known to the hypernetwork. no_cond_weights (bool): See argument ``no_cond_weights`` of :class:`hnets.mlp_hnet.HMLP`. no_uncond_weights (bool): See argument ``no_uncond_weights`` of :class:`hnets.mlp_hnet.HMLP`. uncond_in_size (int): See argument ``uncond_in_size`` of :class:`hnets.mlp_hnet.HMLP`. shmlp_chunk_shapes (list, optional): Argument ``chunk_shapes`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_num_per_chunk (list, optional): Argument ``num_per_chunk`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. shmlp_assembly_fct (func, optional): Argument ``assembly_fct`` of :class:`hnets.structured_mlp_hnet.StructuredHMLP`. verbose (bool): Argument ``verbose`` of :class:`hnets.mlp_hnet.HMLP`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this function are prefixed, since several hypernetworks should be generated. Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hnet_args`. """ assert net_type in [ 'hmlp', 'chunked_hmlp', 'structured_hmlp', 'hdeconv', 'chunked_hdeconv' ] hnet = None ### FIXME Code almost identically copied from `get_mnet_model` ### if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) if hc('hnet_net_act'): net_act = gc('hnet_net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('hnet_no_bias') dropout_rate = get_val('hnet_dropout_rate') specnorm = get_val('hnet_specnorm') batchnorm = get_val('hnet_batchnorm') no_batchnorm = get_val('hnet_no_batchnorm') #bn_no_running_stats = get_val('hnet_bn_no_running_stats') #n_distill_stats = get_val('hnet_bn_distill_stats') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x ### FIXME Code copied until here ### if hc('hmlp_arch'): hmlp_arch_is_list = False hmlp_arch = gc('hmlp_arch') if ';' in hmlp_arch: hmlp_arch_is_list = True if net_type != 'structured_hmlp': raise ValueError('Option "%shmlp_arch" may only ' % (cprefix) + 'contain semicolons for network type ' + '"structured_hmlp"!') hmlp_arch = [misc.str_to_ints(ar) for ar in hmlp_arch.split(';')] else: hmlp_arch = misc.str_to_ints(hmlp_arch) if hc('chunk_emb_size'): chunk_emb_size = gc('chunk_emb_size') chunk_emb_size = misc.str_to_ints(chunk_emb_size) if len(chunk_emb_size) == 1: chunk_emb_size = chunk_emb_size[0] else: if net_type != 'structured_hmlp': raise ValueError('Option "%schunk_emb_size" may ' % (cprefix) + 'only contain multiple values for network ' + 'type "structured_hmlp"!') if hc('cond_emb_size'): cond_emb_size = gc('cond_emb_size') else: cond_emb_size = 0 if net_type == 'hmlp': assert hc('hmlp_arch') # Default keyword arguments of class HMLP. dkws = misc.get_default_args(HMLP.__init__) hnet = HMLP(target_shapes, uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'chunked_hmlp': assert hc('hmlp_arch') assert hc('chmlp_chunk_size') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') # Default keyword arguments of class ChunkedHMLP. dkws = misc.get_default_args(ChunkedHMLP.__init__) hnet = ChunkedHMLP( target_shapes, gc('chmlp_chunk_size'), chunk_emb_size=chunk_emb_size, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, layers=hmlp_arch, verbose=verbose, activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm'])).to(device) elif net_type == 'structured_hmlp': assert hc('hmlp_arch') assert hc('chunk_emb_size') cond_chunk_embs = get_val('use_cond_chunk_embs') assert shmlp_chunk_shapes is not None and \ shmlp_num_per_chunk is not None and \ shmlp_assembly_fct is not None # Default keyword arguments of class StructuredHMLP. dkws = misc.get_default_args(StructuredHMLP.__init__) dkws_hmlp = misc.get_default_args(HMLP.__init__) shmlp_hmlp_kwargs = [] if not hmlp_arch_is_list: hmlp_arch = [hmlp_arch] for i, arch in enumerate(hmlp_arch): shmlp_hmlp_kwargs.append({ 'layers': arch, 'activation_fn': assign(net_act, dkws_hmlp['activation_fn']), 'use_bias': assign(not no_bias, dkws_hmlp['use_bias']), 'dropout_rate': assign(dropout_rate, dkws_hmlp['dropout_rate']), 'use_spectral_norm': \ assign(specnorm, dkws_hmlp['use_spectral_norm']), 'use_batch_norm': assign(use_bn, dkws_hmlp['use_batch_norm']) }) if len(shmlp_hmlp_kwargs) == 1: shmlp_hmlp_kwargs = shmlp_hmlp_kwargs[0] hnet = StructuredHMLP(target_shapes, shmlp_chunk_shapes, shmlp_num_per_chunk, chunk_emb_size, shmlp_hmlp_kwargs, shmlp_assembly_fct, cond_chunk_embs=assign(cond_chunk_embs, dkws['cond_chunk_embs']), uncond_in_size=uncond_in_size, cond_in_size=cond_emb_size, verbose=verbose, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=num_conds).to(device) elif net_type == 'hdeconv': #HDeconv raise NotImplementedError else: assert net_type == 'chunked_hdeconv' #ChunkedHDeconv raise NotImplementedError return hnet
def _generate_networks(config, data_handlers, device, create_hnet=True, create_rnet=False, no_replay=False): """Create the main-net, hypernetwork and recognition network. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. create_hnet: Whether a hypernetwork should be constructed. If not, the main network will have trainable weights. create_rnet: Whether a task-recognition autoencoder should be created. no_replay: If the recognition network should be an instance of class MainModel rather than of class RecognitionNet (note, for multitask learning, no replay network is required). Returns: mnet: Main network instance. hnet: Hypernetwork instance. This return value is None if no hypernetwork should be constructed. rnet: RecognitionNet instance. This return value is None if no recognition network should be constructed. """ num_tasks = len(data_handlers) n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks main_arch = misc.str_to_ints(config.main_arch) main_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=n_y, hidden_layers=main_arch) mnet = MainNetwork(main_shapes, activation_fn=misc.str_to_act(config.main_act), use_bias=True, no_weights=create_hnet).to(device) if create_hnet: hnet_arch = misc.str_to_ints(config.hnet_arch) hnet = HyperNetwork(main_shapes, num_tasks, layers=hnet_arch, te_dim=config.emb_size, activation_fn=misc.str_to_act( config.hnet_act)).to(device) init_params = list(hnet.parameters()) else: hnet = None init_params = list(mnet.parameters()) if create_rnet: ae_arch = misc.str_to_ints(config.ae_arch) if no_replay: rnet_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=num_tasks, hidden_layers=ae_arch, use_bias=True) rnet = MainNetwork(rnet_shapes, activation_fn=misc.str_to_act(config.ae_act), use_bias=True, no_weights=False, dropout_rate=-1, out_fn=lambda x: F.softmax(x, dim=1)) else: rnet = RecognitionNet(n_x, num_tasks, dim_z=config.ae_dim_z, enc_layers=ae_arch, activation_fn=misc.str_to_act(config.ae_act), use_bias=True).to(device) init_params += list(rnet.parameters()) else: rnet = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) elif config.normal_init: torch.nn.init.normal_(W, mean=0, std=config.std_normal_init) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_hnet: for temb in hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if config.use_hyperfan_init: hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2) return mnet, hnet, rnet
def analyse_single_run(out_dir, device, writer, logger, analysis_kwd, get_loss_func, accuracy_func, generate_tasks_func, n_samples=-1, redo_analyses=False, do_kernel_pca=False, do_supervised_dimred=False, timesteps_for_analysis=None, copy_task=True, num_tasks=-1, sup_dimred_criterion=None, sup_dimred_args={}): """Analyse the hidden dimensionality for an individual run. Args: out_dir (str): The path to the output directory. device: The device. writer: The tensorboard writer. logger: The logger. analysis_kwd (dict): The dictionary containing important keywords for the current analysis. get_loss_func (func): A handler to generate the loss function. accuracy_func (func): A handler to the accuracy function. generate_tasks_func (func): A handler to a datahandler generator. redo_analyses (boolean, optional): If ``True``, analyses will be redone even if they had been stored previously. do_kernel_pca (bool, optional): If ``True``, kernel PCA will also be used to compute the number of hidden dimensions. do_supervised_dimred (bool, optional): If ``True``, supervised linear dimensionality reduction will be used to compute the number of task-relevant hidden dimensions. n_samples (int): The number of samples to be used. timesteps_for_analysis (str, optional): The timesteps to be used for the PCA analyses. copy_task (bool, optional): Indicates whether we are analysing the Copy Task or not. num_tasks (int, optional): The number of tasks to be considered. sup_dimred_criterion (int, optional): If provided, this value will be used as stopping criterion when looking for the number of necessary supervised components to describe the hidden activity. sup_dimred_args (dict): Optional arguments (e.g., optimization arguments) passed to the supervised dimensionality reduction :func:`sequential.ht_analyses.supervised_dimred_utils.\ get_loss_vs_supervised_n_dim`. Returns: (tuple): Tuple containing: - **results**: The dictionary of results for the current run. - **settings**: The dictionary with the values of the parameters that are specified in `analysis_kwd['fixed_params']`. """ ### Prepare the data and the networks. # Load the config if not os.path.exists(out_dir): raise ValueError('The directory "%s" does not exist.'%out_dir) with open(os.path.join(out_dir, "config.pickle"), "rb") as f: config = pickle.load(f) # Overwrite the directory it it's not the same as the original. if config.out_dir != out_dir: config.out_dir = out_dir # Check for old command line arguments and make compatible with new version. config = train_args_sequential.update_cli_args(config) print('Working on output directory "%s".' % out_dir) # Overwrite the number of tasks. if num_tasks == -1: num_tasks = config.num_tasks if sup_dimred_criterion == -1: sup_dimred_criterion = None stop_bit=None if copy_task: # Get the index of the stop bit. #stop_bit = getattr(config, analysis_kwd['complexity_measure']) # If we do not enforce the condition below, we have to determine the # location of the stop bit on a sample-by-sample basis. assert config.input_len_step == 0 and config.input_len_variability == 0 stop_bit = config.first_task_input_len if config.pad_after_stop: stop_bit = config.pat_len ### Sanity checks. # Do some sanity checks in the parameters. assert config.use_ewc or config.use_si if config.use_ewc: method = 'ewc' elif config.use_si: method = 'si' for key, value in analysis_kwd['forced_params']: assert getattr(config, key) == value # Ensure all runs have comparable properties if 'num_tasks' not in analysis_kwd['fixed_params']: analysis_kwd['fixed_params'].append('num_tasks') ### Create the settings dictionary. settings = {} for key in analysis_kwd['fixed_params']: settings[key] = getattr(config, key) if key == 'num_tasks': settings[key] = num_tasks ### Load or create the results dictionary. if os.path.exists(os.path.join(out_dir, "pca_results.pickle")) and \ not redo_analyses: ### Load existing results. with open(os.path.join(out_dir, "pca_results.pickle"), "rb") as f: results = pickle.load(f) print('PCA analyses have been done and stored previously and reloaded.') assert num_tasks == -1 or results['num_tasks'] == num_tasks if 'mean_fisher' in results: results['mean_importance'] = results['mean_fisher'] results['mean_importance_ho'] = results['mean_fisher_ho'] else: ### Prepare the environment. # Define functions. task_loss_func = get_loss_func(config, device, logger) accuracy_func = accuracy_func # Generate datahandlers dhandlers = generate_tasks_func(config, logger, writer=writer) config.show_plots = True plc.visualise_data(dhandlers, config, device) # Generate the networks shared = argparse.Namespace() # FIXME might not work for all datasets (e.g., PoS tagging). shared.feature_size = dhandlers[0].in_shape[0] target_net, hnet, _ = stu.generate_networks(config, shared, dhandlers, device) ### Initialize the results dictionary. results = {} if copy_task: results['masked'] = config.pat_len results['pad_after_stop'] = config.pad_after_stop results['accs_per_ts'] = [] results['permutation'] = [] results['expl_var_per_ts'] = [] results['kexpl_var_per_ts'] = [] results['expl_var_per_ts_yt'] = [] results['kexpl_var_per_ts_yt'] = [] results['complexity_measure'] = getattr(config, \ analysis_kwd['complexity_measure']) results['complexity_measure_name'] = \ analysis_kwd['complexity_measure_name'] results['num_tasks'] = num_tasks results['final_acc'] = [] results['final_loss'] = [] results['mean_importance'] = [] results['mean_importance_ho'] = [] results['expl_var'] = [] results['kexpl_var'] = [] results['expl_var_yt'] = [] results['kexpl_var_yt'] = [] if do_supervised_dimred: # Note, in the code 'loss_n_dim_supervised' plays, for the # supervised dimensionality reduction, the same role as 'expl_var' # for the standard PCA analysis, i.e. we store the explained # variance (resp. loss) as a function of how many dimensions are # taken into account, and then select a threshold for the explained # variance (resp. loss) to determine the number of intrinsic # dimensions. results['loss_n_dim_supervised'] = [] results['accu_n_dim_supervised'] = [] if copy_task: results['accu_n_dim_sup_at_stop'] = [] results['loss_n_dim_sup_at_stop'] = [] # Iterate over all tasks and accumulate results in lists within the # results dictionary values. all_during_act = [] all_during_act_yt = [] for task_id in range(num_tasks): if copy_task: results['permutation'].append(dhandlers[task_id].permutation) ### Load the checkpointed during model for the corresponding task. # Note, the return values of the function below are just references # to the variables `target_net` and `hnet`, which are modified in- # place. mnet, hnet = load_models(out_dir, device, logger, target_net, hnet, wembs=None, task_id=task_id, method=method) # FIXME Should we disentangle weight matrices and bias vectors? hh_imp_values = get_importance_values(mnet, connection_type='hh', method=method) results['mean_importance'].append(np.mean(hh_imp_values)) ho_imp_values = get_importance_values(mnet, connection_type='ho', method=method) if ho_imp_values != []: results['mean_importance_ho'].append(np.mean(ho_imp_values)) else: results['mean_importance_ho'].append(np.nan) ### Obtain hidden activations and performances. # We only measure the final accuracy up to the current task, since # we are simulating a continual learning setting with less tasks. loss, accs, accs_per_ts = test(dhandlers, device, config, None, logger, writer, mnet, hnet, store_activations=True, \ accuracy_func=accuracy_func, task_loss_func=task_loss_func, num_trained=task_id, return_acc_per_ts=True) results['final_loss'].append(np.mean(loss[:task_id+1])) if accs is None: results['final_acc'].append(None) else: results['final_acc'].append(np.mean(accs[:task_id+1])) if copy_task: results['accs_per_ts'].append(accs_per_ts[task_id]) ### Load the internal activations. tasks_act, act = get_activations(out_dir, task_id=task_id, vanilla_rnn=config.use_vanilla_rnn) n_hidden = np.sum(misc.str_to_ints(config.rnn_arch)) assert act.shape[-1] == n_hidden all_during_act.append(act) tasks_act_yt, act_yt = get_activations(out_dir, task_id=task_id, internal=False, vanilla_rnn=config.use_vanilla_rnn) all_during_act_yt.append(act_yt) ### Do PCA analyses. # Do analyses on internal recurrent activations. results = pca_analysis_single_task(act, results, do_kernel_pca=do_kernel_pca, n_samples=n_samples, timesteps=timesteps_for_analysis, stop_bit=stop_bit, do_supervised_dimred=do_supervised_dimred) # Do analyses on output recurrent activations. results = pca_analysis_single_task(act_yt, results, do_kernel_pca=do_kernel_pca, n_samples=n_samples, timesteps=timesteps_for_analysis, stop_bit=stop_bit, internal=False, do_supervised_dimred=do_supervised_dimred) if do_supervised_dimred: if not copy_task: raise NotImplementedError('TODO need to adapt the ' + 'loss computation for tasks other than the Copy Task.') # Do supervised dimensionality reduction on during models. loss_dim, accu_dim = get_loss_vs_supervised_n_dim(mnet, hnet, task_loss_func, accuracy_func, dhandlers, config, device, task_id=task_id, criterion=sup_dimred_criterion, writer_dir=out_dir, **sup_dimred_args) results['loss_n_dim_supervised'].append(loss_dim) results['accu_n_dim_supervised'].append(accu_dim) if copy_task: loss_dim, accu_dim = get_loss_vs_supervised_n_dim(mnet, hnet, task_loss_func, accuracy_func, dhandlers, config, device, stop_timestep=stop_bit, task_id=task_id, criterion=sup_dimred_criterion, writer_dir=out_dir, **sup_dimred_args) results['loss_n_dim_sup_at_stop'].append(loss_dim) results['accu_n_dim_sup_at_stop'].append(accu_dim) ### Get hidden dimensionality using the final model. # Note, here we overwrite the files "int_activations.pickle" and # "activations.pickle" that were generated when testing the model of # the current task. os.remove(os.path.join(out_dir, 'int_activations.pickle')) os.remove(os.path.join(out_dir, 'activations.pickle')) mnet, hnet = load_models(out_dir, device, logger, target_net, hnet, wembs=None, method=method) _ = test(dhandlers, device, config, shared, logger, writer, mnet, hnet, store_activations=True, accuracy_func=accuracy_func, task_loss_func=task_loss_func, num_trained=task_id, return_acc_per_ts=True) # Load internal activations. tasks_act, act = get_activations(out_dir, task_id=task_id, vanilla_rnn=config.use_vanilla_rnn) tasks_act_yt, act_yt = get_activations(out_dir, task_id=task_id, internal=False, vanilla_rnn=config.use_vanilla_rnn) ### Do PCA analyses on final models. results = pca_analysis_all_tasks(act, all_during_act, results, do_kernel_pca=do_kernel_pca, n_samples=n_samples, timesteps=timesteps_for_analysis, stop_bit=stop_bit, copy_task=copy_task) results = pca_analysis_all_tasks(act_yt, all_during_act_yt, results, do_kernel_pca=do_kernel_pca, n_samples=n_samples, timesteps=timesteps_for_analysis, stop_bit=stop_bit, copy_task=copy_task, internal=False) if do_supervised_dimred and len(all_during_act) > 1: ### Do supervised dimensionality reduction on final models. # Only do if we dealt with more than one task. loss_dim, accu_dim = get_loss_vs_supervised_n_dim(mnet, hnet, task_loss_func, accuracy_func, dhandlers, config, device, criterion=sup_dimred_criterion, writer_dir=out_dir, **sup_dimred_args) results['loss_n_dim_supervised_all_tasks'] = loss_dim results['accu_n_dim_supervised_all_tasks'] = accu_dim # Store pickle results. with open(os.path.join(out_dir, 'pca_results.pickle'), 'wb') as handle: pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL) return results, settings
def get_hnet_model(config, num_tasks, device, mnet_shapes, cprefix=None, no_weights=False, no_tembs=False, temb_size=None): """Generate a hypernetwork instance. A helper to generate the hypernetwork according to the given the user configurations. .. deprecated:: 1.0 Please use function :func:`get_hypernet` instead. As this function creates deprecated hypernetworks. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.hypernet_args`. num_tasks (int): The number of task embeddings the hypernetwork should have. device: PyTorch device. mnet_shapes: Dimensions of the weight tensors of the main network. See main net argument :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several hypernetworks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.hypernet_args`. no_weights (bool): Whether the hyper network should be generated without internal weights (excluding task embeddings). no_tembs (bool): Whether the hypernetwork should be generated without internally maintained task embeddings. temb_size (int, optional): If user config should be overwritten, then this option can be used to specify the dimensionality of task embeddings. Returns: The created hypernet model. """ warn('Please use function "utils.sim_utils.get_hypernet" instead. As ' +\ 'this function creates deprecated hypernetworks.', DeprecationWarning) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) hyper_chunks = misc.str_to_ints(gc('hyper_chunks')) assert(len(hyper_chunks) in [1,2,3]) if len(hyper_chunks) == 1: hyper_chunks = hyper_chunks[0] hnet_arch = misc.str_to_ints(gc('hnet_arch')) sa_hnet_filters = misc.str_to_ints(gc('sa_hnet_filters')) sa_hnet_kernels = misc.str_to_ints(gc('sa_hnet_kernels')) sa_hnet_attention_layers = misc.str_to_ints(gc('sa_hnet_attention_layers')) hnet_act = misc.str_to_act(gc('hnet_act')) if temb_size is None: temb_size = gc('temb_size') if isinstance(hyper_chunks, list): # Chunked self-attention hypernet raise NotImplementedError('Not publicly available') elif hyper_chunks != -1: # Chunked fully-connected hypernet hnet = ChunkedHyperNetworkHandler(mnet_shapes, num_tasks, chunk_dim=hyper_chunks, layers=hnet_arch, activation_fn=hnet_act, te_dim=temb_size, no_te_embs=no_tembs, ce_dim=gc('emb_size'), dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), no_weights=no_weights, temb_std=gc('temb_std')).to(device) else: # Fully-connected hypernet. hnet = HyperNetwork(mnet_shapes, num_tasks, layers=hnet_arch, te_dim=temb_size, no_te_embs=no_tembs, activation_fn=hnet_act, dropout_rate=gc('hnet_dropout_rate'), noise_dim=gc('hnet_noise_dim'), no_weights=no_weights, temb_std=gc('temb_std')).to(device) return hnet
def run(ref_module, results_dir='./out/random_seeds', config=None, ignore_kwds=None, forced_params=None): """Run the script. Args: ref_module (str): Name of the reference module which contains the hyperparameter search config that can be modified to gather random seeds. results_dir (str, optional): The path where to store the results. config: The Namespace object containing argument names and values. If provided, all random seeds will be gathered from zero, with no reference run. ignore_kwds (list, optional): The list of keywords in the config file to exclude from the grid. forced_params (dict, optional): Dict of key-value pairs specifying hyperparameter values that should be fixed across runs """ if ignore_kwds is None: ignore_kwds = [] if forced_params is None: forced_params = {} ### Parse the command-line arguments. parser = argparse.ArgumentParser(description= \ 'Gathering random seeds for the specified experiment.') parser.add_argument('--out_dir', type=str, default='', help='The output directory of the run or runs. ' + 'For single runs, the configuration will be ' + 'loaded and run with different seeds.' + 'For multiple runs, i.e. results of ' + 'hyperparameter searches, the configuration ' + 'leading to the best mean final accuracy ' + 'will be selected and run with different seeds. ' + 'Default: %(default)s.') parser.add_argument('--config_name', type=str, default='hpsearch_random_seeds.py', help='The name of the hpsearch config file. Since ' + 'multiple random seed gathering experiments ' + 'might be running in parallel, it is important ' + 'that this file has a unique name for each ' + 'experiment. Default: %(default)s.') parser.add_argument('--config_pickle', type=str, default='', help='The path to a pickle file containing a run ' + ' config that will be loaded.') parser.add_argument('--num_seeds', type=int, default=10, help='The number of different random seeds.') # FIXME `None` is not a valid default value. parser.add_argument('--seeds_list', type=str, default=None, help='The list of seeds to use. If specified, ' + '"num_seeds" will be ignored.') parser.add_argument('--vary_data_seed', action='store_true', help='If activated, "data_random_seed"s are set ' + 'equal to "random_seed"s. Otherwise only ' + '"random_seed"s are varied.') parser.add_argument('--num_tot_hours', type=int, metavar='N', default=120, help='If "run_cluster" is activated, then this ' + 'option determines the maximum number of hours ' + 'the entire search may run on the cluster. ' + 'Default: %(default)s.') # FIXME Arguments below are copied from hpsearch. parser.add_argument('--run_cluster', action='store_true', help='This option would produce jobs for a GPU ' + 'cluser running a job scheduler (see option ' + '"scheduler".') parser.add_argument('--scheduler', type=str, default='lsf', choices=['lsf', 'slurm'], help='The job scheduler used on the cluster. ' + 'Default: %(default)s.') parser.add_argument('--num_jobs', type=int, metavar='N', default=8, help='If "run_cluster" is activated, then this ' + 'option determines the maximum number of jobs ' + 'that can be submitted in parallel. ' + 'Default: %(default)s.') parser.add_argument('--num_hours', type=int, metavar='N', default=24, help='If "run_cluster" is activated, then this ' + 'option determines the maximum number of hours ' + 'a job may run on the cluster. ' + 'Default: %(default)s.') parser.add_argument('--resources', type=str, default='"rusage[mem=8000, ngpus_excl_p=1]"', help='If "run_cluster" is activated and "scheduler" ' + 'is "lsf", then this option determines the ' + 'resources assigned to job in the ' + 'hyperparameter search (option -R of bsub). ' + 'Default: %(default)s.') parser.add_argument('--slurm_mem', type=str, default='8G', help='If "run_cluster" is activated and "scheduler" ' + 'is "slurm", then this value will be passed as ' + 'argument "mem" of "sbatch". An empty string ' + 'means that "mem" will not be specified. ' + 'Default: %(default)s.') parser.add_argument('--slurm_gres', type=str, default='gpu:1', help='If "run_cluster" is activated and "scheduler" ' + 'is "slurm", then this value will be passed as ' + 'argument "gres" of "sbatch". An empty string ' + 'means that "gres" will not be specified. ' + 'Default: %(default)s.') parser.add_argument('--slurm_partition', type=str, default='', help='If "run_cluster" is activated and "scheduler" ' + 'is "slurm", then this value will be passed as ' + 'argument "partition" of "sbatch". An empty ' + 'string means that "partition" will not be ' + 'specified. Default: %(default)s.') parser.add_argument('--slurm_qos', type=str, default='', help='If "run_cluster" is activated and "scheduler" ' + 'is "slurm", then this value will be passed as ' + 'argument "qos" of "sbatch". An empty string ' + 'means that "qos" will not be specified. ' + 'Default: %(default)s.') parser.add_argument('--slurm_constraint', type=str, default='', help='If "run_cluster" is activated and "scheduler" ' + 'is "slurm", then this value will be passed as ' + 'argument "constraint" of "sbatch". An empty ' + 'string means that "constraint" will not be ' + 'specified. Default: %(default)s.') parser.add_argument('--visible_gpus', type=str, default='', help='If "run_cluster" is NOT activated, then this ' + 'option determines the CUDA devices visible to ' + 'the hyperparameter search. A string of comma ' + 'separated integers is expected. If the list is ' + 'empty, then all GPUs of the machine are used. ' + 'The relative memory usage is specified, i.e., ' + 'a number between 0 and 1. If "-1" is given, ' + 'the jobs will be executed sequentially and not ' + 'assigned to a particular GPU. ' + 'Default: %(default)s.') parser.add_argument('--allowed_load', type=float, default=0.5, help='If "run_cluster" is NOT activated, then this ' + 'option determines the maximum load a GPU may ' + 'have such that another process may start on ' + 'it. The relative load is specified, i.e., a ' + 'number between 0 and 1. Default: %(default)s.') parser.add_argument('--allowed_memory', type=float, default=0.5, help='If "run_cluster" is NOT activated, then this ' + 'option determines the maximum memory usage a ' + 'GPU may have such that another process may ' + 'start on it. Default: %(default)s.') parser.add_argument('--sim_startup_time', type=int, metavar='N', default=60, help='If "run_cluster" is NOT activated, then this ' + 'option determines the startup time of ' + 'simulations. If a job was assigned to a GPU, ' + 'then this time (in seconds) has to pass before ' + 'options "allowed_load" and "allowed_memory" ' + 'are checked to decide whether a new process ' + 'can be send to a GPU.Default: %(default)s.') parser.add_argument('--max_num_jobs_per_gpu', type=int, metavar='N', default=1, help='If "run_cluster" is NOT activated, then this ' + 'option determines the maximum number of jobs ' + 'per GPU that can be submitted in parallel. ' + 'Note, this script does not validate whether ' + 'other processes are already assigned to a GPU. ' + 'Default: %(default)s.') cmd_args = parser.parse_args() out_dir = cmd_args.out_dir if cmd_args.out_dir == '' and cmd_args.config_pickle != '': with open(cmd_args.config_pickle, "rb") as f: config = pickle.load(f) # Either a config or an experiment folder need to be provided. assert config is not None or cmd_args.out_dir != '' if cmd_args.out_dir == '': out_dir = config.out_dir # Make sure that the provided hpsearch config file name does not exist. config_name = cmd_args.config_name if config_name[-3:] != '.py': config_name = config_name + '.py' if os.path.exists(config_name): overwrite = input('The config file "%s" '% config_name + \ 'already exists! Do you want to overwrite the file? [y/n] ') if not overwrite in ['yes', 'y', 'Y']: exit() # The following ensures that we can safely use `basename` later on. out_dir = os.path.normpath(out_dir) ### Create directory for results. if not os.path.exists(results_dir): os.makedirs(results_dir) # Define a subfolder for the current random seed runs. results_dir = os.path.join(results_dir, os.path.basename(out_dir)) print('Random seeds will be gathered in folder %s.' % results_dir) if os.path.exists(results_dir): # If random seeds have been gathered already, simply get the results for # publication. write_seeds_summary(results_dir) raise RuntimeError('Output directory %s already exists! ' %results_dir+\ 'seems like random seeds already have been gathered.') ### Get the experiments config. num_seeds = cmd_args.num_seeds if config is None: # Check if the current directory corresponds to a single run or not. # FIXME quick and dirty solution to figure out, whether it's a single # run. single_run = False if not os.path.exists(os.path.join(out_dir, 'search_results.csv')) \ and not os.path.exists(os.path.join(out_dir, \ 'postprocessing_results.csv')): single_run = True # Get the configuration. if single_run: config = get_single_run_config(out_dir) best_out_dir = out_dir else: config, best_out_dir = get_hpsearch_config(out_dir) # Since we already have a reference run, we can run one seed less. num_seeds -= 1 if cmd_args.seeds_list is not None: seeds_list = misc.str_to_ints(cmd_args.seeds_list) cmd_args.num_seeds = len(seeds_list) else: seeds_list = list(range(num_seeds)) # Replace config values provided via `forced_params`. if len(forced_params.keys()) > 0: for kwd, value in forced_params.items(): setattr(config, kwd, value) ### Write down the hp search grid module in its own file. ref_module_basename = ref_module[[i for i,e in \ enumerate(ref_module) if e == '.'][-1]+1:] ref_module_path = ref_module[:[i for i,e in \ enumerate(ref_module) if e == '.'][-1]+1] shutil.copy(ref_module_basename + '.py', config_name) # Define the kwds to be added to the grid. kwds = list(vars(config).keys()) for kwd in ignore_kwds: if kwd in kwds: kwds.remove(kwd) # Remove old grid and write new grid, and remove conditions. grid_loc = delete_object_from_text(config_name, 'grid', '{', '}') random_seeds = write_new_grid_to_text(config_name, config, grid_loc, \ seeds_list, cmd_args, kwds=kwds) cond_loc = delete_object_from_text(config_name, 'conditions', \ '[', ']') write_new_conditions_to_text(config_name, cond_loc, random_seeds, cmd_args) ### Run the hpsearch code with different random seeds. hpsearch_module = ref_module_path + config_name[:-3] cmd_str = get_command_line(hpsearch_module, results_dir, cmd_args) print(cmd_str) if cmd_args.run_cluster and cmd_args.scheduler == 'slurm': # FIXME hacky solution to write SLURM job script. # FIXME might be wrong to give the same `slurm_qos` to the hpsearch, # as the job might have to run much longer. job_script_fn = hpsearch._write_slurm_script( Namespace( **{ 'num_hours': cmd_args.num_tot_hours, 'slurm_mem': '8G', 'slurm_gres': '', 'slurm_partition': cmd_args.slurm_partition, 'slurm_qos': cmd_args.slurm_qos, 'slurm_constraint': cmd_args.slurm_constraint, }), cmd_str, 'random_seeds') cmd_str = 'sbatch %s' % job_script_fn print('We will execute command "%s".' % cmd_str) # Execute the program. print('Starting gathering random seeds...') ret = call(cmd_str, shell=True, executable='/bin/bash') print('Call finished with return code %d.' % ret) ### Add results of the reference run to our results folder. new_best_out_dir = os.path.join(results_dir, os.path.basename(out_dir)) copy_tree(best_out_dir, new_best_out_dir) ### Store results of given run in CSV file. # FIXME Extremely ugly solution. imported_grid_module = importlib.import_module(hpsearch_module) hpsearch._read_config(imported_grid_module) results_file = os.path.join(results_dir, 'search_results.csv') cmd_dict = dict() for k in kwds: cmd_dict[k] = getattr(config, k) # Get training results. performance_dict = hpsearch._SUMMARY_PARSER_HANDLE(new_best_out_dir, -1) for k, v in performance_dict.items(): cmd_dict[k] = v # Create or update the CSV file summarizing all runs. panda_frame = pd.DataFrame.from_dict(cmd_dict) if os.path.isfile(results_file): old_frame = pd.read_csv(results_file, sep=';') panda_frame = pd.concat([old_frame, panda_frame], sort=True) panda_frame.to_csv(results_file, sep=';', index=False) # Create a text file aggregating all results for publication. write_seeds_summary(results_dir)
def generate_replay_networks(config, data_handlers, device, create_rp_hnet=True, only_train_replay=False): """Create a replay model that consists of either a encoder/decoder or a discriminator/generator pair. Additionally, this method manages the creation of a hypernetwork for the generator/decoder. Following important configurations will be determined in order to create the replay model: * in- and output and hidden layer dimensions of the encoder/decoder. * architecture, chunk- and task-embedding details of decoder's hypernetwork. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device.. create_rp_hnet: Whether a hypernetwork for the replay should be constructed. If not, the decoder/generator will have trainable weights on its own. only_train_replay: We normally do not train on the last task since we do not need to replay this last tasks data. But if we want a replay method to be able to generate data from all tasks then we set this option to true. Returns: (tuple): Tuple containing: - **enc**: Encoder/discriminator network instance. - **dec**: Decoder/generator networkinstance. - **dec_hnet**: Hypernetwork instance for the decoder/generator. This return value is None if no hypernetwork should be constructed. """ if config.replay_method == 'gan': n_out = 1 else: n_out = config.latent_dim * 2 n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.single_class_replay: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.infer_task_id: # task inference network config.out_dim = 1 # builld encoder print('For the replay encoder/discriminator: ') enc_arch = misc.str_to_ints(config.enc_fc_arch) enc = MLP(n_in=n_in, n_out=n_out, hidden_layers=enc_arch, activation_fn=misc.str_to_act(config.enc_net_act), dropout_rate=config.enc_dropout_rate, no_weights=False).to(device) print('Constructed MLP with shapes: ', enc.param_shapes) init_params = list(enc.parameters()) # builld decoder print('For the replay decoder/generator: ') dec_arch = misc.str_to_ints(config.dec_fc_arch) # add dimensions for conditional input n_out = config.latent_dim if config.conditional_replay: n_out += config.conditional_dim dec = MLP(n_in=n_out, n_out=n_in, hidden_layers=dec_arch, activation_fn=misc.str_to_act(config.dec_net_act), use_bias=True, no_weights=config.rp_beta > 0, dropout_rate=config.dec_dropout_rate).to(device) print('Constructed MLP with shapes: ', dec.param_shapes) config.num_weights_enc = \ MainNetInterface.shapes_to_num_weights(enc.param_shapes) config.num_weights_dec = \ MainNetInterface.shapes_to_num_weights(dec.param_shapes) config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec # we do not need a replay model for the last task # train on last task or not if only_train_replay: subtr = 0 else: subtr = 1 num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1 if config.single_class_replay: # we do not need a replay model for the last task if config.num_tasks > 1: num_embeddings = config.out_dim * (config.num_tasks - subtr) else: num_embeddings = config.out_dim * (config.num_tasks) config.num_embeddings = num_embeddings # build decoder hnet if create_rp_hnet: print('For the decoder/generator hypernetwork: ') d_hnet = sim_utils.get_hnet_model(config, num_embeddings, device, dec.hyper_shapes_learned, cprefix='rp_') init_params += list(d_hnet.parameters()) config.num_weights_rp_hyper_net = sum(p.numel() for p in d_hnet.parameters() if p.requires_grad) config.compression_ratio_rp = config.num_weights_rp_hyper_net / \ config.num_weights_dec print('Created replay hypernetwork with ratio: ', config.compression_ratio_rp) if config.compression_ratio_rp > 1: print('Note that the compression ratio is computed compared to ' + 'current target network,\nthis might not be directly ' + 'comparable with the number of parameters of methods we ' + 'compare against.') else: num_embeddings = config.num_tasks - subtr d_hnet = None init_params += list(dec.parameters()) config.num_weights_rp_hyper_net = 0 config.compression_ratio_rp = 0 ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_rp_hnet: for temb in d_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(d_hnet, 'chunk_embeddings'): for emb in d_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) return enc, dec, d_hnet
def generate_gauss_networks(config, logger, data_handlers, device, no_mnet_weights=None, create_hnet=True, in_shape=None, out_shape=None, net_type='mlp', non_gaussian=False): """Create main network and potentially the corresponding hypernetwork. The function will first create a normal MLP and then convert it into a network with Gaussian weight distribution by using the wrapper :class:`probabilistic.gauss_mnet_interface.GaussianBNNWrapper`. This function also takes care of weight initialization. Args: config: Command-line arguments. logger: Console (and file) logger. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. no_mnet_weights (bool, optional): Whether the main network should not have trainable weights. If left unspecified, then the main network will only have trainable weights if ``create_hnet`` is ``False``. create_hnet (bool): Whether a hypernetwork should be constructed. in_shape (list, optional): Input shape that is passed to function :func:`utils.sim_utils.get_mnet_model` as argument ``in_shape``. If not specified, it is set to ``[data_handlers[0].in_shape[0]]``. out_shape (list, optional): Output shape that is passed to function :func:`utils.sim_utils.get_mnet_model` as argument ``out_shape``. If not specified, it is set to ``[data_handlers[0].out_shape[0]]``. net_type (str): See argument ``net_type`` of function :func:`utils.sim_utils.get_mnet_model`. non_gaussian (bool): If ``True``, then the main network will not be converted into a Gaussian network. Hence, networks remain deterministic. Returns: (tuple): Tuple containing: - **mnet**: Main network instance. - **hnet** (optional): Hypernetwork instance. This return value is ``None`` if no hypernetwork should be constructed. """ assert not hasattr(config, 'mean_only') or config.mean_only == non_gaussian assert not non_gaussian or not config.local_reparam_trick assert not non_gaussian or not config.hyper_gauss_init num_tasks = len(data_handlers) # Should be set, except for regression. if in_shape is None or out_shape is None: assert in_shape is None and out_shape is None assert net_type == 'mlp' assert hasattr(config, 'multi_head') n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks in_shape = [n_x] out_shape = [n_y] ### Main network. logger.info('Creating main network ...') if no_mnet_weights is None: no_mnet_weights = create_hnet if config.local_reparam_trick: if net_type != 'mlp': raise NotImplementedError('The local reparametrization trick is ' + 'only implemented for MLPs so far!') assert len(in_shape) == 1 and len(out_shape) == 1 mlp_arch = utils.str_to_ints(config.mlp_arch) net_act = utils.str_to_act(config.net_act) mnet = GaussianMLP(n_in=in_shape[0], n_out=out_shape[0], hidden_layers=mlp_arch, activation_fn=net_act, use_bias=not config.no_bias, no_weights=no_mnet_weights).to(device) else: mnet_kwargs = {} if net_type == 'iresnet': mnet_kwargs['cutout_mod'] = True mnet = sutils.get_mnet_model(config, net_type, in_shape, out_shape, device, no_weights=no_mnet_weights, **mnet_kwargs) # Initiaize main net weights, if any. assert (not hasattr(config, 'custom_network_init')) mnet.custom_init(normal_init=config.normal_init, normal_std=config.std_normal_init, zero_bias=True) # Convert main net into Gaussian BNN. orig_mnet = mnet if not non_gaussian: mnet = GaussianBNNWrapper(mnet, no_mean_reinit=config.keep_orig_init, logvar_encoding=config.use_logvar_enc, apply_rho_offset=True, is_radial=config.radial_bnn).to(device) else: logger.debug('Created main network will not be converted into a ' + 'Gaussian main network.') ### Hypernet. hnet = None if create_hnet: logger.info('Creating hypernetwork ...') chunk_shapes, num_per_chunk, assembly_fct = None, None, None if config.hnet_type == 'structured_hmlp': if net_type == 'resnet': chunk_shapes, num_per_chunk, orig_assembly_fct = \ resnet_chunking(orig_mnet, gcd_chunking=config.shmlp_gcd_chunking) elif net_type == 'wrn': chunk_shapes, num_per_chunk, orig_assembly_fct = \ wrn_chunking(orig_mnet, gcd_chunking=config.shmlp_gcd_chunking, ignore_bn_weights=False, ignore_out_weights=False) else: raise NotImplementedError( '"structured_hmlp" not implemented ' + 'for network of type %s.' % net_type) if non_gaussian: assembly_fct = orig_assembly_fct else: chunk_shapes = chunk_shapes + chunk_shapes num_per_chunk = num_per_chunk + num_per_chunk def assembly_fct_gauss(list_of_chunks): n = len(list_of_chunks) mean_chunks = list_of_chunks[:n // 2] rho_chunks = list_of_chunks[n // 2:] return orig_assembly_fct(mean_chunks) + \ orig_assembly_fct(rho_chunks) assembly_fct = assembly_fct_gauss # For now, we either produce all or no weights with the hypernet. # Note, it can be that the mnet was produced with internal weights. assert mnet.hyper_shapes_learned is None or \ len(mnet.param_shapes) == len(mnet.hyper_shapes_learned) hnet = sutils.get_hypernet(config, device, config.hnet_type, mnet.param_shapes, num_tasks, shmlp_chunk_shapes=chunk_shapes, shmlp_num_per_chunk=num_per_chunk, shmlp_assembly_fct=assembly_fct) if config.hnet_out_masking != 0: logger.info('Generating binary masks to select task-specific ' + 'subnetworks from hypernetwork.') # Add a wrapper around the hypernpetwork that masks its outputs # using a task-specific binary mask layer per layer. Note that # output weights are not masked. # Ensure that masks are kind of deterministic for a given hyper- # param config/task. mask_gen = torch.Generator() mask_gen = mask_gen.manual_seed(42) # Generate a random binary mask per task. assert len(mnet.param_shapes) == len(hnet.target_shapes) hnet_out_masks = [] for tid in range(config.num_tasks): hnet_out_mask = [] for layer_shapes, is_output in zip(mnet.param_shapes, \ mnet.get_output_weight_mask()): layer_mask = torch.ones(layer_shapes) if is_output is None: # We only mask weights that are not output weights. layer_mask = torch.rand(layer_shapes, generator=mask_gen) layer_mask[layer_mask > config.hnet_out_masking] = 1 layer_mask[layer_mask <= config.hnet_out_masking] = 0 hnet_out_mask.append(layer_mask) hnet_out_masks.append(hnet_out_mask) hnet_out_masks = hnet.convert_out_format(hnet_out_masks, 'sequential', 'flattened') def hnet_out_masking_func(hnet_out_int, uncond_input=None, cond_input=None, cond_id=None): assert isinstance(cond_id, (int, list)) if isinstance(cond_id, int): cond_id = [cond_id] hnet_out_int[hnet_out_masks[cond_id, :] == 0] = 0 return hnet_out_int def hnet_inp_handler(uncond_input=None, cond_input=None, cond_id=None): # Identity return uncond_input, cond_input, cond_id hnet = HPerturbWrapper(hnet, output_handler=hnet_out_masking_func, input_handler=hnet_inp_handler) #if config.hnet_type == 'structured_hmlp': # print(num_per_chunk) # for ii, int_hnet in enumerate(hnet.internal_hnets): # print(' Internal hnet %d with %d outputs.' % \ # (ii, int_hnet.num_outputs)) ### Initialize hypernetwork. if not config.hyper_gauss_init: apply_custom_hnet_init(config, logger, hnet) else: # Initialize task embeddings, if any. hnet_helpers.init_conditional_embeddings( hnet, normal_std=config.std_normal_temb) gauss_hyperfan_init(hnet, mnet=mnet, use_xavier=True, cond_var=config.std_normal_temb**2, keep_hyperfan_mean=config.keep_orig_init) return mnet, hnet
def run(grid_module=None, results_dir='./out/random_seeds', config=None, ignore_kwds=None, forced_params=None, summary_keys=None, summary_sem=False, summary_precs=None, hpmod_path=None): """Run the script. Args: grid_module (str, optional): Name of the reference module which contains the hyperparameter search config that can be modified to gather random seeds. results_dir (str, optional): The path where the hpsearch should store its results. config: The Namespace object containing argument names and values. If provided, all random seeds will be gathered from zero, with no reference run. ignore_kwds (list, optional): A list of keywords in the config file to exclude from the grid. forced_params (dict, optional): Dict of key-value pairs specifying hyperparameter values that should be fixed across runs. summary_keys (list, optional): If provided, those mean and std of those summary keys will be written by function :func:`write_seeds_summary`. Otherwise, the performance key defined in ``grid_module`` will be used. summary_sem (bool): Whether SEM or SD should be calculated in function :func:`write_seeds_summary`. summary_precs (list or int, optional): The precision with which the summary statistics according to ``summary_keys`` should be listed. hpmod_path (str, optional): If the hpsearch doesn't reside in the same directory as the calling script, then we need to know from where to start the hpsearch. """ if ignore_kwds is None: ignore_kwds = [] if forced_params is None: forced_params = {} ### Parse the command-line arguments. parser = argparse.ArgumentParser(description= \ 'Gathering random seeds for the specified experiment.') parser.add_argument('--seeds_dir', type=str, default='', help='If provided, all other arguments (except ' + '"grid_module") are ignored! ' + 'This is supposed to be the output folder of a ' + 'random seed gathering experiment. If provided, ' + 'the results (for different seeds) within this ' + 'directory are gathered and written to a human-' + 'readible text file.') parser.add_argument('--run_dir', type=str, default='', help='The output directory of a simulation or a ' + 'hyperparameter search. ' 'For single runs, the configuration will be ' + 'loaded and run with different seeds.' + 'For multiple runs, i.e. results of ' + 'hyperparameter searches, the configuration ' + 'leading to the best performance will be ' + 'selected and run with different seeds.') parser.add_argument('--config_name', type=str, default='hpsearch_random_seeds', help='A name for this call of gathering random ' + 'seeds. As multiple gatherings might be running ' + 'in parallel, it is important that this name is ' + 'unique name for each experiment. ' + 'Default: %(default)s.') parser.add_argument('--grid_module', type=str, default=grid_module, help='See CLI argument "grid_module" of ' + 'hyperparameter search script "hpsearch". ' + ('Default: %(default)s.' \ if grid_module is not None else '')) parser.add_argument('--num_seeds', type=int, default=10, help='The number of different random seeds.') parser.add_argument('--seeds_list', type=str, default='', help='The list of seeds to use. If specified, ' + '"num_seeds" will be ignored.') parser.add_argument('--vary_data_seed', action='store_true', help='If activated, "data_random_seed"s are set ' + 'equal to "random_seed"s. Otherwise only ' + '"random_seed"s are varied.') parser.add_argument('--start_gathering', action='store_true', help='If activated, the actual gathering of random ' + 'seeds is started via the "hpsearch.py" script.') # Arguments only required if `start_gathering`. hpgroup = parser.add_argument_group('Hpsearch call options') hpgroup.add_argument('--hps_num_hours', type=int, metavar='N', default=24, help='If "run_cluster" is activated, then this ' + 'option determines the maximum number of hours ' + 'the entire search may run on the cluster. ' + 'Default: %(default)s.') hpgroup.add_argument( '--hps_resources', type=str, default='"rusage[mem=8000]"', help='If "run_cluster" is activated and "scheduler" ' + 'is "lsf", then this option determines the ' + 'resources assigned to the entire ' + 'hyperparameter search (option -R of bsub). ' + 'Default: %(default)s.') hpgroup.add_argument('--hps_slurm_mem', type=str, default='8G', help='See option "slum_mem". This argument effects ' + 'hyperparameter search itself. ' 'Default: %(default)s.') rsgroup = parser.add_argument_group('Random seed hpsearch options') hpsearch.hpsearch_cli_arguments(rsgroup, show_out_dir=False, show_grid_module=False) cmd_args = parser.parse_args() grid_module = cmd_args.grid_module if grid_module is None: raise ValueError('"grid_module" needs to be specified.') grid_module = importlib.import_module(grid_module) hpsearch._read_config(grid_module, require_perf_eval_handle=True) if summary_keys is None: summary_keys = [hpsearch._PERFORMANCE_KEY] #################################################### ### Aggregate results of random seed experiments ### #################################################### if len(cmd_args.seeds_dir): print('Writing seed summary ...') write_seeds_summary(cmd_args.seeds_dir, summary_keys, summary_sem, summary_precs) exit(0) ####################################################### ### Create hp config grid for random seed gathering ### ####################################################### if len(cmd_args.seeds_list) > 0: seeds_list = misc.str_to_ints(cmd_args.seeds_list) cmd_args.num_seeds = len(seeds_list) else: seeds_list = list(range(cmd_args.num_seeds)) if config is not None and cmd_args.run_dir != '': raise ValueError('"run_dir" may not be specified if configuration ' + 'is provided directly.') # The directory in which the hpsearch results should be written. Will only # be specified if the `config` is read from a finished simulation. hpsearch_dir = None # Get config if not provided. if config is None: if not os.path.exists(cmd_args.run_dir): raise_error = True # FIXME hacky solution. if cmd_args.run_cwd != '': tmp_dir = os.path.join(cmd_args.run_cwd, cmd_args.run_dir) if os.path.exists(tmp_dir): cmd_args.run_dir = tmp_dir raise_error = False if raise_error: raise ValueError('Directory "%s" does not exist!' % \ cmd_args.run_dir) # FIXME A bit of a shady decision. single_run = False if os.path.exists(os.path.join(cmd_args.run_dir, 'config.pickle')): single_run = True # Get the configuration. if single_run: config = get_single_run_config(cmd_args.run_dir) run_dir = cmd_args.run_dir else: config, run_dir = get_best_hpsearch_config(cmd_args.run_dir) # We should already have one random seed. try: performance_dict = hpsearch._SUMMARY_PARSER_HANDLE(run_dir, -1) has_finished = int(performance_dict['finished'][0]) if not has_finished: raise Exception() use_run = True except: use_run = False if use_run: # The following ensures that we can safely use `basename` later on. run_dir = os.path.normpath(run_dir) if not os.path.isabs(results_dir): if os.path.isdir(cmd_args.run_cwd): results_dir = os.path.join(cmd_args.run_cwd, results_dir) results_dir = os.path.abspath(results_dir) hpsearch_dir = os.path.join(results_dir, os.path.basename(run_dir)) if os.path.exists(hpsearch_dir): # TODO attempt to write summary and exclude existing seeds. warn('Folder "%s" already exists.' % hpsearch_dir) print('Attempting to aggregate random seed results ...') gathered_seeds = write_seeds_summary(hpsearch_dir, summary_keys, summary_sem, summary_precs, ret_seeds=True) if len(gathered_seeds) >= len(seeds_list): print('Already enough seeds have been gathered!') exit(0) for gs in gathered_seeds: if gs in seeds_list: seeds_list.remove(gs) else: ignored_seed = seeds_list.pop() if len(cmd_args.seeds_list) > 0: print('Seed %d is ignored as seed %d already ' \ % (ignored_seed, gs) + 'exists.') else: os.makedirs(hpsearch_dir) # We utilize the already existing random seed. shutil.copytree( run_dir, os.path.join(hpsearch_dir, os.path.basename(run_dir))) if config.random_seed in seeds_list: seeds_list.remove(config.random_seed) else: ignored_seed = seeds_list.pop() if len(cmd_args.seeds_list) > 0: print('Seed %d is ignored as seed %d already exists.' \ % (ignored_seed, config.random_seed)) print('%d random seeds will be gathered!' % len(seeds_list)) ### Which attributes of the `config` should be ignored? # We never set the ouput directory. if hpsearch._OUT_ARG not in ignore_kwds: ignore_kwds.append(hpsearch._OUT_ARG) for kwd in ignore_kwds: delattr(config, kwd) ### Replace config values provided via `forced_params`. if len(forced_params.keys()) > 0: for kwd, value in forced_params.items(): setattr(config, kwd, value) ### Get a filename for where to store the search grid. config_dn, config_bn = os.path.split(cmd_args.config_name) if len(config_dn) == 0: # No relative path given, store only temporary. config_dn = tempfile.gettempdir() else: config_dn = os.path.abspath(config_dn) config_fn_prefix = os.path.splitext(config_bn)[0] config_name = os.path.join(config_dn, config_fn_prefix + '.pickle') if os.path.exists(config_name): if len(config_dn) > 0: overwrite = input('The config file "%s" ' % config_name + \ 'already exists! Do you want to overwrite the file? [y/n] ') if not overwrite in ['yes', 'y', 'Y']: exit(1) else: # Get random temporary filename. config_name_temp = tempfile.NamedTemporaryFile( \ prefix=config_fn_prefix, suffix=".pickle") print('Search grid "%s" already exists, using name "%s" instead!' \ % (config_name, config_name_temp.name)) config_name = config_name_temp.name config_name_temp.close() ### Build and store hpconfig for random seed gathering! grid, conditions = build_grid_and_conditions(cmd_args, config, seeds_list) rseed_config = {'grid': grid, 'conditions': conditions} with open(config_name, 'wb') as f: pickle.dump(rseed_config, f) ### Gather random seeds. if cmd_args.start_gathering: cmd_str = get_hpsearch_call(cmd_args, len(seeds_list), config_name, hpsearch_dir=hpsearch_dir) print(cmd_str) ### Start hpsearch. if hpmod_path is not None: backup_curr_path = os.getcwd() os.chdir(hpmod_path) if cmd_args.run_cluster and cmd_args.scheduler == 'slurm': # FIXME hacky solution to write SLURM job script. # FIXME might be wrong to give the same `slurm_qos` to the hpsearch, # as the job might have to run much longer. job_script_fn = hpsearch._write_slurm_script( Namespace( **{ 'num_hours': cmd_args.hps_num_hours, 'slurm_mem': cmd_args.hps_slurm_mem, 'slurm_gres': 'gpu:0', 'slurm_partition': cmd_args.slurm_partition, 'slurm_qos': cmd_args.slurm_qos, 'slurm_constraint': cmd_args.slurm_constraint, }), cmd_str, 'random_seeds') cmd_str = 'sbatch %s' % job_script_fn print('We will execute command "%s".' % cmd_str) # Execute the program. print('Starting gathering random seeds...') ret = call(cmd_str, shell=True, executable='/bin/bash') print('Call finished with return code %d.' % ret) if hpmod_path is not None: os.chdir(backup_curr_path) # If we run the hpsearch on the cluster, then we just submitted a job # and the search didn't actually run yet. if not cmd_args.run_cluster and hpsearch_dir is not None: write_seeds_summary(hpsearch_dir, summary_keys, summary_sem, summary_precs) print('Random seed gathering finished successfully!') exit(0) ### Random seeds not gathered yet - finalize program. print(hpsearch_dir is None) if hpsearch_dir is not None: print('IMPORTANT: At least one random seed has already been ' + \ 'gathered! Please ensure that the hpsearch forces the correct ' + 'output path.') print('Below is a possible hpsearch call:') call_appendix = '' if hpsearch_dir is not None: call_appendix = '--force_out_dir --dont_force_new_dir ' + \ '--out_dir=%s' % hpsearch_dir print() print('python3 hpsearch.py --grid_module=%s --grid_config=%s %s' % \ (cmd_args.grid_module, config_name, call_appendix)) print() # We print the individual paths to allow easy parsing via `awk` and `xargs`. if hpsearch_dir is None: print('Below is the "grid_module" name and the path to the ' + '"grid_config".') print(cmd_args.grid_module, config_name) else: print( 'Below is the "grid_module" name, the path to the ' + '"grid_config" and the output path that should be used for the ' + 'hpsearch.') print(cmd_args.grid_module, config_name, hpsearch_dir)