def test_module_clone_shared_params(self): # Tests proper use of memo parameter class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() cnn = [ torch.nn.Conv2d(3, 32, 3, 2, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 32, 3, 2, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 32, 3, 2, 1), torch.nn.ReLU(), ] self.seq = torch.nn.Sequential(*cnn) self.head = torch.nn.Sequential(*[ torch.nn.Conv2d(32, 32, 3, 2, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 100, 3, 2, 1) ]) self.net = torch.nn.Sequential(self.seq, self.head) def forward(self, x): return self.net(x) original = TestModule() clone = l2l.clone_module(original) self.assertTrue( len(list(clone.parameters())) == len(list(original.parameters())), 'clone and original do not have same number of parameters.', ) orig_params = [p.data_ptr() for p in original.parameters()] duplicates = [p.data_ptr() in orig_params for p in clone.parameters()] self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')
def do_eval(model, train_dl): adaptation_step = 5 step_size = 0.1 # model.linear.bias.register_hook(investigate) model.eval() accuracies = [] mean = lambda x: sum(x) / len(x) for i, batch in enumerate(train_dl): for task_id, (support, query, support_targets, query_targets) in enumerate(batch): support = support.to(device) # B,C,W,H query = query.to(device) support_targets = support_targets.to(device) query_targets = query_targets.to(device) clone = clone_module(model) clone = adapt(clone, support, support_targets, adaptation_steps=adaptation_step, step_size=step_size) logits = clone(query) acc = accuracy(logits, query_targets) accuracies.append(acc) return mean(accuracies)
def clone( self, first_order=None, allow_unused=None, allow_nograd=None, adapt_transform=None, ): """ **Description** Similar to `MAML.clone()`. **Arguments** * **first_order** (bool, *optional*, default=None) - Whether the clone uses first- or second-order updates. Defaults to self.first_order. * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation of unused parameters. Defaults to self.allow_unused. * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with parameters that have `requires_grad = False`. Defaults to self.allow_nograd. """ if first_order is None: first_order = self.first_order if allow_unused is None: allow_unused = self.allow_unused if allow_nograd is None: allow_nograd = self.allow_nograd if adapt_transform is None: adapt_transform = self.adapt_transform module_clone = l2l.clone_module(self.module) update_clone = l2l.clone_module(self.compute_update) return GBML( module=module_clone, transform=self.transform, lr=self.lr, adapt_transform=adapt_transform, first_order=first_order, allow_unused=allow_unused, allow_nograd=allow_nograd, compute_update=update_clone, )
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr): mean_loss = 0.0 mean_kl = 0.0 for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies), total=len(iteration_replays), desc='Surrogate Loss', leave=False): policy.reset_context() train_replays = task_replays[:-1] valid_episodes = task_replays[-1] new_policy = l2l.clone_module(policy) # Fast Adapt for train_episodes in train_replays: new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr, baseline, gamma, tau, first_order=False) # Useful values states = valid_episodes.state() actions = valid_episodes.action() next_states = valid_episodes.next_state() rewards = valid_episodes.reward() dones = valid_episodes.done() # Compute KL old_densities = old_policy.density(states) new_densities = new_policy.density(states) kl = kl_divergence(new_densities, old_densities).mean() mean_kl += kl # Compute Surrogate Loss advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() old_log_probs = old_densities.log_prob(actions).mean( dim=1, keepdim=True).detach() new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True) mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages) mean_kl /= len(iteration_replays) mean_loss /= len(iteration_replays) return mean_loss, mean_kl
def test_clone_module_nomodule(self): # Tests that we can clone non-module objects class TrickyModule(torch.nn.Module): def __init__(self): super(TrickyModule, self).__init__() self.tricky_modules = torch.nn.ModuleList([ torch.nn.Linear(2, 1), None, torch.nn.Linear(1, 1), ]) model = TrickyModule() clone = l2l.clone_module(model) for i, submodule in enumerate(clone.tricky_modules): if i % 2 == 0: self.assertTrue(submodule is not None) else: self.assertTrue(submodule is None)
def meta_surrogate_loss(iter_replays, iter_policies, policy, baseline, params, anil): mean_loss = 0.0 mean_kl = 0.0 for task_replays, old_policy in zip(iter_replays, iter_policies): train_replays = task_replays[:-1] valid_episodes = task_replays[-1] new_policy = clone_module(policy) # Fast Adapt to the training episodes for train_episodes in train_replays: new_policy = trpo_update(train_episodes, new_policy, baseline, params['inner_lr'], params['gamma'], params['tau'], anil=anil, first_order=False) # Calculate KL from the validation episodes states, actions, rewards, dones, next_states = get_episode_values( valid_episodes) # Compute KL old_densities = old_policy.density(states) new_densities = new_policy.density(states) kl = kl_divergence(new_densities, old_densities).mean() mean_kl += kl # Compute Surrogate Loss advantages = compute_advantages(baseline, params['tau'], params['gamma'], rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() old_log_probs = old_densities.log_prob(actions).mean( dim=1, keepdim=True).detach() new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True) mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages) mean_kl /= len(iter_replays) mean_loss /= len(iter_replays) return mean_loss, mean_kl
def test_module_update_shared_params(self): # Tests proper use of memo parameter class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() cnn = [ torch.nn.Conv2d(3, 32, 3, 2, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 32, 3, 2, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 32, 3, 2, 1), torch.nn.ReLU(), ] self.seq = torch.nn.Sequential(*cnn) self.head = torch.nn.Sequential(*[ torch.nn.Conv2d(32, 32, 3, 2, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 100, 3, 2, 1) ]) self.net = torch.nn.Sequential(self.seq, self.head) def forward(self, x): return self.net(x) original = TestModule() num_original = len(list(original.parameters())) clone = l2l.clone_module(original) updates = [torch.randn_like(p) for p in clone.parameters()] l2l.update_module(clone, updates) num_clone = len(list(clone.parameters())) self.assertTrue( num_original == num_clone, 'clone and original do not have same number of parameters.', ) for p, c, u in zip(original.parameters(), clone.parameters(), updates): self.assertTrue( torch.norm(p + u - c, p=2) <= EPSILON, 'clone is not original + update.') orig_params = [p.data_ptr() for p in original.parameters()] duplicates = [p.data_ptr() in orig_params for p in clone.parameters()] self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')
def test_clone_module_models(self): ref_models = [ l2l.vision.models.OmniglotCNN(10), l2l.vision.models.MiniImagenetCNN(10) ] l2l_models = [copy.deepcopy(m) for m in ref_models] inputs = [torch.randn(5, 1, 28, 28), torch.randn(5, 3, 84, 84)] # Compute reference gradients ref_grads = [] for model, X in zip(ref_models, inputs): for iteration in range(10): model.zero_grad() clone = ref_clone_module(model) out = clone(X) out.norm(p=2).backward() self.optimizer_step(model, [p.grad for p in model.parameters()]) ref_grads.append( [p.grad.clone().detach() for p in model.parameters()]) # Compute cloned gradients l2l_grads = [] for model, X in zip(l2l_models, inputs): for iteration in range(10): model.zero_grad() clone = l2l.clone_module(model) out = clone(X) out.norm(p=2).backward() self.optimizer_step(model, [p.grad for p in model.parameters()]) l2l_grads.append( [p.grad.clone().detach() for p in model.parameters()]) # Compare gradients and model parameters for ref_g, l2l_g in zip(ref_grads, l2l_grads): for r_g, l_g in zip(ref_g, l2l_g): self.assertTrue(torch.equal(r_g, l_g)) for ref_model, l2l_model in zip(ref_models, l2l_models): for ref_p, l2l_p in zip(ref_model.parameters(), l2l_model.parameters()): self.assertTrue(torch.equal(ref_p, l2l_p))
def do_train(model, train_dl): adaptation_step = 5 step_size = 0.1 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) ce = torch.nn.CrossEntropyLoss() model.train() mean = lambda x: sum(x) / len(x) for i, batch in enumerate(train_dl): optimizer.zero_grad() batch_accuracies = [] batch_losses = [] for task_id, (support, query, support_targets, query_targets) in enumerate(batch): support = support.to(device) # B,C,W,H query = query.to(device) support_targets = support_targets.to(device) query_targets = query_targets.to(device) clone = clone_module(model) clone = adapt(clone, support, support_targets, adaptation_steps=adaptation_step, step_size=step_size) logits = clone(query) outer_loss = ce(logits, query_targets) outer_loss.backward() acc = accuracy(logits, query_targets) batch_accuracies.append(acc) batch_losses.append(outer_loss.item()) if i % 50 == 0: print('@ {} | {:.2f} {:.4f}'.format(i, mean(batch_accuracies), mean(batch_losses))) optimizer.step()
def test_clone_module_basics(self): original_output = self.model(self.input) original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]])) original_gradients = torch.autograd.grad(original_loss, self.model.parameters(), retain_graph=True, create_graph=True) cloned_model = l2l.clone_module(self.model) self.optimizer_step(self.model, original_gradients) cloned_output = cloned_model(self.input) cloned_loss = self.loss_func(cloned_output, torch.tensor([[0., 0.]])) cloned_gradients = torch.autograd.grad(cloned_loss, cloned_model.parameters(), retain_graph=True, create_graph=True) self.optimizer_step(cloned_model, cloned_gradients) for a, b in zip(self.model.parameters(), cloned_model.parameters()): self.assertTrue(torch.equal(a, b))
def main(env_name='Particles2D-v1', adapt_lr=0.1, meta_lr=1.0, adapt_steps=1, num_iterations=1000, meta_bsz=20, adapt_bsz=20, tau=1.00, gamma=0.95, seed=42, num_workers=10, cuda=0, num_context_params=2): cuda = bool(cuda) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device_name = 'cpu' if cuda: torch.cuda.manual_seed(seed) device_name = 'cuda' device = torch.device(device_name) def make_env(): env = gym.make(env_name) env = ch.envs.ActionSpaceScaler(env) return env env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)]) env.seed(seed) env.set_task(env.sample_tasks(1)[0]) env = ch.envs.Torch(env) policy = CaviaDiagNormalPolicy(env.state_size, env.action_size, num_context_params=num_context_params, device=device) baseline = LinearValue(env.state_size, env.action_size) for iteration in range(num_iterations): iteration_reward = 0.0 iteration_replays = [] iteration_policies = [] for task_config in tqdm(env.sample_tasks(meta_bsz), leave=False, desc='Data'): # Samples a new config # deepcopy is not working here clone = l2l.clone_module(policy) env.set_task(task_config) env.reset() policy.reset_context() task = ch.envs.Runner(env) task_replay = [] # Fast Adapt for step in range(adapt_steps): train_episodes = task.run(clone, episodes=adapt_bsz) if cuda: train_episodes = train_episodes.to(device, non_blocking=True) clone = fast_adapt_a2c(clone, train_episodes, adapt_lr, baseline, gamma, tau, first_order=True) task_replay.append(train_episodes) # Compute Validation Loss valid_episodes = task.run(clone, episodes=adapt_bsz) task_replay.append(valid_episodes) iteration_reward += valid_episodes.reward().sum().item( ) / adapt_bsz iteration_replays.append(task_replay) iteration_policies.append(clone) # Print statistics print('\nIteration', iteration) adaptation_reward = iteration_reward / meta_bsz print('adaptation_reward', adaptation_reward) # TRPO meta-optimization backtrack_factor = 0.8 ls_max_steps = 15 max_kl = 0.01 if cuda: baseline = baseline.to(device, non_blocking=True) iteration_replays = [[ r.to(device, non_blocking=True) for r in task_replays ] for task_replays in iteration_replays] # Compute CG step direction old_loss, old_kl = meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr) grad = autograd.grad(old_loss, policy.parameters(), retain_graph=True) grad = parameters_to_vector([g.detach() for g in grad]) Fvp = trpo.hessian_vector_product(old_kl, policy.parameters()) step = trpo.conjugate_gradient(Fvp, grad) shs = 0.5 * torch.dot(step, Fvp(step)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = step / lagrange_multiplier step_ = [torch.zeros_like(p.data) for p in policy.parameters()] vector_to_parameters(step, step_) step = step_ del old_kl, Fvp, grad old_loss.detach_() # Line-search for ls_step in range(ls_max_steps): stepsize = backtrack_factor**ls_step * meta_lr clone = l2l.clone_module(policy) for p, u in zip(clone.parameters(), step): p.data.add_(-stepsize, u.data) new_loss, kl = meta_surrogate_loss(iteration_replays, iteration_policies, clone, baseline, tau, gamma, adapt_lr) if new_loss < old_loss and kl < max_kl: for p, u in zip(policy.parameters(), step): p.data.add_(-stepsize, u.data) break
def main( fast_lr: float = 0.1, meta_lr: float = 0.003, num_iterations: int = 40000, meta_batch_size: int = 16, adaptation_steps: int = 5, dataset: str = 'cifarfs', layers: int = 4, shots: int = 5, ways: int = 5, cuda: int = 1, # 0 or 1 only seed: int = 1234, ): args = dict(locals()) wandb.init( project='kfo', group='MAML', name='maml-' + dataset, config=args, ) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.is_available(): torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create Tasksets using the benchmark interface tasksets = l2l.vision.benchmarks.get_tasksets( name=dataset, train_samples=2 * shots, train_ways=ways, test_samples=2 * shots, test_ways=ways, root='~/data', ) # Create model and learnable update if dataset == 'cifarfs': model = CIFARCNN(output_size=ways, hidden_size=32, layers=layers) elif dataset == 'mini-imagenet': model = l2l.vision.models.CNN4( output_size=ways, hidden_size=32, layers=layers, ) model.to(device) kfo_transform = l2l.optim.transforms.KroneckerTransform( kronecker_cls=l2l.nn.KroneckerLinear, psd=True, ) fast_update = l2l.optim.ParameterUpdate( parameters=model.parameters(), transform=kfo_transform, ) fast_update.to(device) diff_sgd = l2l.optim.DifferentiableSGD(lr=fast_lr) all_parameters = list(model.parameters()) + list(fast_update.parameters()) opt = torch.optim.Adam(all_parameters, meta_lr) loss = torch.nn.CrossEntropyLoss(reduction='mean') for iteration in tqdm.trange(num_iterations): opt.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 for task in range(meta_batch_size): # Compute meta-training loss task_model = l2l.clone_module(model) task_update = l2l.clone_module(fast_update) batch = tasksets.train.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, task_model, task_update, diff_sgd, loss, adaptation_steps, shots, ways, device, ) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss task_model = l2l.clone_module(model) task_update = l2l.clone_module(fast_update) batch = tasksets.validation.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, task_model, task_update, diff_sgd, loss, adaptation_steps, shots, ways, device, ) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # log some metrics wandb.log( { 'Train Error': meta_train_error / meta_batch_size, 'Train Accuracy': meta_train_accuracy / meta_batch_size, 'Validation Error': meta_valid_error / meta_batch_size, 'Validation Accuracy': meta_valid_accuracy / meta_batch_size, }, step=iteration) # Average the accumulated gradients and optimize for p in model.parameters(): p.grad.data.mul_(1.0 / meta_batch_size) for p in fast_update.parameters(): p.grad.data.mul_(1.0 / meta_batch_size) opt.step() meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_batch_size): # Compute meta-testing loss task_model = l2l.clone_module(model) task_update = l2l.clone_module(fast_update) batch = tasksets.test.sample() evaluation_error, evaluation_accuracy = fast_adapt( batch, task_model, task_update, diff_sgd, loss, adaptation_steps, shots, ways, device, ) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() wandb.log( { 'Test Error': meta_test_error / meta_batch_size, 'Test Accuracy': meta_test_accuracy / meta_batch_size, }, step=iteration)
def main( fast_lr=0.1, meta_lr=0.003, num_iterations=10000, meta_batch_size=16, adaptation_steps=5, shots=5, ways=5, cuda=1, seed=1234 ): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) device = torch.device('cpu') if cuda and torch.cuda.device_count(): torch.cuda.manual_seed(seed) device = torch.device('cuda') # Create Tasksets using the benchmark interface tasksets = l2l.vision.benchmarks.get_tasksets( name='cifarfs', train_samples=2*shots, train_ways=ways, test_samples=2*shots, test_ways=ways, root='~/data', ) # Create model and learnable update model = CifarCNN(output_size=ways) model.to(device) features = model.features classifier = model.linear kfo_transform = l2l.optim.transforms.KroneckerTransform(l2l.nn.KroneckerLinear) fast_update = l2l.optim.ParameterUpdate( parameters=classifier.parameters(), transform=kfo_transform, ) fast_update.to(device) diff_sgd = l2l.optim.DifferentiableSGD(lr=fast_lr) all_parameters = list(model.parameters()) + list(fast_update.parameters()) opt = torch.optim.Adam(all_parameters, meta_lr) loss = torch.nn.CrossEntropyLoss(reduction='mean') for iteration in range(num_iterations): opt.zero_grad() meta_train_error = 0.0 meta_train_accuracy = 0.0 meta_valid_error = 0.0 meta_valid_accuracy = 0.0 for task in range(meta_batch_size): # Compute meta-training loss task_features = l2l.clone_module(features) task_classifier = l2l.clone_module(classifier) task_update = l2l.clone_module(fast_update) batch = tasksets.train.sample() evaluation_error, evaluation_accuracy = fast_adapt(batch, task_features, task_classifier, task_update, diff_sgd, loss, adaptation_steps, shots, ways, device) evaluation_error.backward() meta_train_error += evaluation_error.item() meta_train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss task_features = l2l.clone_module(features) task_classifier = l2l.clone_module(classifier) task_update = l2l.clone_module(fast_update) batch = tasksets.validation.sample() evaluation_error, evaluation_accuracy = fast_adapt(batch, task_features, task_classifier, task_update, diff_sgd, loss, adaptation_steps, shots, ways, device) meta_valid_error += evaluation_error.item() meta_valid_accuracy += evaluation_accuracy.item() # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Error', meta_train_error / meta_batch_size) print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size) print('Meta Valid Error', meta_valid_error / meta_batch_size) print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size) # Average the accumulated gradients and optimize for p in model.parameters(): p.grad.data.mul_(1.0 / meta_batch_size) for p in fast_update.parameters(): p.grad.data.mul_(1.0 / meta_batch_size) opt.step() meta_test_error = 0.0 meta_test_accuracy = 0.0 for task in range(meta_batch_size): # Compute meta-testing loss task_features = l2l.clone_module(features) task_classifier = l2l.clone_module(classifier) task_update = l2l.clone_module(fast_update) batch = tasksets.test.sample() evaluation_error, evaluation_accuracy = fast_adapt(batch, task_features, task_classifier, task_update, diff_sgd, loss, adaptation_steps, shots, ways, device) meta_test_error += evaluation_error.item() meta_test_accuracy += evaluation_accuracy.item() print('Meta Test Error', meta_test_error / meta_batch_size) print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)