class test_meta_learner(unittest.TestCase): def setUp(self): # Configurations 3-way 3-shot with 3 query set model_dir = 'experiments/base_model' json_path = os.path.join(model_dir, 'params.json') assert os.path.isfile( json_path), "No json configuration file found at {}".format( json_path) params = utils.Params(json_path) params.in_channels = 3 params.num_classes = 5 params.dataset = 'ImageNet' params.cuda = True # Data setting N = 5 self.X = torch.ones([N, params.in_channels, 84, 84]) self.Y = torch.randint(params.num_classes, (N, ), dtype=torch.long) # Optim & loss setting if params.cuda: self.model = MetaLearner(params).cuda() self.X = self.X.cuda() self.Y = self.Y.cuda() else: self.model = MetaLearner(params) self.model.define_task_lr_params() model_params = list(self.model.parameters()) + list( self.model.task_lr.values()) self.optim = torch.optim.SGD(model_params, lr=1e-3) self.loss_fn = nn.NLLLoss() def test_params(self): for key, val in self.model.state_dict().items(): print(key) for key, val in self.model.task_lr.items(): print(key, val.requires_grad) def test_grad_check(self): # Update the model once with data stored_params = { key: val.clone() for key, val in self.model.named_parameters() } task_lr_params = { key: val.clone() for key, val in self.model.task_lr.items() } stored_params.update(task_lr_params) Y_hat = self.model(self.X) loss = self.loss_fn(Y_hat, self.Y) self.optim.zero_grad() loss.backward() self.optim.step() # Test grad check for key, val in self.model.named_parameters(): self.assertTrue((val != stored_params[key]).any())
class test_meta_learner(unittest.TestCase): def setUp(self): # Configurations 3-way 3-shot with 3 query set model_dir = 'experiments/base_model' json_path = os.path.join(model_dir, 'params.json') assert os.path.isfile( json_path), "No json configuration file found at {}".format( json_path) params = utils.Params(json_path) params.in_channels = 3 params.num_classes = 5 params.dataset = 'ImageNet' self.model = MetaLearner(params) # Data setting N = 5 self.X = torch.ones([N, params.in_channels, 84, 84]) self.Y = torch.randint(params.num_classes, (N, ), dtype=torch.long) # Optim & loss setting self.optim = torch.optim.SGD(self.model.parameters(), lr=1e-3) self.loss_fn = nn.NLLLoss() def test_store_cur_params(self): # Store current parameters self.model.store_cur_params() # Update the model once with data Y_hat = self.model(self.X) loss = self.loss_fn(Y_hat, self.Y) self.optim.zero_grad() loss.backward() self.optim.step() # Test stored_params deep copied for key, val in self.model.state_dict().items(): self.assertTrue((val != self.model.stored_params[key]).any()) # @unittest.skip("Adaptation process does not work..") def test_adapt_and_init_params(self): # Store current parameters self.model.store_cur_params() # Update the model once with data Y_hat = self.model(self.X) loss = self.loss_fn(Y_hat, self.Y) self.optim.zero_grad() # grads are in the order of model.parameters() grads = torch.autograd.grad( loss, self.model.parameters(), create_graph=True) # performs updates using calculated gradients # we manually compute adpated parameters since optimizer.step() operates in-place adapted_params = { key: val.clone() for key, val in self.model.state_dict().items() } for (key, val), grad in zip(self.model.named_parameters(), grads): adapted_params[key] = self.model.stored_params[key] - 1e-2 * grad # Check parameter not changed # self.model.check_params_not_changed() # Confirm that adapted_params are different from current params for key, val in self.model.named_parameters(): self.assertTrue((val != adapted_params[key]).any()) # Adapt adapted_params to the model # And confirm that adapted_params are the same to current params self.model.adapt_params(adapted_params) for key, val in adapted_params.items(): self.assertTrue((val == self.model.state_dict()[key]).all()) # Compute loss with adapted parameters # And optimize w.r.t. meta-parameters Y_hat = self.model(self.X) loss = self.loss_fn(Y_hat, self.Y) # Return to meta-parameters before_optim = copy.deepcopy(self.model.state_dict()) self.model.init_params() meta_optim = torch.optim.SGD(self.model.parameters(), lr=1e-3) # self.optim.zero_grad() meta_optim.zero_grad() loss.backward() # self.optim.step() meta_optim.step() # Check meta-parameters updated for key, val in self.model.named_parameters(): self.assertTrue((val != before_optim[key]).any()) # Check adapted-parameters still same @unittest.skip("For debugging purpose") def test_parameter_name(self): print(self.model.state_dict().keys()) print(len(self.model.state_dict().keys())) print(self.model.meta_learner.state_dict().keys()) print(len(self.model.meta_learner.state_dict().keys())) self.model.meta_learner.state_dict()['fc.bias'][0] = 0 print(self.model.state_dict()['meta_learner.fc.bias']) print(self.model.meta_learner.state_dict()['fc.bias'])