def initialize_model(meta_lr: float, decay_lr: _typing.Optional[float] = 1.) -> _typing.Tuple[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
    """Initialize the model, optimizer and lr_scheduler
    The example here is to load ResNet18. You can write
    your own class of model, and specify here.

    Args:
        meta_lr: learning rate for meta-parameters
        decay_lr: decay factor of learning rate

    Returns:
        net:
        meta-optimizer:
        schdlr:
    """
    # net = ResNet18(input_channel=image_size[0], dim_output=n_way, bn_affine=True)
    net = CNN(dim_output=n_way, image_size=image_size, bn_affine=True)

    # initialize
    net.apply(_weights_init)

    # move to gpu
    net.to(device)

    meta_optimizer = torch.optim.Adam(params=net.parameters(), lr=meta_lr)
    schdlr = torch.optim.lr_scheduler.ExponentialLR(optimizer=meta_optimizer, gamma=decay_lr)

    return net, meta_optimizer, schdlr
Example #2
0
def initialize_model(
    hyper_net_cls,
    meta_lr: float,
    decay_lr: float = 1.
) -> typing.Tuple[torch.nn.Module, torch.nn.Module, torch.optim.Optimizer,
                  torch.optim.lr_scheduler._LRScheduler]:
    """Initialize the model, optimizer and lr_scheduler
    The example here is to load ResNet18. You can write
    your own class of model, and specify here.

    Args:
        hyper_net_cls: a handler to refer to a hyper-net class
        meta_lr: learning rate for meta-parameters
        decay_lr: decay factor of learning rate

    Returns:
        net:
        meta-optimizer:
        schdlr:
    """
    if config['base_model'] in ['CNN']:
        base_net = CNN(dim_output=config['min_way'],
                       image_size=image_size,
                       bn_affine=config['batchnorm'])
    elif config['base_model'] in ['ResNet18']:
        base_net = ResNet18(input_channel=image_size[0],
                            dim_output=config['min_way'] if config['min_way']
                            == config['max_way'] else None,
                            bn_affine=config['batchnorm'])
    else:
        raise NotImplementedError

    hyper_net = hyper_net_cls(base_net=base_net)

    # move to gpu
    base_net.to(device)
    hyper_net.to(device)

    meta_opt = torch.optim.Adam(params=hyper_net.parameters(), lr=meta_lr)
    schdlr = torch.optim.lr_scheduler.ExponentialLR(optimizer=meta_opt,
                                                    gamma=decay_lr)

    return hyper_net, base_net, meta_opt, schdlr
Example #3
0
    def load_model(
        self,
        resume_epoch: int = None,
        **kwargs
    ) -> typing.Tuple[torch.nn.Module,
                      typing.Optional[higher.patch._MonkeyPatchBase],
                      torch.optim.Optimizer]:
        """Initialize or load the protonet and its optimizer

        Args:
            resume_epoch: the index of the file containing the saved model

        Returns: a tuple consisting of
            protonet: the prototypical network
            base_net: dummy to match with MAML and VAMPIRE
            opt: the optimizer for the prototypical network
        """
        if resume_epoch is None:
            resume_epoch = self.config['resume_epoch']

        if self.config['network_architecture'] == 'CNN':
            protonet = CNN(dim_output=None, bn_affine=self.config['batchnorm'])
        elif self.config['network_architecture'] == 'ResNet18':
            protonet = ResNet18(dim_output=None,
                                bn_affine=self.config['batchnorm'])
        else:
            raise NotImplementedError(
                'Network architecture is unknown. Please implement it in the CommonModels.py.'
            )

        # ---------------------------------------------------------------
        # run a dummy task to initialize lazy modules defined in base_net
        # ---------------------------------------------------------------
        eps_data = kwargs['eps_generator'].generate_episode(episode_name=None)
        # split data into train and validation
        xt, _, _, _ = train_val_split(X=eps_data,
                                      k_shot=self.config['k_shot'],
                                      shuffle=True)
        # convert numpy data into torch tensor
        x_t = torch.from_numpy(xt).float()
        # run to initialize lazy modules
        protonet(x_t)

        # move to device
        protonet.to(self.config['device'])

        # optimizer
        opt = torch.optim.Adam(params=protonet.parameters(),
                               lr=self.config['meta_lr'])

        # load model if there is saved file
        if resume_epoch > 0:
            # path to the saved file
            checkpoint_path = os.path.join(
                self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch))

            # load file
            saved_checkpoint = torch.load(
                f=checkpoint_path,
                map_location=lambda storage, loc: storage.cuda(self.config[
                    'device'].index)
                if self.config['device'].type == 'cuda' else storage)

            # load state dictionaries
            protonet.load_state_dict(
                state_dict=saved_checkpoint['hyper_net_state_dict'])
            opt.load_state_dict(state_dict=saved_checkpoint['opt_state_dict'])

            # update learning rate
            for param_group in opt.param_groups:
                if param_group['lr'] != self.config['meta_lr']:
                    param_group['lr'] = self.config['meta_lr']

        return protonet, None, opt
    def load_maml_like_model(
        self,
        resume_epoch: int = None,
        **kwargs
    ) -> typing.Tuple[torch.nn.Module,
                      typing.Optional[higher.patch._MonkeyPatchBase],
                      torch.optim.Optimizer]:
        """Initialize or load the hyper-net and base-net models

        Args:
            hyper_net_class: point to the hyper-net class of interest: IdentityNet for MAML or NormalVariationalNet for VAMPIRE
            resume_epoch: the index of the file containing the saved model

        Returns: a tuple consisting of
            hypet_net: the hyper neural network
            base_net: the base neural network
            meta_opt: the optimizer for meta-parameter
        """
        if resume_epoch is None:
            resume_epoch = self.config['resume_epoch']

        if self.config['network_architecture'] == 'CNN':
            base_net = CNN(dim_output=self.config['min_way'],
                           bn_affine=self.config['batchnorm'])
        elif self.config['network_architecture'] == 'ResNet18':
            base_net = ResNet18(dim_output=self.config['min_way'],
                                bn_affine=self.config['batchnorm'])
        elif self.config['network_architecture'] == 'MiniCNN':
            base_net = MiniCNN(dim_output=self.config['min_way'],
                               bn_affine=self.config['batchnorm'])
        else:
            raise NotImplementedError(
                'Network architecture is unknown. Please implement it in the CommonModels.py.'
            )

        # ---------------------------------------------------------------
        # run a dummy task to initialize lazy modules defined in base_net
        # ---------------------------------------------------------------
        eps_data = kwargs['eps_generator'].generate_episode(episode_name=None)
        # split data into train and validation
        xt, _, _, _ = train_val_split(X=eps_data,
                                      k_shot=self.config['k_shot'],
                                      shuffle=True)
        # convert numpy data into torch tensor
        x_t = torch.from_numpy(xt).float()
        # run to initialize lazy modules
        base_net(x_t)
        params = torch.nn.utils.parameters_to_vector(
            parameters=base_net.parameters())
        print('Number of parameters of the base network = {0:d}.\n'.format(
            params.numel()))

        hyper_net = kwargs['hyper_net_class'](base_net=base_net)

        # move to device
        base_net.to(self.config['device'])
        hyper_net.to(self.config['device'])

        # optimizer
        meta_opt = torch.optim.Adam(params=hyper_net.parameters(),
                                    lr=self.config['meta_lr'])

        # load model if there is saved file
        if resume_epoch > 0:
            # path to the saved file
            checkpoint_path = os.path.join(
                self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch))

            # load file
            saved_checkpoint = torch.load(
                f=checkpoint_path,
                map_location=lambda storage, loc: storage.cuda(self.config[
                    'device'].index)
                if self.config['device'].type == 'cuda' else storage)

            # load state dictionaries
            hyper_net.load_state_dict(
                state_dict=saved_checkpoint['hyper_net_state_dict'])
            meta_opt.load_state_dict(
                state_dict=saved_checkpoint['opt_state_dict'])

            # update learning rate
            for param_group in meta_opt.param_groups:
                if param_group['lr'] != self.config['meta_lr']:
                    param_group['lr'] = self.config['meta_lr']

        return hyper_net, base_net, meta_opt