Exemplo n.º 1
0
def calc_reg_masks(data_handlers, mnet, device, config):
    """Compute the regularizer mask for each task when using a multi-head setup.
    See method "get_reg_masks" of class "MainNetwork" for more details.

    Deprecated: Method "calc_fix_target_reg" has its own way of handling
    multihead setups that is more memory efficient than keeping a mask for each
    task.

    Args:
        (....): See docstring of method :func:`train_reg`.
        data_handlers: A list of all data_handlers, one for each task.

    Returns:
        A list of regularizer masks.
    """
    assert (config.multi_head and config.masked_reg)

    warn('"calc_reg_masks" is deprecated and not maintained as it is unused ' +
         'currently.', DeprecationWarning)

    assert (mnet.has_fc_out)
    main_shapes = mnet.hyper_shapes

    masks = []
    for i, data in enumerate(data_handlers):
        n_y = data.out_shape[0]
        allowed_outputs = list(range(i * n_y, (i + 1) * n_y))

        masks.append(MainNetwork.get_reg_masks(main_shapes, allowed_outputs,
                                               device, use_bias=mnet.has_bias))

    return masks
Exemplo n.º 2
0
def _generate_networks(config,
                       data_handlers,
                       device,
                       create_hnet=True,
                       create_rnet=False,
                       no_replay=False):
    """Create the main-net, hypernetwork and recognition network.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        create_hnet: Whether a hypernetwork should be constructed. If not, the
            main network will have trainable weights.
        create_rnet: Whether a task-recognition autoencoder should be created.
        no_replay: If the recognition network should be an instance of class
            MainModel rather than of class RecognitionNet (note, for multitask
            learning, no replay network is required).

    Returns:
        mnet: Main network instance.
        hnet: Hypernetwork instance. This return value is None if no
            hypernetwork should be constructed.
        rnet: RecognitionNet instance. This return value is None if no
            recognition network should be constructed.
    """
    num_tasks = len(data_handlers)

    n_x = data_handlers[0].in_shape[0]
    n_y = data_handlers[0].out_shape[0]
    if config.multi_head:
        n_y = n_y * num_tasks

    main_arch = misc.str_to_ints(config.main_arch)
    main_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                            n_out=n_y,
                                            hidden_layers=main_arch)
    mnet = MainNetwork(main_shapes,
                       activation_fn=misc.str_to_act(config.main_act),
                       use_bias=True,
                       no_weights=create_hnet).to(device)
    if create_hnet:
        hnet_arch = misc.str_to_ints(config.hnet_arch)
        hnet = HyperNetwork(main_shapes,
                            num_tasks,
                            layers=hnet_arch,
                            te_dim=config.emb_size,
                            activation_fn=misc.str_to_act(
                                config.hnet_act)).to(device)
        init_params = list(hnet.parameters())
    else:
        hnet = None
        init_params = list(mnet.parameters())

    if create_rnet:
        ae_arch = misc.str_to_ints(config.ae_arch)
        if no_replay:
            rnet_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                                    n_out=num_tasks,
                                                    hidden_layers=ae_arch,
                                                    use_bias=True)
            rnet = MainNetwork(rnet_shapes,
                               activation_fn=misc.str_to_act(config.ae_act),
                               use_bias=True,
                               no_weights=False,
                               dropout_rate=-1,
                               out_fn=lambda x: F.softmax(x, dim=1))
        else:
            rnet = RecognitionNet(n_x,
                                  num_tasks,
                                  dim_z=config.ae_dim_z,
                                  enc_layers=ae_arch,
                                  activation_fn=misc.str_to_act(config.ae_act),
                                  use_bias=True).to(device)
        init_params += list(rnet.parameters())
    else:
        rnet = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        elif config.normal_init:
            torch.nn.init.normal_(W, mean=0, std=config.std_normal_init)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_hnet:
        for temb in hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if config.use_hyperfan_init:
        hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2)

    return mnet, hnet, rnet
Exemplo n.º 3
0
    def __init__(self,
                 n_in,
                 n_tasks,
                 dim_z=8,
                 enc_layers=[10, 10],
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True):
        """Initialize the network.

        Args:
            n_in: Input size (input dim of encoder and output dim of decoder).
            n_tasks: The maximum number of tasks to be detected (size of
                softmax layer).
            dim_z: Dimensionality of latent space z.
            enc_layers: A list of integers, each denoting the size of a hidden
                layer in the encoder. The decoder will have layer sizes in
                reverse order.
            activation_fn: The nonlinearity used in hidden layers. If None, no
                nonlinearity will be applied.
            use_bias: Whether layers may have bias terms.
        """
        super(RecognitionNet, self).__init__()

        self._n_alpha = n_tasks
        self._n_nu_z = 2 * dim_z
        self._n_z = dim_z

        ## Enoder
        encoder_shapes = MainNetwork.weight_shapes(n_in=n_in,
                                                   n_out=self._n_alpha +
                                                   self._n_nu_z,
                                                   hidden_layers=enc_layers,
                                                   use_bias=use_bias)
        self._encoder = MainNetwork(encoder_shapes,
                                    activation_fn=activation_fn,
                                    use_bias=use_bias,
                                    no_weights=False,
                                    dropout_rate=-1,
                                    verbose=False)
        self._weights_enc = self._encoder.weights

        ## Decoder
        decoder_shapes = MainNetwork.weight_shapes(
            n_in=self._n_alpha + self._n_z,
            n_out=n_in,
            hidden_layers=list(reversed(enc_layers)),
            use_bias=use_bias)
        self._decoder = MainNetwork(decoder_shapes,
                                    activation_fn=activation_fn,
                                    use_bias=use_bias,
                                    no_weights=False,
                                    dropout_rate=-1,
                                    verbose=False)
        self._weights_dec = self._decoder.weights

        ## Prior
        # Note, when changing the prior, one has to change the method
        # "prior_matching".
        self._mu_z = torch.zeros(dim_z)
        self._sigma_z = torch.ones(dim_z)

        n_params = np.sum([np.prod(p.shape) for p in self.parameters()])
        print('Constructed recognition model with %d parameters.' % n_params)
Exemplo n.º 4
0
class RecognitionNet(nn.Module):
    """The recognition network consists of an encoder and decoder. The encoder
    gets as input X (the same input as the main network) and has two outputs:
    a softmax layer named alpha, that can be used to determine the task inferred
    from the input; and a latent embedding nu_z, that we interpret as parameters
    of a normal distribution (mean and log-variances). Using the
    reparametrization trick, we can sample z.
    Both, alpha and z are going to be the input to the decoder. The decoder aims
    to reconstruct the input of the encoder.

    The network consists only of fully-connected layers. The decoder
    architecture is a mirrored version of the encoder architecture, except for
    the fact that the input to the decoder is z and not nu_z.

    Attributes (additional to base class):
        encoder_weights: All parameters of the encoder network.
        decoder_weights: All parameters of the decoder network.
        dim_alpha: Dimensionality of the softmax layer alpha.
        dim_z: Dimensionality of the latent embeddings z.
    """
    def __init__(self,
                 n_in,
                 n_tasks,
                 dim_z=8,
                 enc_layers=[10, 10],
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True):
        """Initialize the network.

        Args:
            n_in: Input size (input dim of encoder and output dim of decoder).
            n_tasks: The maximum number of tasks to be detected (size of
                softmax layer).
            dim_z: Dimensionality of latent space z.
            enc_layers: A list of integers, each denoting the size of a hidden
                layer in the encoder. The decoder will have layer sizes in
                reverse order.
            activation_fn: The nonlinearity used in hidden layers. If None, no
                nonlinearity will be applied.
            use_bias: Whether layers may have bias terms.
        """
        super(RecognitionNet, self).__init__()

        self._n_alpha = n_tasks
        self._n_nu_z = 2 * dim_z
        self._n_z = dim_z

        ## Enoder
        encoder_shapes = MainNetwork.weight_shapes(n_in=n_in,
                                                   n_out=self._n_alpha +
                                                   self._n_nu_z,
                                                   hidden_layers=enc_layers,
                                                   use_bias=use_bias)
        self._encoder = MainNetwork(encoder_shapes,
                                    activation_fn=activation_fn,
                                    use_bias=use_bias,
                                    no_weights=False,
                                    dropout_rate=-1,
                                    verbose=False)
        self._weights_enc = self._encoder.weights

        ## Decoder
        decoder_shapes = MainNetwork.weight_shapes(
            n_in=self._n_alpha + self._n_z,
            n_out=n_in,
            hidden_layers=list(reversed(enc_layers)),
            use_bias=use_bias)
        self._decoder = MainNetwork(decoder_shapes,
                                    activation_fn=activation_fn,
                                    use_bias=use_bias,
                                    no_weights=False,
                                    dropout_rate=-1,
                                    verbose=False)
        self._weights_dec = self._decoder.weights

        ## Prior
        # Note, when changing the prior, one has to change the method
        # "prior_matching".
        self._mu_z = torch.zeros(dim_z)
        self._sigma_z = torch.ones(dim_z)

        n_params = np.sum([np.prod(p.shape) for p in self.parameters()])
        print('Constructed recognition model with %d parameters.' % n_params)

    @property
    def dim_alpha(self):
        """Getter for read-only attribute dim_alpha.

        Returns:
            Size of alpha layer.
        """
        return self._n_alpha

    @property
    def dim_z(self):
        """Getter for read-only attribute dim_z.

        Returns:
            Size of z layer.
        """
        return self._n_z

    @property
    def encoder_weights(self):
        """Getter for read-only attribute encoder_weights.

        Returns:
            A torch.nn.ParameterList.
        """
        return self._weights_enc

    @property
    def decoder_weights(self):
        """Getter for read-only attribute decoder_weights.

        Returns:
            A torch.nn.ParameterList.
        """
        return self._weights_dec

    def forward(self, x):
        """This function computes
            x_rec = decode(encode(x))

        Note, the function utilizes the class members "encode" and "decode".

        Args:
            x: The input to the "autoencoder".

        Returns:
            x_rec: The reconstruction of the input.
        """
        alpha, _, z = self.encode(x)
        x_rec = self.decode(alpha, z)

        return x_rec

    def encode(self, x, ret_log_alpha=False, encoder_weights=None):
        """Encode a sample x -> "recognize the task of x".
        
        Args:
            x: An input sample (from which a task should be inferred).
            ret_log_alpha (optional): Whether the log-softmax distribution of
                the output layer alpha should be returned as well.
            encoder_weights (optional): If given, these will be the parameters
                used in the encoder rather than the ones maintained object
                internally.

        Returns:
            (tuple): Tuple containing:

            - **alpha**: The softmax output (task classification output).
            - **nu_z**: The parameters of the latent distribution from which "z" is
              sampled (i.e., the actual output of the encoder besides alpha).
              Note, that these parameters are the cooncatenated means and
              log-variances of the latent distribution.
            - **z**: A latent space embedding retrieved via the
              reparametrization trick.
            - **log_alpha** (optional): The log softmax activity of alpha.
        """
        phi_e = None
        if encoder_weights is not None:
            phi_e = encoder_weights

        h = self._encoder.forward(x, weights=phi_e)

        h_alpha = h[:, :self._n_alpha]
        alpha = F.softmax(h_alpha, dim=1)

        params_z = h[:, self._n_alpha:]
        mu_z = params_z[:, :self._n_z]
        logvar_z = params_z[:, self._n_z:]

        std_z = torch.exp(0.5 * logvar_z)
        z = Normal(mu_z, std_z).rsample()

        if ret_log_alpha:
            return alpha, params_z, z, F.log_softmax(h_alpha, dim=1)

        return alpha, params_z, z

    def decode(self, alpha, z, decoder_weights=None):
        """Decode a latent representation back to a sample.
        If alpha is a 1-hot encoding denoting a specific task and z are latent
        space samples, the decoding can be seen as "replay" of task samples.

        Args:
            alpha: See return value of method "encode".
            z: See return value of method "encode".
            decoder_weights (optional): If given, these will be the parameters
                used in the decoder rather than the ones maintained object
                internally.

        Returns:
            x_dec: The decoded sample.
        """
        phi_d = None
        if decoder_weights is not None:
            phi_d = decoder_weights

        x_dec = self._decoder.forward(torch.cat([alpha, z], dim=1),
                                      weights=phi_d)

        return x_dec

    def prior_samples(self, batch_size):
        """Obtain a batch of samples from the prior for the latent space z.

        Args:
            batch_size: Number of samples to acquire.

        Returns:
            A torch tensor of samples.
        """
        return Normal(self._mu_z, self._sigma_z).rsample([batch_size])

    def prior_matching(self, nu_z):
        """Compute the prior matching term between the Gaussian described by
        the parameters "nu_z" and a standard normal distribution N(0, I).

        Args:
            nu_z: Part of the encoder output.

        Returns:
            The value of the prior matching loss.
        """
        mu_z = nu_z[:, :self._n_z]
        logvar_z = nu_z[:, self._n_z:]

        var_z = logvar_z.exp()

        return -0.5 * torch.sum(1 + logvar_z - mu_z.pow(2) - var_z)

    @staticmethod
    def task_cross_entropy(log_alpha, target):
        """A call to pytorch "nll_loss".

        Args:
            log_alpha: The log softmax activity of the alpha layer.
            target: A vector of task ids.

        Returns:
            Cross-entropy loss
        """
        return F.nll_loss(log_alpha, target)

    @staticmethod
    def reconstruction_loss(x, x_rec):
        """A call to pytorch "mse_loss"

        Args:
            x: An input sample.
            x_rec: The reconstruction provided by the recognition AE when seeing
                input "x".

        Returns:
            The MSE loss between x and x_rec.
        """
        return F.mse_loss(x, x_rec)