def _generate_networks(config, data_handlers, device, create_hnet=True, create_rnet=False, no_replay=False): """Create the main-net, hypernetwork and recognition network. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. create_hnet: Whether a hypernetwork should be constructed. If not, the main network will have trainable weights. create_rnet: Whether a task-recognition autoencoder should be created. no_replay: If the recognition network should be an instance of class MainModel rather than of class RecognitionNet (note, for multitask learning, no replay network is required). Returns: mnet: Main network instance. hnet: Hypernetwork instance. This return value is None if no hypernetwork should be constructed. rnet: RecognitionNet instance. This return value is None if no recognition network should be constructed. """ num_tasks = len(data_handlers) n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks main_arch = misc.str_to_ints(config.main_arch) main_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=n_y, hidden_layers=main_arch) mnet = MainNetwork(main_shapes, activation_fn=misc.str_to_act(config.main_act), use_bias=True, no_weights=create_hnet).to(device) if create_hnet: hnet_arch = misc.str_to_ints(config.hnet_arch) hnet = HyperNetwork(main_shapes, num_tasks, layers=hnet_arch, te_dim=config.emb_size, activation_fn=misc.str_to_act( config.hnet_act)).to(device) init_params = list(hnet.parameters()) else: hnet = None init_params = list(mnet.parameters()) if create_rnet: ae_arch = misc.str_to_ints(config.ae_arch) if no_replay: rnet_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=num_tasks, hidden_layers=ae_arch, use_bias=True) rnet = MainNetwork(rnet_shapes, activation_fn=misc.str_to_act(config.ae_act), use_bias=True, no_weights=False, dropout_rate=-1, out_fn=lambda x: F.softmax(x, dim=1)) else: rnet = RecognitionNet(n_x, num_tasks, dim_z=config.ae_dim_z, enc_layers=ae_arch, activation_fn=misc.str_to_act(config.ae_act), use_bias=True).to(device) init_params += list(rnet.parameters()) else: rnet = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) elif config.normal_init: torch.nn.init.normal_(W, mean=0, std=config.std_normal_init) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_hnet: for temb in hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if config.use_hyperfan_init: hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2) return mnet, hnet, rnet
def __init__(self, n_in, n_tasks, dim_z=8, enc_layers=[10, 10], activation_fn=torch.nn.ReLU(), use_bias=True): """Initialize the network. Args: n_in: Input size (input dim of encoder and output dim of decoder). n_tasks: The maximum number of tasks to be detected (size of softmax layer). dim_z: Dimensionality of latent space z. enc_layers: A list of integers, each denoting the size of a hidden layer in the encoder. The decoder will have layer sizes in reverse order. activation_fn: The nonlinearity used in hidden layers. If None, no nonlinearity will be applied. use_bias: Whether layers may have bias terms. """ super(RecognitionNet, self).__init__() self._n_alpha = n_tasks self._n_nu_z = 2 * dim_z self._n_z = dim_z ## Enoder encoder_shapes = MainNetwork.weight_shapes(n_in=n_in, n_out=self._n_alpha + self._n_nu_z, hidden_layers=enc_layers, use_bias=use_bias) self._encoder = MainNetwork(encoder_shapes, activation_fn=activation_fn, use_bias=use_bias, no_weights=False, dropout_rate=-1, verbose=False) self._weights_enc = self._encoder.weights ## Decoder decoder_shapes = MainNetwork.weight_shapes( n_in=self._n_alpha + self._n_z, n_out=n_in, hidden_layers=list(reversed(enc_layers)), use_bias=use_bias) self._decoder = MainNetwork(decoder_shapes, activation_fn=activation_fn, use_bias=use_bias, no_weights=False, dropout_rate=-1, verbose=False) self._weights_dec = self._decoder.weights ## Prior # Note, when changing the prior, one has to change the method # "prior_matching". self._mu_z = torch.zeros(dim_z) self._sigma_z = torch.ones(dim_z) n_params = np.sum([np.prod(p.shape) for p in self.parameters()]) print('Constructed recognition model with %d parameters.' % n_params)