def generate_classifier(config, data_handlers, device): """Create a classifier network. Depending on the experiment and method, the method manages to build either a classifier for task inference or a classifier that solves our task is build. This also implies if the network will receive weights from a hypernetwork or will have weights on its own. Following important configurations will be determined in order to create the classifier: \n * in- and output and hidden layer dimensions of the classifier. \n * architecture, chunk- and task-embedding details of the hypernetwork. See :class:`mnets.mlp.MLP` for details on the network that will be created to be a classifier. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. Returns: (tuple): Tuple containing: - **net**: The classifier network. - **class_hnet**: (optional) The classifier's hypernetwork. """ n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.class_incremental: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.training_task_infer or config.class_incremental: # task inference network config.out_dim = 1 # have all output neurons already build up for cl 2 if config.cl_scenario != 2: n_out = config.out_dim * config.num_tasks else: n_out = config.out_dim if config.training_task_infer or config.class_incremental: n_out = config.num_tasks # build classifier print('For the Classifier: ') class_arch = misc.str_to_ints(config.class_fc_arch) if config.training_with_hnet: no_weights = True else: no_weights = False net = MLP(n_in=n_in, n_out=n_out, hidden_layers=class_arch, activation_fn=misc.str_to_act(config.class_net_act), dropout_rate=config.class_dropout_rate, no_weights=no_weights).to(device) print('Constructed MLP with shapes: ', net.param_shapes) config.num_weights_class_net = \ MainNetInterface.shapes_to_num_weights(net.param_shapes) # build classifier hnet # this is set in the run method in train.py if config.training_with_hnet: class_hnet = sim_utils.get_hnet_model(config, config.num_tasks, device, net.param_shapes, cprefix='class_') init_params = list(class_hnet.parameters()) config.num_weights_class_hyper_net = sum( p.numel() for p in class_hnet.parameters() if p.requires_grad) config.compression_ratio_class = config.num_weights_class_hyper_net / \ config.num_weights_class_net print('Created classifier Hypernetwork with ratio: ', config.compression_ratio_class) if config.compression_ratio_class > 1: print('Note that the compression ratio is computed compared to ' + 'current target network, not might not be directly ' + 'comparable with the number of parameters of work we ' + 'compare against.') else: class_hnet = None init_params = list(net.parameters()) config.num_weights_class_hyper_net = None config.compression_ratio_class = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if config.training_with_hnet: for temb in class_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(class_hnet, 'chunk_embeddings'): for emb in class_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) if not config.training_with_hnet: return net else: return net, class_hnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. Returns: The created main network model. """ assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net']) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # FIXME if an argument wasn't specified, then we use the default value that # is currently (at time of implementation) in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, torch.nn.ReLU()), use_bias=not assign(no_bias, False), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, -1), use_spectral_norm=assign(specnorm, False), use_batch_norm=assign(use_bn, False), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, True), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False ).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, 0.25)).to(device) else: assert (net_type == 'bio_conv_net') assert (len(out_shape) == 1) raise NotImplementedError('Implementation not publicly available!') return mnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False, **mnet_kwargs): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``wrn``: :class:`mnets.wide_resnet.WRN` - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer` - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. **mnet_kwargs: Additional keyword arguments that will be passed to the main network constructor. Returns: The created main network model. """ assert (net_type in [ 'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp', 'simple_rnn', 'wrn', 'iresnet' ]) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') # This argument has to be handled during usage of the network and not during # construction. #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) # Default keyword arguments of class MLP. dkws = misc.get_default_args(MLP.__init__) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True, **mnet_kwargs).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) assert hc('resnet_block_depth') and hc('resnet_channel_sizes') # Default keyword arguments of class ResNet. dkws = misc.get_default_args(ResNet.__init__) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], n=gc('resnet_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')), verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'wrn': assert (len(out_shape) == 1) assert hc('wrn_block_depth') and hc('wrn_widening_factor') # Default keyword arguments of class WRN. dkws = misc.get_default_args(WRN.__init__) mnet = WRN( in_shape=in_shape, num_classes=out_shape[0], n=gc('wrn_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')), verbose=True, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), k=gc('wrn_widening_factor'), use_fc_bias=gc('wrn_use_fc_bias'), dropout_rate=gc('dropout_rate'), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'iresnet': assert (len(out_shape) == 1) assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \ and hc('iresnet_blocks_per_group') \ and hc('iresnet_bottleneck_blocks') \ and hc('iresnet_projection_shortcut') # Default keyword arguments of class WRN. dkws = misc.get_default_args(ResNetIN.__init__) mnet = ResNetIN( in_shape=in_shape, num_classes=out_shape[0], use_bias=assign(not no_bias, dkws['use_bias']), use_fc_bias=gc('wrn_use_fc_bias'), num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')), blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')), projection_shortcut=gc('iresnet_projection_shortcut'), bottleneck_blocks=gc('iresnet_bottleneck_blocks'), #cutout_mod=False, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #chw_input_format=False, verbose=True, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) # Default keyword arguments of class ZenkeNet. dkws = misc.get_default_args(ZenkeNet.__init__) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), **mnet_kwargs).to(device) elif net_type == 'bio_conv_net': assert (len(out_shape) == 1) # Default keyword arguments of class BioConvNet. #dkws = misc.get_default_args(BioConvNet.__init__) mnet = BioConvNet( in_shape=in_shape, num_classes=out_shape[0], no_weights=no_weights, #init_weights=None, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'chunked_mlp': assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \ hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \ hc('cmlp_cemb_dim') assert len(in_shape) == 1 and len(out_shape) == 1 # Default keyword arguments of class ChunkSqueezer. dkws = misc.get_default_args(ChunkSqueezer.__init__) mnet = ChunkSqueezer( n_in=in_shape[0], n_out=out_shape[0], inp_chunk_dim=gc('cmlp_in_cdim'), out_chunk_dim=gc('cmlp_out_cdim'), cemb_size=gc('cmlp_cemb_dim'), #cemb_init_std=1., red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')), net_layers=misc.str_to_ints(gc('cmlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), #dynamic_biases=None, no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), verbose=True, **mnet_kwargs).to(device) elif net_type == 'lenet': assert hc('lenet_type') assert len(out_shape) == 1 # Default keyword arguments of class LeNet. dkws = misc.get_default_args(LeNet.__init__) mnet = LeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, arch=gc('lenet_type'), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), # TODO Context-mod weights. **mnet_kwargs).to(device) else: assert (net_type == 'simple_rnn') assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \ hc('srnn_post_fc_layers') and hc('srnn_no_fc_out') and \ hc('srnn_rec_type') assert len(in_shape) == 1 and len(out_shape) == 1 if gc('srnn_rec_type') == 'lstm': use_lstm = True else: assert gc('srnn_rec_type') == 'elman' use_lstm = False # Default keyword arguments of class SimpleRNN. dkws = misc.get_default_args(SimpleRNN.__init__) rnn_layers = misc.str_to_ints(gc('srnn_rec_layers')) fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers')) if gc('srnn_no_fc_out'): rnn_layers.append(out_shape[0]) else: fc_layers.append(out_shape[0]) mnet = SimpleRNN(n_in=in_shape[0], rnn_layers=rnn_layers, fc_layers_pre=misc.str_to_ints( gc('srnn_pre_fc_layers')), fc_layers=fc_layers, activation=assign(net_act, dkws['activation']), use_lstm=use_lstm, use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, verbose=True, **mnet_kwargs).to(device) return mnet
def generate_replay_networks(config, data_handlers, device, create_rp_hnet=True, only_train_replay=False): """Create a replay model that consists of either a encoder/decoder or a discriminator/generator pair. Additionally, this method manages the creation of a hypernetwork for the generator/decoder. Following important configurations will be determined in order to create the replay model: * in- and output and hidden layer dimensions of the encoder/decoder. * architecture, chunk- and task-embedding details of decoder's hypernetwork. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device.. create_rp_hnet: Whether a hypernetwork for the replay should be constructed. If not, the decoder/generator will have trainable weights on its own. only_train_replay: We normally do not train on the last task since we do not need to replay this last tasks data. But if we want a replay method to be able to generate data from all tasks then we set this option to true. Returns: (tuple): Tuple containing: - **enc**: Encoder/discriminator network instance. - **dec**: Decoder/generator networkinstance. - **dec_hnet**: Hypernetwork instance for the decoder/generator. This return value is None if no hypernetwork should be constructed. """ if config.replay_method == 'gan': n_out = 1 else: n_out = config.latent_dim * 2 n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.single_class_replay: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.infer_task_id: # task inference network config.out_dim = 1 # builld encoder print('For the replay encoder/discriminator: ') enc_arch = misc.str_to_ints(config.enc_fc_arch) enc = MLP(n_in=n_in, n_out=n_out, hidden_layers=enc_arch, activation_fn=misc.str_to_act(config.enc_net_act), dropout_rate=config.enc_dropout_rate, no_weights=False).to(device) print('Constructed MLP with shapes: ', enc.param_shapes) init_params = list(enc.parameters()) # builld decoder print('For the replay decoder/generator: ') dec_arch = misc.str_to_ints(config.dec_fc_arch) # add dimensions for conditional input n_out = config.latent_dim if config.conditional_replay: n_out += config.conditional_dim dec = MLP(n_in=n_out, n_out=n_in, hidden_layers=dec_arch, activation_fn=misc.str_to_act(config.dec_net_act), use_bias=True, no_weights=config.rp_beta > 0, dropout_rate=config.dec_dropout_rate).to(device) print('Constructed MLP with shapes: ', dec.param_shapes) config.num_weights_enc = \ MainNetInterface.shapes_to_num_weights(enc.param_shapes) config.num_weights_dec = \ MainNetInterface.shapes_to_num_weights(dec.param_shapes) config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec # we do not need a replay model for the last task # train on last task or not if only_train_replay: subtr = 0 else: subtr = 1 num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1 if config.single_class_replay: # we do not need a replay model for the last task if config.num_tasks > 1: num_embeddings = config.out_dim * (config.num_tasks - subtr) else: num_embeddings = config.out_dim * (config.num_tasks) config.num_embeddings = num_embeddings # build decoder hnet if create_rp_hnet: print('For the decoder/generator hypernetwork: ') d_hnet = sim_utils.get_hnet_model(config, num_embeddings, device, dec.hyper_shapes_learned, cprefix='rp_') init_params += list(d_hnet.parameters()) config.num_weights_rp_hyper_net = sum(p.numel() for p in d_hnet.parameters() if p.requires_grad) config.compression_ratio_rp = config.num_weights_rp_hyper_net / \ config.num_weights_dec print('Created replay hypernetwork with ratio: ', config.compression_ratio_rp) if config.compression_ratio_rp > 1: print('Note that the compression ratio is computed compared to ' + 'current target network,\nthis might not be directly ' + 'comparable with the number of parameters of methods we ' + 'compare against.') else: num_embeddings = config.num_tasks - subtr d_hnet = None init_params += list(dec.parameters()) config.num_weights_rp_hyper_net = 0 config.compression_ratio_rp = 0 ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_rp_hnet: for temb in d_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(d_hnet, 'chunk_embeddings'): for emb in d_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) return enc, dec, d_hnet
def __init__(self, rnn_args={}, mlp_args=None, preprocess_fct=None, no_weights=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) assert isinstance(rnn_args, (dict, list, tuple)) assert mlp_args is None or isinstance(mlp_args, dict) if isinstance(rnn_args, dict): rnn_args = [rnn_args] self._forward_rnns = [] self._backward_rnns = [] self._out_mlp = None self._preprocess_fct = preprocess_fct self._forward_called = False # FIXME At the moment we do not control input and output size of # individual networks and need to assume that the user sets them # correctly. ### Create all forward and backward nets for each bidirectional layer. for rargs in rnn_args: assert isinstance(rargs, dict) if 'verbose' not in rargs.keys(): rargs['verbose'] = False if 'no_weights' in rargs.keys() and \ rargs['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'bidirectional layer is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in rargs.keys(): rargs['no_weights'] = no_weights self._forward_rnns.append(SimpleRNN(**rargs)) self._backward_rnns.append(SimpleRNN(**rargs)) ### Create output network. if mlp_args is not None: if 'verbose' not in mlp_args.keys(): mlp_args['verbose'] = False if 'no_weights' in mlp_args.keys() and \ mlp_args['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'output MLP is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in mlp_args.keys(): mlp_args['no_weights'] = no_weights self._out_mlp = MLP(**mlp_args) ### Set all interface attributes correctly. if self._out_mlp is None: self._has_fc_out = self._forward_rnns[-1].has_fc_out # We can't set the following attribute to true, as the output is # a concatenation of the outputs from two networks. Therefore, the # weights used two compute the outputs are at different locations # in the `param_shapes` list. self._mask_fc_out = False self._has_linear_out = self._forward_rnns[-1].has_linear_out else: self._has_fc_out = self._out_mlp.has_fc_out self._mask_fc_out = self._out_mlp.mask_fc_out self._has_linear_out = self._out_mlp.has_linear_out # Collect all internal net objects from which we need to collect # attributes. nets = [] for i, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[i] nets.append((fnet, 'forward_rnn', i)) nets.append((bnet, 'backward_rnn', i)) if self._out_mlp is not None: nets.append((self._out_mlp, 'out_mlp', -1)) # Iterate over all nets to collect their attribute values. self._param_shapes = [] self._param_shapes_meta = [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() for i, net_tup in enumerate(nets): net, net_type, net_id = net_tup # Note, it is important to convert lists into new object and not # just copy references! # Note, we have to adapt all references if `i > 0`. # Sanity check: if i == 0: cm_nw = net._context_mod_no_weights elif cm_nw != net._context_mod_no_weights: raise ValueError('Network expect that either all internal ' + 'networks maintain their context-mod ' + 'weights or non of them does!') ps_len_old = len(self._param_shapes) if net._internal_params is not None: if self._internal_params is None: self._internal_params = nn.ParameterList() ip_len_old = len(self._internal_params) self._internal_params.extend( \ nn.ParameterList(net._internal_params)) self._param_shapes.extend(list(net._param_shapes)) for meta in net.param_shapes_meta: assert 'birnn_layer_type' not in meta.keys() assert 'birnn_layer_id' not in meta.keys() new_meta = dict(meta) new_meta['birnn_layer_type'] = net_type new_meta['birnn_layer_id'] = net_id if i > 0: # FIXME We should properly adjust colliding `layer` IDs. new_meta['layer'] = -1 new_meta['index'] = meta['index'] + ip_len_old self._param_shapes_meta.append(new_meta) if net._hyper_shapes_learned is not None: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = [] self._hyper_shapes_learned_ref = [] self._hyper_shapes_learned.extend( \ list(net._hyper_shapes_learned)) for ref in net._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ref + ps_len_old) if net._hyper_shapes_distilled is not None: if self._hyper_shapes_distilled is None: self._hyper_shapes_distilled = [] self._hyper_shapes_distilled.extend( \ list(net._hyper_shapes_distilled)) if self._has_bias is None: self._has_bias = net._has_bias elif self._has_bias != net._has_bias: self._has_bias = False # FIXME We should overwrite the getter and throw an error! warn('Some internally maintained networks use biases, ' + 'while others don\'t. Setting attribute "has_bias" to ' + 'False.') self._layer_weight_tensors.extend( \ nn.ParameterList(net._layer_weight_tensors)) self._layer_bias_vectors.extend( \ nn.ParameterList(net._layer_bias_vectors)) if net._batchnorm_layers is not None: if self._batchnorm_layers is None: self._batchnorm_layers = nn.ModuleList() self._batchnorm_layers.extend( \ nn.ModuleList(net._batchnorm_layers)) if net._context_mod_layers is not None: if self._context_mod_layers is None: self._context_mod_layers = nn.ModuleList() self._context_mod_layers.extend( \ nn.ModuleList(net._context_mod_layers)) self._is_properly_setup() ### Print user information. if verbose: print('Constructed Bidirectional RNN with %d weights.' \ % self.num_params)
def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup()