def generate_classifier(config, data_handlers, device): """Create a classifier network. Depending on the experiment and method, the method manages to build either a classifier for task inference or a classifier that solves our task is build. This also implies if the network will receive weights from a hypernetwork or will have weights on its own. Following important configurations will be determined in order to create the classifier: \n * in- and output and hidden layer dimensions of the classifier. \n * architecture, chunk- and task-embedding details of the hypernetwork. See :class:`mnets.mlp.MLP` for details on the network that will be created to be a classifier. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. Returns: (tuple): Tuple containing: - **net**: The classifier network. - **class_hnet**: (optional) The classifier's hypernetwork. """ n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.class_incremental: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.training_task_infer or config.class_incremental: # task inference network config.out_dim = 1 # have all output neurons already build up for cl 2 if config.cl_scenario != 2: n_out = config.out_dim * config.num_tasks else: n_out = config.out_dim if config.training_task_infer or config.class_incremental: n_out = config.num_tasks # build classifier print('For the Classifier: ') class_arch = misc.str_to_ints(config.class_fc_arch) if config.training_with_hnet: no_weights = True else: no_weights = False net = MLP(n_in=n_in, n_out=n_out, hidden_layers=class_arch, activation_fn=misc.str_to_act(config.class_net_act), dropout_rate=config.class_dropout_rate, no_weights=no_weights).to(device) print('Constructed MLP with shapes: ', net.param_shapes) config.num_weights_class_net = \ MainNetInterface.shapes_to_num_weights(net.param_shapes) # build classifier hnet # this is set in the run method in train.py if config.training_with_hnet: class_hnet = sim_utils.get_hnet_model(config, config.num_tasks, device, net.param_shapes, cprefix='class_') init_params = list(class_hnet.parameters()) config.num_weights_class_hyper_net = sum( p.numel() for p in class_hnet.parameters() if p.requires_grad) config.compression_ratio_class = config.num_weights_class_hyper_net / \ config.num_weights_class_net print('Created classifier Hypernetwork with ratio: ', config.compression_ratio_class) if config.compression_ratio_class > 1: print('Note that the compression ratio is computed compared to ' + 'current target network, not might not be directly ' + 'comparable with the number of parameters of work we ' + 'compare against.') else: class_hnet = None init_params = list(net.parameters()) config.num_weights_class_hyper_net = None config.compression_ratio_class = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if config.training_with_hnet: for temb in class_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(class_hnet, 'chunk_embeddings'): for emb in class_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) if not config.training_with_hnet: return net else: return net, class_hnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False, **mnet_kwargs): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``wrn``: :class:`mnets.wide_resnet.WRN` - ``iresnet``: :class:`mnets.resnet_imgnet.ResNetIN` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` - ``chunked_mlp``: :class:`mnets.chunk_squeezer.ChunkSqueezer` - ``simple_rnn``: :class:`mnets.simple_rnn.SimpleRNN` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. **mnet_kwargs: Additional keyword arguments that will be passed to the main network constructor. Returns: The created main network model. """ assert (net_type in [ 'mlp', 'lenet', 'resnet', 'zenke', 'bio_conv_net', 'chunked_mlp', 'simple_rnn', 'wrn', 'iresnet' ]) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') # This argument has to be handled during usage of the network and not during # construction. #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # If an argument wasn't specified, then we use the default value that # is currently in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) # Default keyword arguments of class MLP. dkws = misc.get_default_args(MLP.__init__) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True, **mnet_kwargs).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) assert hc('resnet_block_depth') and hc('resnet_channel_sizes') # Default keyword arguments of class ResNet. dkws = misc.get_default_args(ResNet.__init__) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], n=gc('resnet_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), num_feature_maps=misc.str_to_ints(gc('resnet_channel_sizes')), verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'wrn': assert (len(out_shape) == 1) assert hc('wrn_block_depth') and hc('wrn_widening_factor') # Default keyword arguments of class WRN. dkws = misc.get_default_args(WRN.__init__) mnet = WRN( in_shape=in_shape, num_classes=out_shape[0], n=gc('wrn_block_depth'), use_bias=assign(not no_bias, dkws['use_bias']), #num_feature_maps=misc.str_to_ints(gc('wrn_channel_sizes')), verbose=True, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), k=gc('wrn_widening_factor'), use_fc_bias=gc('wrn_use_fc_bias'), dropout_rate=gc('dropout_rate'), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'iresnet': assert (len(out_shape) == 1) assert hc('iresnet_use_fc_bias') and hc('iresnet_channel_sizes') \ and hc('iresnet_blocks_per_group') \ and hc('iresnet_bottleneck_blocks') \ and hc('iresnet_projection_shortcut') # Default keyword arguments of class WRN. dkws = misc.get_default_args(ResNetIN.__init__) mnet = ResNetIN( in_shape=in_shape, num_classes=out_shape[0], use_bias=assign(not no_bias, dkws['use_bias']), use_fc_bias=gc('wrn_use_fc_bias'), num_feature_maps=misc.str_to_ints(gc('iresnet_channel_sizes')), blocks_per_group=misc.str_to_ints(gc('iresnet_blocks_per_group')), projection_shortcut=gc('iresnet_projection_shortcut'), bottleneck_blocks=gc('iresnet_bottleneck_blocks'), #cutout_mod=False, no_weights=no_weights, use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), #chw_input_format=False, verbose=True, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) # Default keyword arguments of class ZenkeNet. dkws = misc.get_default_args(ZenkeNet.__init__) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), **mnet_kwargs).to(device) elif net_type == 'bio_conv_net': assert (len(out_shape) == 1) # Default keyword arguments of class BioConvNet. #dkws = misc.get_default_args(BioConvNet.__init__) mnet = BioConvNet( in_shape=in_shape, num_classes=out_shape[0], no_weights=no_weights, #init_weights=None, #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False **mnet_kwargs).to(device) elif net_type == 'chunked_mlp': assert hc('cmlp_arch') and hc('cmlp_chunk_arch') and \ hc('cmlp_in_cdim') and hc('cmlp_out_cdim') and \ hc('cmlp_cemb_dim') assert len(in_shape) == 1 and len(out_shape) == 1 # Default keyword arguments of class ChunkSqueezer. dkws = misc.get_default_args(ChunkSqueezer.__init__) mnet = ChunkSqueezer( n_in=in_shape[0], n_out=out_shape[0], inp_chunk_dim=gc('cmlp_in_cdim'), out_chunk_dim=gc('cmlp_out_cdim'), cemb_size=gc('cmlp_cemb_dim'), #cemb_init_std=1., red_layers=misc.str_to_ints(gc('cmlp_chunk_arch')), net_layers=misc.str_to_ints(gc('cmlp_arch')), activation_fn=assign(net_act, dkws['activation_fn']), use_bias=assign(not no_bias, dkws['use_bias']), #dynamic_biases=None, no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), use_spectral_norm=assign(specnorm, dkws['use_spectral_norm']), use_batch_norm=assign(use_bn, dkws['use_batch_norm']), bn_track_stats=assign(not bn_no_running_stats, dkws['bn_track_stats']), distill_bn_stats=assign(bn_distill_stats, dkws['distill_bn_stats']), verbose=True, **mnet_kwargs).to(device) elif net_type == 'lenet': assert hc('lenet_type') assert len(out_shape) == 1 # Default keyword arguments of class LeNet. dkws = misc.get_default_args(LeNet.__init__) mnet = LeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, arch=gc('lenet_type'), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, dkws['dropout_rate']), # TODO Context-mod weights. **mnet_kwargs).to(device) else: assert (net_type == 'simple_rnn') assert hc('srnn_rec_layers') and hc('srnn_pre_fc_layers') and \ hc('srnn_post_fc_layers') and hc('srnn_no_fc_out') and \ hc('srnn_rec_type') assert len(in_shape) == 1 and len(out_shape) == 1 if gc('srnn_rec_type') == 'lstm': use_lstm = True else: assert gc('srnn_rec_type') == 'elman' use_lstm = False # Default keyword arguments of class SimpleRNN. dkws = misc.get_default_args(SimpleRNN.__init__) rnn_layers = misc.str_to_ints(gc('srnn_rec_layers')) fc_layers = misc.str_to_ints(gc('srnn_post_fc_layers')) if gc('srnn_no_fc_out'): rnn_layers.append(out_shape[0]) else: fc_layers.append(out_shape[0]) mnet = SimpleRNN(n_in=in_shape[0], rnn_layers=rnn_layers, fc_layers_pre=misc.str_to_ints( gc('srnn_pre_fc_layers')), fc_layers=fc_layers, activation=assign(net_act, dkws['activation']), use_lstm=use_lstm, use_bias=assign(not no_bias, dkws['use_bias']), no_weights=no_weights, verbose=True, **mnet_kwargs).to(device) return mnet
def get_mnet_model(config, net_type, in_shape, out_shape, device, cprefix=None, no_weights=False): """Generate a main network instance. A helper to generate a main network according to the given the user configurations. .. note:: Generation of networks with context-modulation is not yet supported, since there is no global argument set in :mod:`utils.cli_args` yet. Args: config (argparse.Namespace): Command-line arguments. .. note:: The function expects command-line arguments available according to the function :func:`utils.cli_args.main_net_args`. net_type (str): The type of network. The following options are available: - ``mlp``: :class:`mnets.mlp.MLP` - ``resnet``: :class:`mnets.resnet.ResNet` - ``zenke``: :class:`mnets.zenkenet.ZenkeNet` - ``bio_conv_net``: :class:`mnets.bio_conv_net.BioConvNet` in_shape (list): Shape of network inputs. Can be ``None`` if not required by network type. For instance: For an MLP network :class:`mnets.mlp.MLP` with 100 input neurons it should be :code:`in_shape=[100]`. out_shape (list): Shape of network outputs. See ``in_shape`` for more details. device: PyTorch device. cprefix (str, optional): A prefix of the config names. It might be, that the config names used in this method are prefixed, since several main networks should be generated (e.g., :code:`cprefix='gen_'` or ``'dis_'`` when training a GAN). Also see docstring of parameter ``prefix`` in function :func:`utils.cli_args.main_net_args`. no_weights (bool): Whether the main network should be generated without weights. Returns: The created main network model. """ assert (net_type in ['mlp', 'resnet', 'zenke', 'bio_conv_net']) if cprefix is None: cprefix = '' def gc(name): """Get config value with that name.""" return getattr(config, '%s%s' % (cprefix, name)) def hc(name): """Check whether config exists.""" return hasattr(config, '%s%s' % (cprefix, name)) mnet = None if hc('net_act'): net_act = gc('net_act') net_act = misc.str_to_act(net_act) else: net_act = None def get_val(name): ret = None if hc(name): ret = gc(name) return ret no_bias = get_val('no_bias') dropout_rate = get_val('dropout_rate') specnorm = get_val('specnorm') batchnorm = get_val('batchnorm') no_batchnorm = get_val('no_batchnorm') bn_no_running_stats = get_val('bn_no_running_stats') bn_distill_stats = get_val('bn_distill_stats') #bn_no_stats_checkpointing = get_val('bn_no_stats_checkpointing') use_bn = None if batchnorm is not None: use_bn = batchnorm elif no_batchnorm is not None: use_bn = not no_batchnorm # FIXME if an argument wasn't specified, then we use the default value that # is currently (at time of implementation) in the constructor. assign = lambda x, y: y if x is None else x if net_type == 'mlp': assert (hc('mlp_arch')) assert (len(in_shape) == 1 and len(out_shape) == 1) mnet = MLP( n_in=in_shape[0], n_out=out_shape[0], hidden_layers=misc.str_to_ints(gc('mlp_arch')), activation_fn=assign(net_act, torch.nn.ReLU()), use_bias=not assign(no_bias, False), no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, -1), use_spectral_norm=assign(specnorm, False), use_batch_norm=assign(use_bn, False), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #out_fn=None, verbose=True).to(device) elif net_type == 'resnet': assert (len(out_shape) == 1) mnet = ResNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #n=5, no_weights=no_weights, #init_weights=None, use_batch_norm=assign(use_bn, True), bn_track_stats=assign(not bn_no_running_stats, True), distill_bn_stats=assign(bn_distill_stats, False), #use_context_mod=False, #context_mod_inputs=False, #no_last_layer_context_mod=False, #context_mod_no_weights=False, #context_mod_post_activation=False, #context_mod_gain_offset=False, #context_mod_apply_pixel_wise=False ).to(device) elif net_type == 'zenke': assert (len(out_shape) == 1) mnet = ZenkeNet( in_shape=in_shape, num_classes=out_shape[0], verbose=True, #arch='cifar', no_weights=no_weights, #init_weights=None, dropout_rate=assign(dropout_rate, 0.25)).to(device) else: assert (net_type == 'bio_conv_net') assert (len(out_shape) == 1) raise NotImplementedError('Implementation not publicly available!') return mnet
def generate_replay_networks(config, data_handlers, device, create_rp_hnet=True, only_train_replay=False): """Create a replay model that consists of either a encoder/decoder or a discriminator/generator pair. Additionally, this method manages the creation of a hypernetwork for the generator/decoder. Following important configurations will be determined in order to create the replay model: * in- and output and hidden layer dimensions of the encoder/decoder. * architecture, chunk- and task-embedding details of decoder's hypernetwork. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device.. create_rp_hnet: Whether a hypernetwork for the replay should be constructed. If not, the decoder/generator will have trainable weights on its own. only_train_replay: We normally do not train on the last task since we do not need to replay this last tasks data. But if we want a replay method to be able to generate data from all tasks then we set this option to true. Returns: (tuple): Tuple containing: - **enc**: Encoder/discriminator network instance. - **dec**: Decoder/generator networkinstance. - **dec_hnet**: Hypernetwork instance for the decoder/generator. This return value is None if no hypernetwork should be constructed. """ if config.replay_method == 'gan': n_out = 1 else: n_out = config.latent_dim * 2 n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.single_class_replay: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.infer_task_id: # task inference network config.out_dim = 1 # builld encoder print('For the replay encoder/discriminator: ') enc_arch = misc.str_to_ints(config.enc_fc_arch) enc = MLP(n_in=n_in, n_out=n_out, hidden_layers=enc_arch, activation_fn=misc.str_to_act(config.enc_net_act), dropout_rate=config.enc_dropout_rate, no_weights=False).to(device) print('Constructed MLP with shapes: ', enc.param_shapes) init_params = list(enc.parameters()) # builld decoder print('For the replay decoder/generator: ') dec_arch = misc.str_to_ints(config.dec_fc_arch) # add dimensions for conditional input n_out = config.latent_dim if config.conditional_replay: n_out += config.conditional_dim dec = MLP(n_in=n_out, n_out=n_in, hidden_layers=dec_arch, activation_fn=misc.str_to_act(config.dec_net_act), use_bias=True, no_weights=config.rp_beta > 0, dropout_rate=config.dec_dropout_rate).to(device) print('Constructed MLP with shapes: ', dec.param_shapes) config.num_weights_enc = \ MainNetInterface.shapes_to_num_weights(enc.param_shapes) config.num_weights_dec = \ MainNetInterface.shapes_to_num_weights(dec.param_shapes) config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec # we do not need a replay model for the last task # train on last task or not if only_train_replay: subtr = 0 else: subtr = 1 num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1 if config.single_class_replay: # we do not need a replay model for the last task if config.num_tasks > 1: num_embeddings = config.out_dim * (config.num_tasks - subtr) else: num_embeddings = config.out_dim * (config.num_tasks) config.num_embeddings = num_embeddings # build decoder hnet if create_rp_hnet: print('For the decoder/generator hypernetwork: ') d_hnet = sim_utils.get_hnet_model(config, num_embeddings, device, dec.hyper_shapes_learned, cprefix='rp_') init_params += list(d_hnet.parameters()) config.num_weights_rp_hyper_net = sum(p.numel() for p in d_hnet.parameters() if p.requires_grad) config.compression_ratio_rp = config.num_weights_rp_hyper_net / \ config.num_weights_dec print('Created replay hypernetwork with ratio: ', config.compression_ratio_rp) if config.compression_ratio_rp > 1: print('Note that the compression ratio is computed compared to ' + 'current target network,\nthis might not be directly ' + 'comparable with the number of parameters of methods we ' + 'compare against.') else: num_embeddings = config.num_tasks - subtr d_hnet = None init_params += list(dec.parameters()) config.num_weights_rp_hyper_net = 0 config.compression_ratio_rp = 0 ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_rp_hnet: for temb in d_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(d_hnet, 'chunk_embeddings'): for emb in d_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) return enc, dec, d_hnet
class BiRNN(nn.Module, MainNetInterface): r"""Implementation of a bidirectional RNN. Note: The output is non-linear if the last layer is recurrent! Otherwise, logits are returned (cmp. attribute :attr:`mnets.mnet_interface.MainNetInterface.has_fc_out`). Example: Here is an example instantiation of a BiLSTM with a single bidirectional layer of dimensionality 256, assuming 100 dimensional inputs and 10 dimensional outputs. .. code-block:: python net = BiRNN(rnn_args={'n_in': 100, 'rnn_layers': [256], 'use_lstm': True, 'fc_layers_pre': [], 'fc_layers': []}, mlp_args={'n_in': 512, 'n_out': 10, 'hidden_layers': []}, no_weights=False) Attributes: preprocess_fct (func, optional): See constructor argument ``preprocess_fct``. num_rec_layers (int): See attribute :attr:`mnets.simple_rnn.SimpleRNN.num_rec_layers`. Total number of recurrent layer, where each bidirectional layer consists of at least two recurrent layers (forward and backward layer). use_lstm (bool): See attribute :attr:`mnets.simple_rnn.SimpleRNN.use_lstm`. Args: rnn_args (dict or list): A dictionary of arguments for an instance of class :class:`mnets.simple_rnn.SimpleRNN`. These arguments will be used to create two instances of this class, one representing the forward RNN and one the backward RNN. Note, each of these instances may contain multiple layers, even non-recurrent layers. The outputs of such an instance are considered the hidden activations :math:`\hat{h}_{1:T}^{(f)}` or :math:`\hat{h}_{1:T}^{(b)}`, respectively. To realize multiple bidirectional layers (which in itself can be multi-layer RNNs), one may provide a list of dictionaries. Each entry in such list will be used to generate a single bidirectional layer (i.e., consisting of two instances of class :class:`mnets.simple_rnn.SimpleRNN`). Note, the input size of each new layer has to be twice the size of :math:`\hat{h}_t^{(f)}` from the previous layer. mlp_args (dict, optional): A dictionary of arguments for class :class:`mnets.mlp.MLP`. The input size of such an MLP should be twice the size of :math:`\hat{h}_t^{(f)}`. If ``None``, then the output of the last bidirectional layer is considered the output of the network. preprocess_fct (func, optional): A function handle can be provided, that will process inputs ``x`` passed to the method :meth:`forward`. An example usecase could be the translation or selection of word embeddings. The function handle must have the signature: ``preprocess_fct(x, seq_lengths=None)``. See the corresponding argument descriptions of method :meth:`forward`.The function is expected to return the preprocessed ``x``. no_weights (bool): See parameter ``no_weights`` of class :class:`mnets.mlp.MLP`. verbose (bool): See parameter ``verbose`` of class :class:`mnets.mlp.MLP`. """ def __init__(self, rnn_args={}, mlp_args=None, preprocess_fct=None, no_weights=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) assert isinstance(rnn_args, (dict, list, tuple)) assert mlp_args is None or isinstance(mlp_args, dict) if isinstance(rnn_args, dict): rnn_args = [rnn_args] self._forward_rnns = [] self._backward_rnns = [] self._out_mlp = None self._preprocess_fct = preprocess_fct self._forward_called = False # FIXME At the moment we do not control input and output size of # individual networks and need to assume that the user sets them # correctly. ### Create all forward and backward nets for each bidirectional layer. for rargs in rnn_args: assert isinstance(rargs, dict) if 'verbose' not in rargs.keys(): rargs['verbose'] = False if 'no_weights' in rargs.keys() and \ rargs['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'bidirectional layer is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in rargs.keys(): rargs['no_weights'] = no_weights self._forward_rnns.append(SimpleRNN(**rargs)) self._backward_rnns.append(SimpleRNN(**rargs)) ### Create output network. if mlp_args is not None: if 'verbose' not in mlp_args.keys(): mlp_args['verbose'] = False if 'no_weights' in mlp_args.keys() and \ mlp_args['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'output MLP is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in mlp_args.keys(): mlp_args['no_weights'] = no_weights self._out_mlp = MLP(**mlp_args) ### Set all interface attributes correctly. if self._out_mlp is None: self._has_fc_out = self._forward_rnns[-1].has_fc_out # We can't set the following attribute to true, as the output is # a concatenation of the outputs from two networks. Therefore, the # weights used two compute the outputs are at different locations # in the `param_shapes` list. self._mask_fc_out = False self._has_linear_out = self._forward_rnns[-1].has_linear_out else: self._has_fc_out = self._out_mlp.has_fc_out self._mask_fc_out = self._out_mlp.mask_fc_out self._has_linear_out = self._out_mlp.has_linear_out # Collect all internal net objects from which we need to collect # attributes. nets = [] for i, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[i] nets.append((fnet, 'forward_rnn', i)) nets.append((bnet, 'backward_rnn', i)) if self._out_mlp is not None: nets.append((self._out_mlp, 'out_mlp', -1)) # Iterate over all nets to collect their attribute values. self._param_shapes = [] self._param_shapes_meta = [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() for i, net_tup in enumerate(nets): net, net_type, net_id = net_tup # Note, it is important to convert lists into new object and not # just copy references! # Note, we have to adapt all references if `i > 0`. # Sanity check: if i == 0: cm_nw = net._context_mod_no_weights elif cm_nw != net._context_mod_no_weights: raise ValueError('Network expect that either all internal ' + 'networks maintain their context-mod ' + 'weights or non of them does!') ps_len_old = len(self._param_shapes) if net._internal_params is not None: if self._internal_params is None: self._internal_params = nn.ParameterList() ip_len_old = len(self._internal_params) self._internal_params.extend( \ nn.ParameterList(net._internal_params)) self._param_shapes.extend(list(net._param_shapes)) for meta in net.param_shapes_meta: assert 'birnn_layer_type' not in meta.keys() assert 'birnn_layer_id' not in meta.keys() new_meta = dict(meta) new_meta['birnn_layer_type'] = net_type new_meta['birnn_layer_id'] = net_id if i > 0: # FIXME We should properly adjust colliding `layer` IDs. new_meta['layer'] = -1 new_meta['index'] = meta['index'] + ip_len_old self._param_shapes_meta.append(new_meta) if net._hyper_shapes_learned is not None: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = [] self._hyper_shapes_learned_ref = [] self._hyper_shapes_learned.extend( \ list(net._hyper_shapes_learned)) for ref in net._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ref + ps_len_old) if net._hyper_shapes_distilled is not None: if self._hyper_shapes_distilled is None: self._hyper_shapes_distilled = [] self._hyper_shapes_distilled.extend( \ list(net._hyper_shapes_distilled)) if self._has_bias is None: self._has_bias = net._has_bias elif self._has_bias != net._has_bias: self._has_bias = False # FIXME We should overwrite the getter and throw an error! warn('Some internally maintained networks use biases, ' + 'while others don\'t. Setting attribute "has_bias" to ' + 'False.') self._layer_weight_tensors.extend( \ nn.ParameterList(net._layer_weight_tensors)) self._layer_bias_vectors.extend( \ nn.ParameterList(net._layer_bias_vectors)) if net._batchnorm_layers is not None: if self._batchnorm_layers is None: self._batchnorm_layers = nn.ModuleList() self._batchnorm_layers.extend( \ nn.ModuleList(net._batchnorm_layers)) if net._context_mod_layers is not None: if self._context_mod_layers is None: self._context_mod_layers = nn.ModuleList() self._context_mod_layers.extend( \ nn.ModuleList(net._context_mod_layers)) self._is_properly_setup() ### Print user information. if verbose: print('Constructed Bidirectional RNN with %d weights.' \ % self.num_params) @property def preprocess_fct(self): """Getter for attribute :attr:`preprocess_fct`.""" return self._preprocess_fct @preprocess_fct.setter def preprocess_fct(self, value): """Setter for attribute :attr:`preprocess_fct`. Note: This setter may only be called before the first call of the :meth:`forward` method. """ if self._forward_called: raise RuntimeError('Attribute "preprocess_fct" cannot be ' + 'modified after method "forward" has been ' + 'called.') self._preprocess_fct = value @property def num_rec_layers(self): """Getter for read-only attribute :attr:`num_rec_layers`.""" num_rec_layers = 0 for net in self._forward_rnns + self._backward_rnns: num_rec_layers += net.num_rec_layers return num_rec_layers @property def use_lstm(self): """Getter for read-only attribute :attr:`use_lstm`.""" use_lstm = self._forward_rnns[0].use_lstm for i in range(1, len(self._forward_rnns)): if self._forward_rnns[i].use_lstm != use_lstm: raise RuntimeError('Attribute "use_lstm" not applicable to ' + 'this network as layers use mixed types ' + 'of RNNs.') return use_lstm def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. """ # SimpleRNNs should not have any distillation targets. for net in self._forward_rnns + self._backward_rnns: if net.distillation_targets is not None: raise RuntimeError() if self._out_mlp is not None: return self._out_mlp.distillation_targets() return None def forward(self, x, weights=None, distilled_params=None, condition=None, seq_lengths=None): """Compute the output :math:`y` of this network given the input :math:`x`. Note: If constructor argument ``preprocess_fct`` was set, then all inputs ``x`` are first processed by this function. Args: (....): See docstring of method :meth:`mnets.mnet_interface.MainNetInterface.forward`. We provide some more specific information below. weights (list or dict): See argument ``weights`` of method :meth:`mnets.mlp.MLP.forward`. distilled_params: Will only be passed to the underlying instance of class :class:`mnets.mlp.MLP` condition (int or dict, optional): If provided, then this argument will be passed as argument ``ckpt_id`` to the method :meth:`utils.context_mod_layer.ContextModLayer.forward`. When providing as dict, see argument ``condition`` of method :meth:`mnets.mlp.MLP.forward` for more details. seq_lengths (numpy.ndarray, optional): List of sequence lengths. The length of the list has to match the batch size of inputs ``x``. The entries will correspond to the unpadded sequence lengths. If this option is provided, then the bidirectional layers will reverse its input sequences according to the unpadded sequence lengths. Example: ``x = [[a,b,0,0], [a,b,c,0]].T``. If ``seq_lengths = [2, 3]`` if provided, then the reverse sequences ``[[b,a,0,0], [c,b,a,0]].T`` are fed into the first bidirectional layer (and similarly for all subsequent bidirectional layers). Otherwise reverse sequences ``[[0,0,b,a], [0,c,b,a]].T`` are used. Caution: If this option is not provided but padded input sequences are used, the output of a bidirectional layer will depent on the padding. I.e., different padding lengths will lead to different results. Returns: (torch.Tensor or tuple): Where the tuple is containing: - **output** (torch.Tensor): The output of the network. - **hidden** (list): ``None`` - not implemented yet. """ # FIXME Delete warning below. if seq_lengths is None: warn('"seq_lengths" has not been provided to BiRNN.') if self._out_mlp is None: assert distilled_params is None ######################## ### Parse condition ### ####################### rnn_cmod_cond = None mlp_cond = None if condition is not None: if isinstance(condition, dict): if 'cmod_ckpt_id' in condition.keys(): rnn_cmod_cond = condition['cmod_ckpt_id'] mlp_cond = condition else: rnn_cmod_cond = condition mlp_cond = {'cmod_ckpt_id': condition} ######################################## ### Extract-weights for each network ### ######################################## forward_weights = [None] * len(self._forward_rnns) backward_weights = [None] * len(self._backward_rnns) mlp_weights = None n_cm = self._num_context_mod_shapes() int_weights = None cm_weights = None all_weights = None if weights is not None and isinstance(weights, dict): if 'internal_weights' in weights.keys(): int_weights = weights['internal_weights'] if 'mod_weights' in weights.keys(): cm_weights = weights['mod_weights'] elif weights is not None: if len(weights) == n_cm: cm_weights = weights else: assert len(weights) == len(self.param_shapes) all_weights = weights if weights is not None: # Collect all context-mod and internal weights if not explicitly # passed. Note, those will either be taken from `all_weights` or # have to exist internally. if n_cm > 0 and cm_weights is None: cm_weights = [] for ii, meta in enumerate(self.param_shapes_meta): if meta['name'].startswith('cm_'): if all_weights is not None: cm_weights.append(all_weights[ii]) else: assert meta['index'] != -1 cm_weights.append( \ self.internal_params[meta['index']]) if int_weights is None: int_weights = [] for ii, meta in enumerate(self.param_shapes_meta): if not meta['name'].startswith('cm_'): if all_weights is not None: int_weights.append(all_weights[ii]) else: assert meta['index'] != -1 int_weights.append( \ self.internal_params[meta['index']]) # Now that we have all context-mod and internal weights, we need to # distribute them across networks. Therefore, note that the order # in which they appear in `param_shapes` matches the order of # `cm_weights` and `int_weights`. cm_ind = 0 int_ind = 0 for ii, meta in enumerate(self.param_shapes_meta): net_type = meta['birnn_layer_type'] net_id = meta['birnn_layer_id'] if net_type == 'forward_rnn': if forward_weights[net_id] is None: forward_weights[net_id] = dict() curr_weights = forward_weights[net_id] elif net_type == 'backward_rnn': if backward_weights[net_id] is None: backward_weights[net_id] = dict() curr_weights = backward_weights[net_id] else: assert net_type == 'out_mlp' if mlp_weights is None: mlp_weights = dict() curr_weights = mlp_weights if meta['name'].startswith('cm_'): if 'mod_weights' not in curr_weights.keys(): curr_weights['mod_weights'] = [] curr_weights['mod_weights'].append(cm_weights[cm_ind]) cm_ind += 1 else: if 'internal_weights' not in curr_weights.keys(): curr_weights['internal_weights'] = [] curr_weights['internal_weights'].append( \ int_weights[int_ind]) int_ind += 1 ##################################### ### Apply potential preprocessing ### ##################################### self._forward_called = True if self._preprocess_fct is not None: x = self._preprocess_fct(x, seq_lengths=seq_lengths) #################################### ### Process bidirectional layers ### #################################### # Create reverse input sequence for backward network. if seq_lengths is not None: assert seq_lengths.size == x.shape[1] def revert_order(inp): if seq_lengths is None: return torch.flip(inp, [0]) else: inp_back = torch.zeros_like(inp) for ii in range(seq_lengths.size): inp_back[:int(seq_lengths[ii]),ii, :] = \ torch.flip(inp[:int(seq_lengths[ii]),ii, :], [0]) return inp_back h = x for ll, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[ll] # Revert inputs in time before processing them by the backward RNN. h_rev = revert_order(h) h_f = fnet.forward(h, weights=forward_weights[ll], condition=rnn_cmod_cond, return_hidden=False, return_hidden_int=False) h_b = bnet.forward(h_rev, weights=backward_weights[ll], condition=rnn_cmod_cond, return_hidden=False, return_hidden_int=False) # Revert outputs in time from the backward RNN before concatenation. # NOTE If `seq_lengths` are given, then this function will also set # the hidden timesteps corresponding to "padded timesteps" to zero. h_b = revert_order(h_b) # Set hidden states of `h_f` corresponding to padded timesteps to # zero to ensure consistency. Note, will only ever affect those # "padded timesteps". if seq_lengths is not None: for ii in range(seq_lengths.size): h_f[:int(seq_lengths[ii]), ii, :] = 0 h = torch.cat([h_f, h_b], dim=2) ############################## ### Compute network output ### ############################## if self._out_mlp is not None: #n_time, n_batch, n_feat = h.shape #h = h.view(n_time*n_batch, n_feat) h = self._out_mlp.forward(h, weights=mlp_weights, distilled_params=distilled_params, condition=mlp_cond) #h = h.view(n_time, n_batch, -1) return h def init_hh_weights_orthogonal(self): """Initialize hidden-to-hidden weights orthogonally. This method will call method :meth:`mnets.simple_rnn.SimpleRNN.init_hh_weights_orthogonal` of all internally maintained instances of class :class:`mnets.simple_rnn.SimpleRNN`. """ for net in self._forward_rnns + self._backward_rnns: net.init_hh_weights_orthogonal() def get_cm_weights(self): """Get internal maintained weights that are associated with context- modulation. Returns: (list): List of weights from :attr:`mnets.mnet_interface.MainNetInterface.internal_params` that are belonging to context-mod layers. """ ret = [] for i, meta in enumerate(self.param_shapes_meta): if not (meta['name'] == 'cm_shift' or meta['name'] == 'cm_scale'): continue if meta['index'] != -1: ret.append(self.internal_params[meta['index']]) return ret def get_non_cm_weights(self): """Get internal weights that are not associated with context-modulation. Returns: (list): List of weights from :attr:`mnets.mnet_interface.MainNetInterface.internal_params` that are not belonging to context-mod layers. """ n_cm = self._num_context_mod_shapes() if n_cm == 0: return self.internal_params else: ret = [] for i, meta in enumerate(self.param_shapes_meta): if meta['name'] == 'cm_shift' or meta['name'] == 'cm_scale': continue if meta['index'] != -1: ret.append(self.internal_params[meta['index']]) return ret
def __init__(self, rnn_args={}, mlp_args=None, preprocess_fct=None, no_weights=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) assert isinstance(rnn_args, (dict, list, tuple)) assert mlp_args is None or isinstance(mlp_args, dict) if isinstance(rnn_args, dict): rnn_args = [rnn_args] self._forward_rnns = [] self._backward_rnns = [] self._out_mlp = None self._preprocess_fct = preprocess_fct self._forward_called = False # FIXME At the moment we do not control input and output size of # individual networks and need to assume that the user sets them # correctly. ### Create all forward and backward nets for each bidirectional layer. for rargs in rnn_args: assert isinstance(rargs, dict) if 'verbose' not in rargs.keys(): rargs['verbose'] = False if 'no_weights' in rargs.keys() and \ rargs['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'bidirectional layer is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in rargs.keys(): rargs['no_weights'] = no_weights self._forward_rnns.append(SimpleRNN(**rargs)) self._backward_rnns.append(SimpleRNN(**rargs)) ### Create output network. if mlp_args is not None: if 'verbose' not in mlp_args.keys(): mlp_args['verbose'] = False if 'no_weights' in mlp_args.keys() and \ mlp_args['no_weights'] != no_weights: raise ValueError('Keyword argument "no_weights" of ' + 'output MLP is in conflict with ' + 'constructor argument "no_weights".') elif 'no_weights' not in mlp_args.keys(): mlp_args['no_weights'] = no_weights self._out_mlp = MLP(**mlp_args) ### Set all interface attributes correctly. if self._out_mlp is None: self._has_fc_out = self._forward_rnns[-1].has_fc_out # We can't set the following attribute to true, as the output is # a concatenation of the outputs from two networks. Therefore, the # weights used two compute the outputs are at different locations # in the `param_shapes` list. self._mask_fc_out = False self._has_linear_out = self._forward_rnns[-1].has_linear_out else: self._has_fc_out = self._out_mlp.has_fc_out self._mask_fc_out = self._out_mlp.mask_fc_out self._has_linear_out = self._out_mlp.has_linear_out # Collect all internal net objects from which we need to collect # attributes. nets = [] for i, fnet in enumerate(self._forward_rnns): bnet = self._backward_rnns[i] nets.append((fnet, 'forward_rnn', i)) nets.append((bnet, 'backward_rnn', i)) if self._out_mlp is not None: nets.append((self._out_mlp, 'out_mlp', -1)) # Iterate over all nets to collect their attribute values. self._param_shapes = [] self._param_shapes_meta = [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() for i, net_tup in enumerate(nets): net, net_type, net_id = net_tup # Note, it is important to convert lists into new object and not # just copy references! # Note, we have to adapt all references if `i > 0`. # Sanity check: if i == 0: cm_nw = net._context_mod_no_weights elif cm_nw != net._context_mod_no_weights: raise ValueError('Network expect that either all internal ' + 'networks maintain their context-mod ' + 'weights or non of them does!') ps_len_old = len(self._param_shapes) if net._internal_params is not None: if self._internal_params is None: self._internal_params = nn.ParameterList() ip_len_old = len(self._internal_params) self._internal_params.extend( \ nn.ParameterList(net._internal_params)) self._param_shapes.extend(list(net._param_shapes)) for meta in net.param_shapes_meta: assert 'birnn_layer_type' not in meta.keys() assert 'birnn_layer_id' not in meta.keys() new_meta = dict(meta) new_meta['birnn_layer_type'] = net_type new_meta['birnn_layer_id'] = net_id if i > 0: # FIXME We should properly adjust colliding `layer` IDs. new_meta['layer'] = -1 new_meta['index'] = meta['index'] + ip_len_old self._param_shapes_meta.append(new_meta) if net._hyper_shapes_learned is not None: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = [] self._hyper_shapes_learned_ref = [] self._hyper_shapes_learned.extend( \ list(net._hyper_shapes_learned)) for ref in net._hyper_shapes_learned_ref: self._hyper_shapes_learned_ref.append(ref + ps_len_old) if net._hyper_shapes_distilled is not None: if self._hyper_shapes_distilled is None: self._hyper_shapes_distilled = [] self._hyper_shapes_distilled.extend( \ list(net._hyper_shapes_distilled)) if self._has_bias is None: self._has_bias = net._has_bias elif self._has_bias != net._has_bias: self._has_bias = False # FIXME We should overwrite the getter and throw an error! warn('Some internally maintained networks use biases, ' + 'while others don\'t. Setting attribute "has_bias" to ' + 'False.') self._layer_weight_tensors.extend( \ nn.ParameterList(net._layer_weight_tensors)) self._layer_bias_vectors.extend( \ nn.ParameterList(net._layer_bias_vectors)) if net._batchnorm_layers is not None: if self._batchnorm_layers is None: self._batchnorm_layers = nn.ModuleList() self._batchnorm_layers.extend( \ nn.ModuleList(net._batchnorm_layers)) if net._context_mod_layers is not None: if self._context_mod_layers is None: self._context_mod_layers = nn.ModuleList() self._context_mod_layers.extend( \ nn.ModuleList(net._context_mod_layers)) self._is_properly_setup() ### Print user information. if verbose: print('Constructed Bidirectional RNN with %d weights.' \ % self.num_params)
class ChunkSqueezer(nn.Module, MainNetInterface): """An MLP that first reduces the dimensionality of its inputs. The input dimensionality ``n_in`` is first reduced by a `reducer` network (which is an instance of class :class:`mnets.mlp.MLP`) using a chunking strategy. The reduced input will be then passed to the actual `network` (which is another instance of :class:`mnets.mlp.MLP`) to compute an output. Args: n_in (int): Input dimensionality. n_out (int): Number of output neurons. inp_chunk_dim (int): The input (dimensionality ``n_in``) will be split into chunks of size ``inp_chunk_dim``. Thus, there will be ``np.ceil(n_in/inp_chunk_dim)`` input chunks that are individually squeezed through the `reducer` network. Note: If the last chunk chunk might be zero-padded. out_chunk_dim (int): The output size of the `reducer` network. The input size of the actual `network` is then ``np.ceil(n_in/inp_chunk_dim) * out_chunk_dim``. cemb_size (int): The `reducer` network processes every chunk individually. In order to do so, it needs to know which chunk it is processing. Therefore, it is conditioned on a learned chunk embedding (there will be ``np.ceil(n_in/inp_chunk_dim)`` chunk embeddings). The dimensionality of these chunk embeddings is dertermined by this argument. cemb_init_std (float): Standard deviation used for the normal initialization of the chunk embeddings. red_layers (list or tuple): The architecture of the `reducer` network. See argument ``hidden_layers`` of class :class:`mnets.mlp.MLP`. net_layers (list or tuple): The architecture of the actual `network`. See argument ``hidden_layers`` of class :class:`mnets.mlp.MLP`. activation_fn: The nonlinearity used in hidden layers. If ``None``, no nonlinearity will be applied. use_bias: Will be passed as option ``use_bias`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). dynamic_biases (list, optional): This option determines the hidden layers of the `reducer` networks that receive the chunk embedding as dynamic biases. It is a list of indexes with the first hidden layer having index 0 and the output of the `reducer` would have index ``len(red_layers)``. The chunk embeddings will be transformed through a fully connected layer (no bias) and then added as "dynamic" bias to the output of the corresponding hidden layer. Note: If left unspecified, the chunk embeddings will just be another input to the `reducer` network. no_weights (bool): If set to ``True``, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the :meth:`forward` method. init_weights (optional): This option is for convinience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convinient way of initializing a network with a weight draw produced by the hypernetwork. Note, internal weights (see :attr:`mnets.mnet_interface.MainNetInterface.weights`) will be affected by this argument only. dropout_rate (float): Will be passed as option ``dropout_rate`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). use_spectral_norm (bool): Will be passed as option ``use_spectral_norm`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). use_batch_norm (bool): Will be passed as option ``use_batch_norm`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). bn_track_stats (bool): Will be passed as option ``bn_track_stats`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). distill_bn_stats (bool): Will be passed as option ``distill_bn_stats`` to the underlying MLPs (see :class:`mnets.mlp.MLP`). """ def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup() def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. This method will return the distillation targets from the 2 underlying networks, see method :meth:`mnets.mlp.MLP.distillation_targets`. Returns: The target tensors corresponding to the shapes specified in attribute :attr:`hyper_shapes_distilled`. """ if self.hyper_shapes_distilled is None: return None ret = self._reducer.distillation_targets + \ self._network.distillation_targets return ret def forward(self, x, weights=None, distilled_params=None, condition=None): """Compute the output :math:`y` of this network given the input :math:`x`. Args: (....): See docstring of method :meth:`mnets.mnet_interface.MainNetInterface.forward`. We provide some more specific information below. distilled_params: Will be split and passed as distillation targets to the underying instances of class :class:`mnets.mlp.MLP` if specified. condition (optional, int or dict): Will be passed to the underlying instances of class :class:`mnets.mlp.MLP`. Returns: The output :math:`y` of the network. """ if self._no_weights and weights is None: raise Exception('Network was generated without weights. ' + 'Hence, "weights" option may not be None.') if weights is None: weights = self._weights else: assert len(weights) == len(self.param_shapes) for i, s in enumerate(self.param_shapes): assert np.all(np.equal(s, list(weights[i].shape))) ######################################### ### Extract parameters from `weights` ### ######################################### cembs = weights[0] w_ind = 1 if self._has_dyn_bias is not None: w_ind_new = w_ind + len(self._dyn_bias_weights) dyn_bias_weights = weights[w_ind:w_ind_new] w_ind = w_ind_new # TODO use `dyn_bias_weights` to construct weights for context-mod # layers. raise NotImplementedError w_ind_new = w_ind + len(self._reducer.param_shapes) red_weights = weights[w_ind:w_ind_new] w_ind = w_ind_new w_ind_new = w_ind + len(self._network.param_shapes) net_weights = weights[w_ind:w_ind_new] w_ind = w_ind_new red_distilled_params = None net_distilled_params = None if distilled_params is not None: if self.hyper_shapes_distilled is None: raise ValueError( 'Argument "distilled_params" can only be ' + 'provided if the return value of ' + 'method "distillation_targets()" is not None.') assert len(distilled_params) == len(self.hyper_shapes_distilled) red_distilled_params = \ distilled_params[:len(self._reducer.hyper_shapes_distilled)] net_distilled_params = \ distilled_params[len(self._reducer.hyper_shapes_distilled):] ########################### ### Chunk network input ### ########################### assert x.shape[1] == self._n_in if self._pad != -1: x = F.pad(x, (0, self._pad)) assert x.shape[1] % self._out_chunk_dim == 0 batch_size = x.shape[0] # We now split the input `x` into chunks and convert them into # separate samples, i.e., the `batch_size` will be multiplied by the # number of chunks. # So, we parallel process a huge batch with a small network rather than # processing a huge input with a huge network. chunks = torch.split(x, self._inp_chunk_dim, dim=1) # Concatenate the chunks along the batch dimension. chunks = torch.cat(chunks, dim=0) if self._has_dyn_bias is not None: raise NotImplementedError() else: # Within a chunk the same chunk embedding is used. cembs = torch.split(cembs, 1, dim=0) cembs = [emb.expand(batch_size, -1) for emb in cembs] cembs = torch.cat(cembs, dim=0) chunks = torch.cat([chunks, cembs], dim=1) ################################### ### Reduce input dimensionality ### ################################### if self._has_dyn_bias is not None: # TODO pass context-mod weights to `reducer`. raise NotImplementedError() chunks = self._reducer.forward(chunks, weights=red_weights, distilled_params=red_distilled_params, condition=condition) ### Reformat `reducer` output into the input of the actual `network`. chunks = torch.split(chunks, batch_size, dim=0) net_input = torch.cat(chunks, dim=1) assert net_input.shape[0] == batch_size ############################### ### Compute network output ### ############################## return self._network.forward(net_input, weights=net_weights, distilled_params=net_distilled_params, condition=condition)
def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup()