Exemplo n.º 1
0
def _generate_networks(config,
                       data_handlers,
                       device,
                       create_hnet=True,
                       create_rnet=False,
                       no_replay=False):
    """Create the main-net, hypernetwork and recognition network.

    Args:
        config: Command-line arguments.
        data_handlers: List of data handlers, one for each task. Needed to
            extract the number of inputs/outputs of the main network. And to
            infer the number of tasks.
        device: Torch device.
        create_hnet: Whether a hypernetwork should be constructed. If not, the
            main network will have trainable weights.
        create_rnet: Whether a task-recognition autoencoder should be created.
        no_replay: If the recognition network should be an instance of class
            MainModel rather than of class RecognitionNet (note, for multitask
            learning, no replay network is required).

    Returns:
        mnet: Main network instance.
        hnet: Hypernetwork instance. This return value is None if no
            hypernetwork should be constructed.
        rnet: RecognitionNet instance. This return value is None if no
            recognition network should be constructed.
    """
    num_tasks = len(data_handlers)

    n_x = data_handlers[0].in_shape[0]
    n_y = data_handlers[0].out_shape[0]
    if config.multi_head:
        n_y = n_y * num_tasks

    main_arch = misc.str_to_ints(config.main_arch)
    main_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                            n_out=n_y,
                                            hidden_layers=main_arch)
    mnet = MainNetwork(main_shapes,
                       activation_fn=misc.str_to_act(config.main_act),
                       use_bias=True,
                       no_weights=create_hnet).to(device)
    if create_hnet:
        hnet_arch = misc.str_to_ints(config.hnet_arch)
        hnet = HyperNetwork(main_shapes,
                            num_tasks,
                            layers=hnet_arch,
                            te_dim=config.emb_size,
                            activation_fn=misc.str_to_act(
                                config.hnet_act)).to(device)
        init_params = list(hnet.parameters())
    else:
        hnet = None
        init_params = list(mnet.parameters())

    if create_rnet:
        ae_arch = misc.str_to_ints(config.ae_arch)
        if no_replay:
            rnet_shapes = MainNetwork.weight_shapes(n_in=n_x,
                                                    n_out=num_tasks,
                                                    hidden_layers=ae_arch,
                                                    use_bias=True)
            rnet = MainNetwork(rnet_shapes,
                               activation_fn=misc.str_to_act(config.ae_act),
                               use_bias=True,
                               no_weights=False,
                               dropout_rate=-1,
                               out_fn=lambda x: F.softmax(x, dim=1))
        else:
            rnet = RecognitionNet(n_x,
                                  num_tasks,
                                  dim_z=config.ae_dim_z,
                                  enc_layers=ae_arch,
                                  activation_fn=misc.str_to_act(config.ae_act),
                                  use_bias=True).to(device)
        init_params += list(rnet.parameters())
    else:
        rnet = None

    ### Initialize network weights.
    for W in init_params:
        if W.ndimension() == 1:  # Bias vector.
            torch.nn.init.constant_(W, 0)
        elif config.normal_init:
            torch.nn.init.normal_(W, mean=0, std=config.std_normal_init)
        else:
            torch.nn.init.xavier_uniform_(W)

    # The task embeddings are initialized differently.
    if create_hnet:
        for temb in hnet.get_task_embs():
            torch.nn.init.normal_(temb, mean=0., std=config.std_normal_temb)

    if config.use_hyperfan_init:
        hnet.apply_hyperfan_init(temb_var=config.std_normal_temb**2)

    return mnet, hnet, rnet
Exemplo n.º 2
0
    def __init__(self,
                 n_in,
                 n_tasks,
                 dim_z=8,
                 enc_layers=[10, 10],
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True):
        """Initialize the network.

        Args:
            n_in: Input size (input dim of encoder and output dim of decoder).
            n_tasks: The maximum number of tasks to be detected (size of
                softmax layer).
            dim_z: Dimensionality of latent space z.
            enc_layers: A list of integers, each denoting the size of a hidden
                layer in the encoder. The decoder will have layer sizes in
                reverse order.
            activation_fn: The nonlinearity used in hidden layers. If None, no
                nonlinearity will be applied.
            use_bias: Whether layers may have bias terms.
        """
        super(RecognitionNet, self).__init__()

        self._n_alpha = n_tasks
        self._n_nu_z = 2 * dim_z
        self._n_z = dim_z

        ## Enoder
        encoder_shapes = MainNetwork.weight_shapes(n_in=n_in,
                                                   n_out=self._n_alpha +
                                                   self._n_nu_z,
                                                   hidden_layers=enc_layers,
                                                   use_bias=use_bias)
        self._encoder = MainNetwork(encoder_shapes,
                                    activation_fn=activation_fn,
                                    use_bias=use_bias,
                                    no_weights=False,
                                    dropout_rate=-1,
                                    verbose=False)
        self._weights_enc = self._encoder.weights

        ## Decoder
        decoder_shapes = MainNetwork.weight_shapes(
            n_in=self._n_alpha + self._n_z,
            n_out=n_in,
            hidden_layers=list(reversed(enc_layers)),
            use_bias=use_bias)
        self._decoder = MainNetwork(decoder_shapes,
                                    activation_fn=activation_fn,
                                    use_bias=use_bias,
                                    no_weights=False,
                                    dropout_rate=-1,
                                    verbose=False)
        self._weights_dec = self._decoder.weights

        ## Prior
        # Note, when changing the prior, one has to change the method
        # "prior_matching".
        self._mu_z = torch.zeros(dim_z)
        self._sigma_z = torch.ones(dim_z)

        n_params = np.sum([np.prod(p.shape) for p in self.parameters()])
        print('Constructed recognition model with %d parameters.' % n_params)