def initialize_model(meta_lr: float, decay_lr: _typing.Optional[float] = 1.) -> _typing.Tuple[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]: """Initialize the model, optimizer and lr_scheduler The example here is to load ResNet18. You can write your own class of model, and specify here. Args: meta_lr: learning rate for meta-parameters decay_lr: decay factor of learning rate Returns: net: meta-optimizer: schdlr: """ # net = ResNet18(input_channel=image_size[0], dim_output=n_way, bn_affine=True) net = CNN(dim_output=n_way, image_size=image_size, bn_affine=True) # initialize net.apply(_weights_init) # move to gpu net.to(device) meta_optimizer = torch.optim.Adam(params=net.parameters(), lr=meta_lr) schdlr = torch.optim.lr_scheduler.ExponentialLR(optimizer=meta_optimizer, gamma=decay_lr) return net, meta_optimizer, schdlr
def initialize_model( hyper_net_cls, meta_lr: float, decay_lr: float = 1. ) -> typing.Tuple[torch.nn.Module, torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]: """Initialize the model, optimizer and lr_scheduler The example here is to load ResNet18. You can write your own class of model, and specify here. Args: hyper_net_cls: a handler to refer to a hyper-net class meta_lr: learning rate for meta-parameters decay_lr: decay factor of learning rate Returns: net: meta-optimizer: schdlr: """ if config['base_model'] in ['CNN']: base_net = CNN(dim_output=config['min_way'], image_size=image_size, bn_affine=config['batchnorm']) elif config['base_model'] in ['ResNet18']: base_net = ResNet18(input_channel=image_size[0], dim_output=config['min_way'] if config['min_way'] == config['max_way'] else None, bn_affine=config['batchnorm']) else: raise NotImplementedError hyper_net = hyper_net_cls(base_net=base_net) # move to gpu base_net.to(device) hyper_net.to(device) meta_opt = torch.optim.Adam(params=hyper_net.parameters(), lr=meta_lr) schdlr = torch.optim.lr_scheduler.ExponentialLR(optimizer=meta_opt, gamma=decay_lr) return hyper_net, base_net, meta_opt, schdlr
def load_model( self, resume_epoch: int = None, **kwargs ) -> typing.Tuple[torch.nn.Module, typing.Optional[higher.patch._MonkeyPatchBase], torch.optim.Optimizer]: """Initialize or load the protonet and its optimizer Args: resume_epoch: the index of the file containing the saved model Returns: a tuple consisting of protonet: the prototypical network base_net: dummy to match with MAML and VAMPIRE opt: the optimizer for the prototypical network """ if resume_epoch is None: resume_epoch = self.config['resume_epoch'] if self.config['network_architecture'] == 'CNN': protonet = CNN(dim_output=None, bn_affine=self.config['batchnorm']) elif self.config['network_architecture'] == 'ResNet18': protonet = ResNet18(dim_output=None, bn_affine=self.config['batchnorm']) else: raise NotImplementedError( 'Network architecture is unknown. Please implement it in the CommonModels.py.' ) # --------------------------------------------------------------- # run a dummy task to initialize lazy modules defined in base_net # --------------------------------------------------------------- eps_data = kwargs['eps_generator'].generate_episode(episode_name=None) # split data into train and validation xt, _, _, _ = train_val_split(X=eps_data, k_shot=self.config['k_shot'], shuffle=True) # convert numpy data into torch tensor x_t = torch.from_numpy(xt).float() # run to initialize lazy modules protonet(x_t) # move to device protonet.to(self.config['device']) # optimizer opt = torch.optim.Adam(params=protonet.parameters(), lr=self.config['meta_lr']) # load model if there is saved file if resume_epoch > 0: # path to the saved file checkpoint_path = os.path.join( self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch)) # load file saved_checkpoint = torch.load( f=checkpoint_path, map_location=lambda storage, loc: storage.cuda(self.config[ 'device'].index) if self.config['device'].type == 'cuda' else storage) # load state dictionaries protonet.load_state_dict( state_dict=saved_checkpoint['hyper_net_state_dict']) opt.load_state_dict(state_dict=saved_checkpoint['opt_state_dict']) # update learning rate for param_group in opt.param_groups: if param_group['lr'] != self.config['meta_lr']: param_group['lr'] = self.config['meta_lr'] return protonet, None, opt
def load_maml_like_model( self, resume_epoch: int = None, **kwargs ) -> typing.Tuple[torch.nn.Module, typing.Optional[higher.patch._MonkeyPatchBase], torch.optim.Optimizer]: """Initialize or load the hyper-net and base-net models Args: hyper_net_class: point to the hyper-net class of interest: IdentityNet for MAML or NormalVariationalNet for VAMPIRE resume_epoch: the index of the file containing the saved model Returns: a tuple consisting of hypet_net: the hyper neural network base_net: the base neural network meta_opt: the optimizer for meta-parameter """ if resume_epoch is None: resume_epoch = self.config['resume_epoch'] if self.config['network_architecture'] == 'CNN': base_net = CNN(dim_output=self.config['min_way'], bn_affine=self.config['batchnorm']) elif self.config['network_architecture'] == 'ResNet18': base_net = ResNet18(dim_output=self.config['min_way'], bn_affine=self.config['batchnorm']) elif self.config['network_architecture'] == 'MiniCNN': base_net = MiniCNN(dim_output=self.config['min_way'], bn_affine=self.config['batchnorm']) else: raise NotImplementedError( 'Network architecture is unknown. Please implement it in the CommonModels.py.' ) # --------------------------------------------------------------- # run a dummy task to initialize lazy modules defined in base_net # --------------------------------------------------------------- eps_data = kwargs['eps_generator'].generate_episode(episode_name=None) # split data into train and validation xt, _, _, _ = train_val_split(X=eps_data, k_shot=self.config['k_shot'], shuffle=True) # convert numpy data into torch tensor x_t = torch.from_numpy(xt).float() # run to initialize lazy modules base_net(x_t) params = torch.nn.utils.parameters_to_vector( parameters=base_net.parameters()) print('Number of parameters of the base network = {0:d}.\n'.format( params.numel())) hyper_net = kwargs['hyper_net_class'](base_net=base_net) # move to device base_net.to(self.config['device']) hyper_net.to(self.config['device']) # optimizer meta_opt = torch.optim.Adam(params=hyper_net.parameters(), lr=self.config['meta_lr']) # load model if there is saved file if resume_epoch > 0: # path to the saved file checkpoint_path = os.path.join( self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch)) # load file saved_checkpoint = torch.load( f=checkpoint_path, map_location=lambda storage, loc: storage.cuda(self.config[ 'device'].index) if self.config['device'].type == 'cuda' else storage) # load state dictionaries hyper_net.load_state_dict( state_dict=saved_checkpoint['hyper_net_state_dict']) meta_opt.load_state_dict( state_dict=saved_checkpoint['opt_state_dict']) # update learning rate for param_group in meta_opt.param_groups: if param_group['lr'] != self.config['meta_lr']: param_group['lr'] = self.config['meta_lr'] return hyper_net, base_net, meta_opt