def get_hnet_model(config, mnet, logger, device): """Generate the hypernetwork. This function uses :func:`utils.sim_utils.get_hnet_model` to generate the hypernetwork. The function also takes care of weight initialization, if configured. Args: (....): See docstring of function :func:`get_main_model`. mnet: The main network. Returns: The hypernetwork or ``None`` if no hypernet is needed. """ logger.info('Creating hypernetwork ...') hnet = sutils.get_hnet_model(config, config.num_tasks, device, mnet.param_shapes) # FIXME There should be a nicer way of initializing hypernets in the # future. chunk_embs = None if hasattr(hnet, 'chunk_embeddings'): chunk_embs = hnet.chunk_embeddings init_network_weights(hnet.parameters(), config, logger, chunk_embs=chunk_embs, task_embs=hnet.get_task_embs(), net=hnet) if config.hnet_init_shift: hnet_init_shift(hnet, mnet, config, logger, device) # TODO Incorporate hyperchunk init. #if isinstance(hnet, ChunkedHyperNetworkHandler): # hnet.apply_chunked_hyperfan_init(temb_var=config.std_normal_temb**2) return hnet
def generate_classifier(config, data_handlers, device): """Create a classifier network. Depending on the experiment and method, the method manages to build either a classifier for task inference or a classifier that solves our task is build. This also implies if the network will receive weights from a hypernetwork or will have weights on its own. Following important configurations will be determined in order to create the classifier: \n * in- and output and hidden layer dimensions of the classifier. \n * architecture, chunk- and task-embedding details of the hypernetwork. See :class:`mnets.mlp.MLP` for details on the network that will be created to be a classifier. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. Returns: (tuple): Tuple containing: - **net**: The classifier network. - **class_hnet**: (optional) The classifier's hypernetwork. """ n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.class_incremental: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.training_task_infer or config.class_incremental: # task inference network config.out_dim = 1 # have all output neurons already build up for cl 2 if config.cl_scenario != 2: n_out = config.out_dim * config.num_tasks else: n_out = config.out_dim if config.training_task_infer or config.class_incremental: n_out = config.num_tasks # build classifier print('For the Classifier: ') class_arch = misc.str_to_ints(config.class_fc_arch) if config.training_with_hnet: no_weights = True else: no_weights = False net = MLP(n_in=n_in, n_out=n_out, hidden_layers=class_arch, activation_fn=misc.str_to_act(config.class_net_act), dropout_rate=config.class_dropout_rate, no_weights=no_weights).to(device) print('Constructed MLP with shapes: ', net.param_shapes) config.num_weights_class_net = \ MainNetInterface.shapes_to_num_weights(net.param_shapes) # build classifier hnet # this is set in the run method in train.py if config.training_with_hnet: class_hnet = sim_utils.get_hnet_model(config, config.num_tasks, device, net.param_shapes, cprefix='class_') init_params = list(class_hnet.parameters()) config.num_weights_class_hyper_net = sum( p.numel() for p in class_hnet.parameters() if p.requires_grad) config.compression_ratio_class = config.num_weights_class_hyper_net / \ config.num_weights_class_net print('Created classifier Hypernetwork with ratio: ', config.compression_ratio_class) if config.compression_ratio_class > 1: print('Note that the compression ratio is computed compared to ' + 'current target network, not might not be directly ' + 'comparable with the number of parameters of work we ' + 'compare against.') else: class_hnet = None init_params = list(net.parameters()) config.num_weights_class_hyper_net = None config.compression_ratio_class = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if config.training_with_hnet: for temb in class_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(class_hnet, 'chunk_embeddings'): for emb in class_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) if not config.training_with_hnet: return net else: return net, class_hnet
def generate_replay_networks(config, data_handlers, device, create_rp_hnet=True, only_train_replay=False): """Create a replay model that consists of either a encoder/decoder or a discriminator/generator pair. Additionally, this method manages the creation of a hypernetwork for the generator/decoder. Following important configurations will be determined in order to create the replay model: * in- and output and hidden layer dimensions of the encoder/decoder. * architecture, chunk- and task-embedding details of decoder's hypernetwork. .. note:: This module also handles the initialisation of the weights of either the classifier or its hypernetwork. This will change in the near future. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device.. create_rp_hnet: Whether a hypernetwork for the replay should be constructed. If not, the decoder/generator will have trainable weights on its own. only_train_replay: We normally do not train on the last task since we do not need to replay this last tasks data. But if we want a replay method to be able to generate data from all tasks then we set this option to true. Returns: (tuple): Tuple containing: - **enc**: Encoder/discriminator network instance. - **dec**: Decoder/generator networkinstance. - **dec_hnet**: Hypernetwork instance for the decoder/generator. This return value is None if no hypernetwork should be constructed. """ if config.replay_method == 'gan': n_out = 1 else: n_out = config.latent_dim * 2 n_in = data_handlers[0].in_shape[0] pd = config.padding * 2 if config.experiment == "splitMNIST": n_in = n_in * n_in else: # permutedMNIST n_in = (n_in + pd) * (n_in + pd) config.input_dim = n_in if config.experiment == "splitMNIST": if config.single_class_replay: config.out_dim = 1 else: config.out_dim = 2 else: # permutedMNIST config.out_dim = 10 if config.infer_task_id: # task inference network config.out_dim = 1 # builld encoder print('For the replay encoder/discriminator: ') enc_arch = misc.str_to_ints(config.enc_fc_arch) enc = MLP(n_in=n_in, n_out=n_out, hidden_layers=enc_arch, activation_fn=misc.str_to_act(config.enc_net_act), dropout_rate=config.enc_dropout_rate, no_weights=False).to(device) print('Constructed MLP with shapes: ', enc.param_shapes) init_params = list(enc.parameters()) # builld decoder print('For the replay decoder/generator: ') dec_arch = misc.str_to_ints(config.dec_fc_arch) # add dimensions for conditional input n_out = config.latent_dim if config.conditional_replay: n_out += config.conditional_dim dec = MLP(n_in=n_out, n_out=n_in, hidden_layers=dec_arch, activation_fn=misc.str_to_act(config.dec_net_act), use_bias=True, no_weights=config.rp_beta > 0, dropout_rate=config.dec_dropout_rate).to(device) print('Constructed MLP with shapes: ', dec.param_shapes) config.num_weights_enc = \ MainNetInterface.shapes_to_num_weights(enc.param_shapes) config.num_weights_dec = \ MainNetInterface.shapes_to_num_weights(dec.param_shapes) config.num_weights_rp_net = config.num_weights_enc + config.num_weights_dec # we do not need a replay model for the last task # train on last task or not if only_train_replay: subtr = 0 else: subtr = 1 num_embeddings = config.num_tasks - subtr if config.num_tasks > 1 else 1 if config.single_class_replay: # we do not need a replay model for the last task if config.num_tasks > 1: num_embeddings = config.out_dim * (config.num_tasks - subtr) else: num_embeddings = config.out_dim * (config.num_tasks) config.num_embeddings = num_embeddings # build decoder hnet if create_rp_hnet: print('For the decoder/generator hypernetwork: ') d_hnet = sim_utils.get_hnet_model(config, num_embeddings, device, dec.hyper_shapes_learned, cprefix='rp_') init_params += list(d_hnet.parameters()) config.num_weights_rp_hyper_net = sum(p.numel() for p in d_hnet.parameters() if p.requires_grad) config.compression_ratio_rp = config.num_weights_rp_hyper_net / \ config.num_weights_dec print('Created replay hypernetwork with ratio: ', config.compression_ratio_rp) if config.compression_ratio_rp > 1: print('Note that the compression ratio is computed compared to ' + 'current target network,\nthis might not be directly ' + 'comparable with the number of parameters of methods we ' + 'compare against.') else: num_embeddings = config.num_tasks - subtr d_hnet = None init_params += list(dec.parameters()) config.num_weights_rp_hyper_net = 0 config.compression_ratio_rp = 0 ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_rp_hnet: for temb in d_hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if hasattr(d_hnet, 'chunk_embeddings'): for emb in d_hnet.chunk_embeddings: torch.nn.init.normal_(emb, mean=0, std=config.std_normal_emb) return enc, dec, d_hnet