def calc_reg_masks(data_handlers, mnet, device, config): """Compute the regularizer mask for each task when using a multi-head setup. See method "get_reg_masks" of class "MainNetwork" for more details. Deprecated: Method "calc_fix_target_reg" has its own way of handling multihead setups that is more memory efficient than keeping a mask for each task. Args: (....): See docstring of method :func:`train_reg`. data_handlers: A list of all data_handlers, one for each task. Returns: A list of regularizer masks. """ assert (config.multi_head and config.masked_reg) warn('"calc_reg_masks" is deprecated and not maintained as it is unused ' + 'currently.', DeprecationWarning) assert (mnet.has_fc_out) main_shapes = mnet.hyper_shapes masks = [] for i, data in enumerate(data_handlers): n_y = data.out_shape[0] allowed_outputs = list(range(i * n_y, (i + 1) * n_y)) masks.append(MainNetwork.get_reg_masks(main_shapes, allowed_outputs, device, use_bias=mnet.has_bias)) return masks
def _generate_networks(config, data_handlers, device, create_hnet=True, create_rnet=False, no_replay=False): """Create the main-net, hypernetwork and recognition network. Args: config: Command-line arguments. data_handlers: List of data handlers, one for each task. Needed to extract the number of inputs/outputs of the main network. And to infer the number of tasks. device: Torch device. create_hnet: Whether a hypernetwork should be constructed. If not, the main network will have trainable weights. create_rnet: Whether a task-recognition autoencoder should be created. no_replay: If the recognition network should be an instance of class MainModel rather than of class RecognitionNet (note, for multitask learning, no replay network is required). Returns: mnet: Main network instance. hnet: Hypernetwork instance. This return value is None if no hypernetwork should be constructed. rnet: RecognitionNet instance. This return value is None if no recognition network should be constructed. """ num_tasks = len(data_handlers) n_x = data_handlers[0].in_shape[0] n_y = data_handlers[0].out_shape[0] if config.multi_head: n_y = n_y * num_tasks main_arch = misc.str_to_ints(config.main_arch) main_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=n_y, hidden_layers=main_arch) mnet = MainNetwork(main_shapes, activation_fn=misc.str_to_act(config.main_act), use_bias=True, no_weights=create_hnet).to(device) if create_hnet: hnet_arch = misc.str_to_ints(config.hnet_arch) hnet = HyperNetwork(main_shapes, num_tasks, layers=hnet_arch, te_dim=config.emb_size, activation_fn=misc.str_to_act( config.hnet_act)).to(device) init_params = list(hnet.parameters()) else: hnet = None init_params = list(mnet.parameters()) if create_rnet: ae_arch = misc.str_to_ints(config.ae_arch) if no_replay: rnet_shapes = MainNetwork.weight_shapes(n_in=n_x, n_out=num_tasks, hidden_layers=ae_arch, use_bias=True) rnet = MainNetwork(rnet_shapes, activation_fn=misc.str_to_act(config.ae_act), use_bias=True, no_weights=False, dropout_rate=-1, out_fn=lambda x: F.softmax(x, dim=1)) else: rnet = RecognitionNet(n_x, num_tasks, dim_z=config.ae_dim_z, enc_layers=ae_arch, activation_fn=misc.str_to_act(config.ae_act), use_bias=True).to(device) init_params += list(rnet.parameters()) else: rnet = None ### Initialize network weights. for W in init_params: if W.ndimension() == 1: # Bias vector. torch.nn.init.constant_(W, 0) elif config.normal_init: torch.nn.init.normal_(W, mean=0, std=config.std_normal_init) else: torch.nn.init.xavier_uniform_(W) # The task embeddings are initialized differently. if create_hnet: for temb in hnet.get_task_embs(): torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb) if config.use_hyperfan_init: hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2) return mnet, hnet, rnet
def __init__(self, n_in, n_tasks, dim_z=8, enc_layers=[10, 10], activation_fn=torch.nn.ReLU(), use_bias=True): """Initialize the network. Args: n_in: Input size (input dim of encoder and output dim of decoder). n_tasks: The maximum number of tasks to be detected (size of softmax layer). dim_z: Dimensionality of latent space z. enc_layers: A list of integers, each denoting the size of a hidden layer in the encoder. The decoder will have layer sizes in reverse order. activation_fn: The nonlinearity used in hidden layers. If None, no nonlinearity will be applied. use_bias: Whether layers may have bias terms. """ super(RecognitionNet, self).__init__() self._n_alpha = n_tasks self._n_nu_z = 2 * dim_z self._n_z = dim_z ## Enoder encoder_shapes = MainNetwork.weight_shapes(n_in=n_in, n_out=self._n_alpha + self._n_nu_z, hidden_layers=enc_layers, use_bias=use_bias) self._encoder = MainNetwork(encoder_shapes, activation_fn=activation_fn, use_bias=use_bias, no_weights=False, dropout_rate=-1, verbose=False) self._weights_enc = self._encoder.weights ## Decoder decoder_shapes = MainNetwork.weight_shapes( n_in=self._n_alpha + self._n_z, n_out=n_in, hidden_layers=list(reversed(enc_layers)), use_bias=use_bias) self._decoder = MainNetwork(decoder_shapes, activation_fn=activation_fn, use_bias=use_bias, no_weights=False, dropout_rate=-1, verbose=False) self._weights_dec = self._decoder.weights ## Prior # Note, when changing the prior, one has to change the method # "prior_matching". self._mu_z = torch.zeros(dim_z) self._sigma_z = torch.ones(dim_z) n_params = np.sum([np.prod(p.shape) for p in self.parameters()]) print('Constructed recognition model with %d parameters.' % n_params)
class RecognitionNet(nn.Module): """The recognition network consists of an encoder and decoder. The encoder gets as input X (the same input as the main network) and has two outputs: a softmax layer named alpha, that can be used to determine the task inferred from the input; and a latent embedding nu_z, that we interpret as parameters of a normal distribution (mean and log-variances). Using the reparametrization trick, we can sample z. Both, alpha and z are going to be the input to the decoder. The decoder aims to reconstruct the input of the encoder. The network consists only of fully-connected layers. The decoder architecture is a mirrored version of the encoder architecture, except for the fact that the input to the decoder is z and not nu_z. Attributes (additional to base class): encoder_weights: All parameters of the encoder network. decoder_weights: All parameters of the decoder network. dim_alpha: Dimensionality of the softmax layer alpha. dim_z: Dimensionality of the latent embeddings z. """ def __init__(self, n_in, n_tasks, dim_z=8, enc_layers=[10, 10], activation_fn=torch.nn.ReLU(), use_bias=True): """Initialize the network. Args: n_in: Input size (input dim of encoder and output dim of decoder). n_tasks: The maximum number of tasks to be detected (size of softmax layer). dim_z: Dimensionality of latent space z. enc_layers: A list of integers, each denoting the size of a hidden layer in the encoder. The decoder will have layer sizes in reverse order. activation_fn: The nonlinearity used in hidden layers. If None, no nonlinearity will be applied. use_bias: Whether layers may have bias terms. """ super(RecognitionNet, self).__init__() self._n_alpha = n_tasks self._n_nu_z = 2 * dim_z self._n_z = dim_z ## Enoder encoder_shapes = MainNetwork.weight_shapes(n_in=n_in, n_out=self._n_alpha + self._n_nu_z, hidden_layers=enc_layers, use_bias=use_bias) self._encoder = MainNetwork(encoder_shapes, activation_fn=activation_fn, use_bias=use_bias, no_weights=False, dropout_rate=-1, verbose=False) self._weights_enc = self._encoder.weights ## Decoder decoder_shapes = MainNetwork.weight_shapes( n_in=self._n_alpha + self._n_z, n_out=n_in, hidden_layers=list(reversed(enc_layers)), use_bias=use_bias) self._decoder = MainNetwork(decoder_shapes, activation_fn=activation_fn, use_bias=use_bias, no_weights=False, dropout_rate=-1, verbose=False) self._weights_dec = self._decoder.weights ## Prior # Note, when changing the prior, one has to change the method # "prior_matching". self._mu_z = torch.zeros(dim_z) self._sigma_z = torch.ones(dim_z) n_params = np.sum([np.prod(p.shape) for p in self.parameters()]) print('Constructed recognition model with %d parameters.' % n_params) @property def dim_alpha(self): """Getter for read-only attribute dim_alpha. Returns: Size of alpha layer. """ return self._n_alpha @property def dim_z(self): """Getter for read-only attribute dim_z. Returns: Size of z layer. """ return self._n_z @property def encoder_weights(self): """Getter for read-only attribute encoder_weights. Returns: A torch.nn.ParameterList. """ return self._weights_enc @property def decoder_weights(self): """Getter for read-only attribute decoder_weights. Returns: A torch.nn.ParameterList. """ return self._weights_dec def forward(self, x): """This function computes x_rec = decode(encode(x)) Note, the function utilizes the class members "encode" and "decode". Args: x: The input to the "autoencoder". Returns: x_rec: The reconstruction of the input. """ alpha, _, z = self.encode(x) x_rec = self.decode(alpha, z) return x_rec def encode(self, x, ret_log_alpha=False, encoder_weights=None): """Encode a sample x -> "recognize the task of x". Args: x: An input sample (from which a task should be inferred). ret_log_alpha (optional): Whether the log-softmax distribution of the output layer alpha should be returned as well. encoder_weights (optional): If given, these will be the parameters used in the encoder rather than the ones maintained object internally. Returns: (tuple): Tuple containing: - **alpha**: The softmax output (task classification output). - **nu_z**: The parameters of the latent distribution from which "z" is sampled (i.e., the actual output of the encoder besides alpha). Note, that these parameters are the cooncatenated means and log-variances of the latent distribution. - **z**: A latent space embedding retrieved via the reparametrization trick. - **log_alpha** (optional): The log softmax activity of alpha. """ phi_e = None if encoder_weights is not None: phi_e = encoder_weights h = self._encoder.forward(x, weights=phi_e) h_alpha = h[:, :self._n_alpha] alpha = F.softmax(h_alpha, dim=1) params_z = h[:, self._n_alpha:] mu_z = params_z[:, :self._n_z] logvar_z = params_z[:, self._n_z:] std_z = torch.exp(0.5 * logvar_z) z = Normal(mu_z, std_z).rsample() if ret_log_alpha: return alpha, params_z, z, F.log_softmax(h_alpha, dim=1) return alpha, params_z, z def decode(self, alpha, z, decoder_weights=None): """Decode a latent representation back to a sample. If alpha is a 1-hot encoding denoting a specific task and z are latent space samples, the decoding can be seen as "replay" of task samples. Args: alpha: See return value of method "encode". z: See return value of method "encode". decoder_weights (optional): If given, these will be the parameters used in the decoder rather than the ones maintained object internally. Returns: x_dec: The decoded sample. """ phi_d = None if decoder_weights is not None: phi_d = decoder_weights x_dec = self._decoder.forward(torch.cat([alpha, z], dim=1), weights=phi_d) return x_dec def prior_samples(self, batch_size): """Obtain a batch of samples from the prior for the latent space z. Args: batch_size: Number of samples to acquire. Returns: A torch tensor of samples. """ return Normal(self._mu_z, self._sigma_z).rsample([batch_size]) def prior_matching(self, nu_z): """Compute the prior matching term between the Gaussian described by the parameters "nu_z" and a standard normal distribution N(0, I). Args: nu_z: Part of the encoder output. Returns: The value of the prior matching loss. """ mu_z = nu_z[:, :self._n_z] logvar_z = nu_z[:, self._n_z:] var_z = logvar_z.exp() return -0.5 * torch.sum(1 + logvar_z - mu_z.pow(2) - var_z) @staticmethod def task_cross_entropy(log_alpha, target): """A call to pytorch "nll_loss". Args: log_alpha: The log softmax activity of the alpha layer. target: A vector of task ids. Returns: Cross-entropy loss """ return F.nll_loss(log_alpha, target) @staticmethod def reconstruction_loss(x, x_rec): """A call to pytorch "mse_loss" Args: x: An input sample. x_rec: The reconstruction provided by the recognition AE when seeing input "x". Returns: The MSE loss between x and x_rec. """ return F.mse_loss(x, x_rec)