def compute_outer_grads( self, task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """Compute gradients on query set.""" support_input, support_target, query_input, query_target = task with higher.innerloop_ctx(self.head, self.in_optim, copy_initial_weights=False, track_higher_grads=meta_train) as (fhead, diffopt): with torch.no_grad(): support_feature = self.encoder(support_input) # inner loop (adapt) self.inner_loop(fhead, diffopt, support_feature, support_target) # evaluate on the query set with torch.set_grad_enabled(meta_train): quert_feature = self.encoder(query_input) query_output = fhead(quert_feature) query_loss = self.loss_function(query_output, query_target) query_loss /= len(query_input) # compute gradients when in the meta-training stage if meta_train == True: (query_loss / n_tasks).backward() return query_output, query_loss
def train(self, niter=3): if self.jit: raise NotImplementedError() net, _ = self.get_module() net.train() x_spt, y_spt, x_qry, y_qry = self.meta_inputs meta_opt = optim.Adam(net.parameters(), lr=1e-3) for _ in range(niter): task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) meta_opt.zero_grad() for i in range(task_num): with higher.innerloop_ctx( net, inner_opt, copy_initial_weights=False) as (fnet, diffopt): for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) qry_logits = fnet(x_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i]) qry_loss.backward() meta_opt.step()
def fit(self, dataset_train): """ The learner's fit function over the train set of a task. Args: dataset_train : a tf.data.Dataset object. Iterates over the training examples (support set). Returns: predictor : An instance of MyPredictor that is initilialized with the fine-tuned learner's weights in this case. """ self.learner.train() for images, labels in dataset_train: images, labels = self.process_task(images, labels) with higher.innerloop_ctx(self.learner, self.optimizer, track_higher_grads=False) as (fnet, diffopt): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. for _ in range(self.n_inner_iter): spt_logits = fnet(images) spt_loss = F.cross_entropy(spt_logits, labels) diffopt.step(spt_loss) predictor = MyPredictor(fnet) break return predictor
def evaluate(self, dataloader, updates, mini_batch_size): support_set = [] for _ in range(updates): text, labels = self.memory.read_batch(batch_size=mini_batch_size) support_set.append((text, labels)) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop task_predictions, task_labels = [], [] support_loss = [] for text, labels in support_set: labels = torch.tensor(labels).to(self.device) input_dict = self.pn.encode_text(text) repr = fpn(input_dict, out_from='transformers') modulation = self.nm(input_dict) output = fpn(repr * modulation, out_from='linear') loss = self.loss_fn(output, labels) diffopt.step(loss) pred = models.utils.make_prediction(output.detach()) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(labels.tolist()) acc, prec, rec, f1 = models.utils.calculate_metrics( task_predictions, task_labels) logger.info( 'Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, ' 'recall = {:.4f}, F1 score = {:.4f}'.format( np.mean(support_loss), acc, prec, rec, f1)) all_losses, all_predictions, all_labels = [], [], [] for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.pn.encode_text(text) with torch.no_grad(): repr = fpn(input_dict, out_from='transformers') modulation = self.nm(input_dict) output = fpn(repr * modulation, out_from='linear') loss = self.loss_fn(output, labels) loss = loss.item() pred = models.utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) acc, prec, rec, f1 = models.utils.calculate_metrics( all_predictions, all_labels) logger.info( 'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec, f1)) return acc, prec, rec, f1
def train(step_idx, data, net, inner_opt_builder, meta_opt, n_inner_iter): """Main meta-training step.""" x_spt, y_spt, x_qry, y_qry = data task_num = x_spt.size()[0] querysz = x_qry.size(1) inner_opt = inner_opt_builder.inner_opt qry_losses = [] meta_opt.zero_grad() for i in range(task_num): with higher.innerloop_ctx( net, inner_opt, copy_initial_weights=False, override=inner_opt_builder.overrides, ) as ( fnet, diffopt, ): for _ in range(n_inner_iter): spt_pred = fnet(x_spt[i]) spt_loss = F.mse_loss(spt_pred, y_spt[i]) diffopt.step(spt_loss) qry_pred = fnet(x_qry[i]) qry_loss = F.mse_loss(qry_pred, y_qry[i]) qry_losses.append(qry_loss.detach().cpu().numpy()) qry_loss.backward() metrics = {"train_loss": np.mean(qry_losses)} wandb.log(metrics, step=step_idx) meta_opt.step()
def run_batches(self, batches, train=True, meta_train=True): metrics = [] device = next(self.model.parameters()).device shufflers = { key: shuffler() for key, shuffler in self._shuffler_factory.items() } with torch.backends.cudnn.flags(enabled=False): with higher.innerloop_ctx(self.model, self.inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt): for n, inputs in enumerate(batches[:-1]): inputs = self.shuffle_labels(inputs, shufflers) inputs = move_to_device(inputs, device) output_dict = fmodel(**inputs, **self.forward_kwargs(n)) loss = output_dict["loss"] metric = output_dict["metric"] diffopt.step(loss) metrics.append({"loss": loss.item(), "metric": metric}) inputs = self.shuffle_labels(batches[-1], shufflers) inputs = move_to_device(inputs, device) output_dict = fmodel(**inputs, **self.forward_kwargs(len(batches) - 1)) loss = output_dict["loss"] metric = output_dict["metric"] loss.backward() metrics.append({"loss": loss.item(), "metric": metric}) return metrics
def maml_val(tepoch, model, inner_criterion, outer_criterion, inner_optimizer, num_adapt_steps): model.train() test_losses = [] for batch in tepoch: tepoch.set_description(f"Validation") batch['train'][0] = batch['train'][0].view(1, -1, 6, 36, 36) batch['test'][0] = batch['test'][0].view(1, -1, 6, 36, 36) batch['train'][1] = batch['train'][1].view(1, -1, 1) batch['test'][1] = batch['test'][1].view(1, -1, 1) inputs, targets = batch['train'] test_inputs, test_targets = batch['test'] for task_idx, (input, target, test_input, test_target) in enumerate( zip(inputs, targets, test_inputs, test_targets)): with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt): for step in range(num_adapt_steps): inner_loss = inner_criterion(fmodel(input), target) diffopt.step(inner_loss) test_logit = fmodel(test_input).detach() test_loss = outer_criterion(test_logit, test_target) test_losses.append(test_loss.detach()) losses = sum(test_losses) / len(tepoch) tepoch.set_postfix(loss=losses) '''
def test(step_idx, data, net, inner_opt_builder, n_inner_iter): """Main meta-training step.""" x_spt, y_spt, x_qry, y_qry = data task_num = x_spt.size()[0] querysz = x_qry.size(1) inner_opt = inner_opt_builder.inner_opt qry_losses = [] for i in range(task_num): with higher.innerloop_ctx( net, inner_opt, track_higher_grads=False, override=inner_opt_builder.overrides, ) as ( fnet, diffopt, ): for _ in range(n_inner_iter): spt_pred = fnet(x_spt[i]) spt_loss = F.mse_loss(spt_pred, y_spt[i]) diffopt.step(spt_loss) qry_pred = fnet(x_qry[i]) qry_loss = F.mse_loss(qry_pred, y_qry[i]) qry_losses.append(qry_loss.detach().cpu().numpy()) avg_qry_loss = np.mean(qry_losses) _low, high = st.t.interval( 0.95, len(qry_losses) - 1, loc=avg_qry_loss, scale=st.sem(qry_losses) ) test_metrics = {"test_loss": avg_qry_loss, "test_err": high - avg_qry_loss} wandb.log(test_metrics, step=step_idx) return avg_qry_loss
def meta_forward(self, data): self.learner.train() self.optimizer.zero_grad() self.optimizer_meta.zero_grad() # copy_initial_weights should be False because learner is regarded a parameter of model with higher.innerloop_ctx(self.model, self.optimizer, copy_initial_weights=False) as (fmodel, diffopt): loss_dict = fmodel(data) # aggregate loss loss = sum([v for k, v in loss_dict.items() if 'loss_cls' in k]) diffopt.step(loss) meta_data = next(self._meta_data_loader_iter) meta_loss_dict = fmodel(meta_data) # aggregate meta loss meta_loss = sum( [v for k, v in meta_loss_dict.items() if 'loss_sigmoid' in k]) meta_loss.backward() self.optimizer_meta.step() self.learner.eval() self.log_meta_info()
def get_meta_grad(self, features, edges, labels, train_iters): model = self.gcn loss = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) with higher.innerloop_ctx(model, optimizer) as (fmodel, diffopt): for i in range(train_iters): pre = fmodel(features, edges) idx = select_index(labels, -1, same=False) pre, Y = pre[idx], labels[idx] cost = loss(pre, Y) diffopt.step(cost) pre = fmodel(features, edges) idx = select_index(labels, -1, same=False) sudo_idx = select_index(labels, -1, same=True) cost = 0 if self.lambda_ > 0 : cost =+ self.lambda_ * loss(pre[idx], labels[idx]) if (1-self.lambda_) > 0 : cost =+ (1-self.lambda_) * loss(pre[sudo_idx], self.labels_self_training[sudo_idx]) return torch.autograd.grad(cost, self.adj_changes, retain_graph=False)[0]
def testSameInitialWeightsPostPatch(self): """Verify fast weight alignment/equality after monkey patching.""" ref_named_params = list(self.reference_net.get_fast_weights().items()) ref_params = [p for (_, p) in ref_named_params] with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, _): target_named_params = list(fnet.named_parameters()) target_params = fnet.parameters() self.assertEqual( len(ref_named_params), len(target_named_params), msg=("Length mismatched between reference net parameter count " "({}) and target ({}).".format(len(ref_named_params), len(target_named_params)))) for ref, target in zip(ref_named_params, target_named_params): ref_name, ref_p = ref target_name, target_p = target self.assertEqual( ref_name, target_name, msg="Name mismatch or parameter misalignment ('{}' vs '{}')" .format(ref_name, target_name)) self.assertTrue( torch.equal(ref_p, target_p), msg="Parameter value inequality for {}".format(ref_name)) zipped = zip(ref_params, target_params) for i, (ref_p, target_p) in enumerate(zipped): self.assertTrue( torch.equal(ref_p, target_p), msg="Parameter misalignment in position {}.".format(i))
def compute_outer_grads( self, task: Tuple[torch.Tensor], n_tasks: int, meta_train: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """Compute gradients on query set.""" support_input, support_target, query_input, query_target = task with higher.innerloop_ctx(self.model, self.in_optim, track_higher_grads=False) as (fmodel, diffopt): # inner loop (adapt) self.inner_loop(fmodel, diffopt, support_input, support_target) # evaluate on the query set with torch.set_grad_enabled(meta_train): query_output = fmodel(query_input) query_loss = self.loss_function(query_output, query_target) query_loss /= len(query_input) # compute gradients when in the meta-training stage if meta_train == True: outer_grad = [] for p, fast_p in zip(self.model.parameters(), fmodel.parameters()): outer_grad.append((p.data - fast_p.data) / n_tasks) self.grad_list.append(outer_grad) return query_output, query_loss
def update_fn( self, images: torch.FloatTensor, labels: torch.FloatTensor, model: torch.nn.Module, optimizer: torch.optim.Optimizer, ) -> None: """One optimization step.""" self.task_weights.requires_grad_() # make task weights require grad with higher.innerloop_ctx(model, optimizer) as (fmodel, doptimizer): task_losses = fmodel(images, labels) weighted_loss = torch.sum(task_losses * self.task_weights) doptimizer.step(weighted_loss) new_task_losses = fmodel(images, labels) * self.preference_weights normalized_losses = new_task_losses / torch.sum(new_task_losses) kl_divergence = torch.sum(normalized_losses * torch.log(normalized_losses * len(self.task_weights))) task_weight_grads = torch.autograd.grad(kl_divergence, self.task_weights)[0] # gradient step on task weights with torch.no_grad(): self.task_weights = torch.clamp(self.task_weights - self.flags.meta_lr * task_weight_grads, min=0) self.task_weights = self.task_weights / torch.sum(self.task_weights) # compute gradients using new task weights new_weighted_loss = torch.sum(task_losses * self.task_weights) param_grads = torch.autograd.grad(new_weighted_loss, fmodel.parameters(time=0)) optimizer.zero_grad() for index, param in enumerate(model.parameters()): param.grad = param_grads[index] optimizer.step()
def forward(self, target_repr, init): """ Conceptually, DSPN simply turns the target_repr feature vector into a set. target_repr: Representation that the predicted set should match. FloatTensor of size (batch_size, repr_channels). This can come from a set processed with the same encoder as self.encoder (auto-encoder), or a different input completely (normal supervised learning), such as an image encoded into a feature vector. """ # copy same initial set over batch current_set = nn.Parameter(init) inner_set = InnerSet(current_set) # info used for loss computation intermediate_sets = [current_set] # info used for debugging repr_losses = [] grad_norms = [] # optimise repr_loss for fixed number of steps with torch.enable_grad(): opt = torch.optim.SGD(inner_set.parameters(), lr=self.lr, momentum=0.5) with higher.innerloop_ctx(inner_set, opt) as (fset, diffopt): for i in range(self.iters): predicted_repr = self.encoder(fset()) # how well does the representation matches the target repr_loss = ((predicted_repr- target_repr)**2).sum() diffopt.step(repr_loss) intermediate_sets.append(fset.mask) repr_losses.append(repr_loss) grad_norms.append(()) return intermediate_sets, repr_losses, grad_norms
def closure(model, optimizer, *args): """This function will be evaluated on all GPUs.""" # noqa: D401 list(model.children())[-1].train( ) if model.frozen else model.train() batch_size = inputs.shape[0] data = torch.cat((inputs, targets), dim=0) # Wrap the model into a meta-object that allows for meta-learning steps via monkeypatching: with higher.innerloop_ctx(model, optimizer, copy_initial_weights=False) as (fmodel, fopt): for _ in range(self.args.nadapt): outputs = fmodel(data) poison_loss = criterion(outputs[:batch_size], labels) fopt.step(poison_loss) prediction = (outputs[:batch_size].data.argmax( dim=1) == labels).sum() target_loss = criterion(outputs[batch_size:], intended_classes) target_loss.backward(retain_graph=self.retain) return target_loss.detach().cpu(), prediction.detach().cpu()
def learn(self, timesteps, update_interval=None, track_higher_grads=False, lr_scheduler=None, step_callback=None, interval_callback=None, reward_aggregation='episodic'): if update_interval is None: update_interval = self.update_interval state = self.env.reset() memory = Memory() episodic_rewards = [0.] interval_rewards = [0.] # This context wraps the policy and optimizer to track parameter updates # over time such that d Params(time=t) / d Params(time=t-n) can be calculated. # If not tracking higher gradients, a dummy context is used which does # nothing. with innerloop_ctx(self.policy, self.optimizer, track_higher_grads=track_higher_grads, copy_initial_weights=False) as (policy, optimizer): for t in trange(1, int(timesteps) + 1, leave=False): # Running policy: memory.states.append(state) action, logprob = policy.predict(state) state, reward, done, info = self.env.step(action) episodic_rewards[-1] += reward interval_rewards[-1] += reward if done: state = self.env.reset() episodic_rewards.append(0.) memory.actions.append(action) memory.logprobs.append(logprob) memory.rewards.append(reward) memory.is_terminals.append(done) if step_callback is not None: step_callback(locals()) # update if its time if t % update_interval == 0: interval_rewards.append(0.) loss = self.update(policy, memory, self.epochs, optimizer, self.summary) if interval_callback is not None: interval_callback(locals()) if lr_scheduler is not None: lr_scheduler() memory.clear() self.meta_policy = policy if track_higher_grads else None self.policy.load_state_dict(policy.state_dict()) if reward_aggregation == 'episodic': return episodic_rewards[:-1 if len(episodic_rewards) > 1 else None] else: return interval_rewards
def evaluate(self, dataloader, updates, mini_batch_size): self.pn.train() support_set = [] for _ in range(updates): text, label, candidates = self.memory.read_batch(batch_size=mini_batch_size) support_set.append((text, label, candidates)) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop task_predictions, task_labels = [], [] support_loss = [] for text, label, candidates in support_set: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label, candidates) input_dict = self.pn.encode_text(list(zip(replicated_text, replicated_relations))) output = fpn(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device) loss = self.loss_fn(output, targets) diffopt.step(loss) pred, true_labels = models.utils.make_rel_prediction(output, ranking_label) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(true_labels.tolist()) acc = models.utils.calculate_accuracy(task_predictions, task_labels) logger.info('Support set metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(np.mean(support_loss), acc)) all_losses, all_predictions, all_labels = [], [], [] for text, label, candidates in dataloader: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label, candidates) with torch.no_grad(): input_dict = self.pn.encode_text(list(zip(replicated_text, replicated_relations))) output = fpn(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device) loss = self.loss_fn(output, targets) loss = loss.item() pred, true_labels = models.utils.make_rel_prediction(output, ranking_label) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(true_labels.tolist()) acc = models.utils.calculate_accuracy(all_predictions, all_labels) logger.info('Test metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(np.mean(all_losses), acc)) return acc
def few_shot_testing(self, train_dataset, eval_dataset, increment_counters=False, split="test"): """ Allow the model to train on a small amount of datapoints at a time. After every training step, evaluate on many samples that haven't been seen yet. Results are saved in learner's `metrics` attribute. Parameters --- train_dataset: Dataset Contains examples on which the model is trained before being evaluated eval_dataset: Dataset Contains examples on which the model is evaluated increment_counters: bool If True, update online metrics and current iteration counters. """ self.logger.info( f"few shot testing on dataset {self.config.testing.eval_dataset} " f"with {len(train_dataset)} samples") train_dataloader, eval_dataloader = self.few_shot_preparation( train_dataset, eval_dataset, split=split) all_predictions, all_labels = [], [] with higher.innerloop_ctx(self.pln, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpln, diffopt): self.pln.train() self.rln.eval() # Inner loop for i, (text, labels, datasets) in enumerate(train_dataloader): labels = torch.tensor(labels).to(self.device) output = self.forward(text, labels, fpln) loss = self.loss_fn(output["logits"], labels) diffopt.step(loss) predictions = model_utils.make_prediction( output["logits"].detach()) all_predictions.extend(predictions.tolist()) all_labels.extend(labels.tolist()) dataset_results = self.evaluate(dataloader=eval_dataloader, prediction_network=fpln) self.log_few_shot(all_predictions, all_labels, datasets, dataset_results, increment_counters, text, i, split=split) if (i * self.config.testing.few_shot_batch_size ) % self.mini_batch_size == 0 and i > 0: all_predictions, all_labels = [], [] self.few_shot_end()
def test_maml(db, net, device, lr_finetune, epoch, log): # Crucially in our testing procedure here, we do *not* fine-tune # the model during testing for simplicity. # Most research papers using MAML for this task do an extra # stage of fine-tuning here that should be added if you are # adapting this code for research. net.train() n_test_iter = db.x_test.shape[0] // db.batchsz qry_losses = [] qry_accs = [] for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) # TODO: Maybe pull this out into a separate module so it # doesn't have to be duplicated between `train` and `test`? n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=lr_finetune) for i in range(task_num): with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt): # Optimize the likelihood of the support set by taking # gradient steps w.r.t. the model's parameters. # This adapts the model's meta-parameters to the task. for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) # The query loss and acc induced by these parameters. qry_logits = fnet(x_qry[i]).detach() qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') qry_losses.append(qry_loss.detach()) qry_accs.append( (qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() # print("accuracies are:", qry_accs) qry_accs = 100. * torch.cat(qry_accs).float().mean().item() print( f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' ) log.append({ 'epoch': epoch + 1, 'loss': qry_losses, 'acc': qry_accs, 'mode': 'test', 'time': time.time(), })
def outer_loop(self, batch, is_train): train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch( batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): with higher.innerloop_ctx(self.network, self.inner_optimizer, track_higher_grads=False) as (fmodel, diffopt): for step in range(self.args.n_inner): self.inner_loop(fmodel, diffopt, train_input, train_target) train_logit = fmodel(train_input) in_loss = F.cross_entropy(train_logit, train_target) test_logit = fmodel(test_input) outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item() / self.batch_size with torch.no_grad(): acc_log += get_accuracy( test_logit, test_target).item() / self.batch_size if is_train: params = list(fmodel.parameters(time=-1)) in_grad = torch.nn.utils.parameters_to_vector( torch.autograd.grad(in_loss, params, create_graph=True)) outer_grad = torch.nn.utils.parameters_to_vector( torch.autograd.grad(outer_loss, params)) implicit_grad = self.neumann_approx( in_grad, outer_grad, params) grad_list.append(implicit_grad) loss_list.append(outer_loss.item()) if is_train: self.outer_optimizer.zero_grad() weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log
def train_on_loader(self, loader): ## for now, one task at a time task_num = 1 qry_losses = [] qry_accs = [] self.backbone.train() for episode in tqdm(loader): episode = episode[0] # undo collate self.optimizer.zero_grad() support_set = episode["support_set"].cuda(non_blocking=False) query_set = episode["query_set"].cuda(non_blocking=False) ss, nclasses, c, h, w = support_set.size() qs, nclasses, c, h, w = query_set.size() absolute_labels = episode["targets"] relative_labels = absolute_labels.clone() # TODO: use episode['targets'] support_relative_labels = torch.arange(episode['nclasses']).view(1, -1).repeat( episode['support_size'], 1).cuda().view(-1) query_relative_labels = torch.arange(episode['nclasses']).view(1, -1).repeat( episode['query_size'], 1).cuda().view(-1) inner_opt = torch.optim.SGD(self.backbone.parameters(), lr=self.inner_lr) querysz = query_relative_labels.size()[0] self.optimizer.zero_grad() for i in range(task_num): with higher.innerloop_ctx(self.backbone, inner_opt, copy_initial_weights=False) as (fnet, diffopt): for _ in range(self.n_inner_iter): ## for now only one task at a time spt_logits = fnet(support_set.view(ss * nclasses, c, h, w)).view(ss * nclasses, -1) spt_loss = F.cross_entropy(spt_logits, support_relative_labels) diffopt.step(spt_loss) qry_logits = fnet(query_set.view(qs * nclasses, c, h, w)) qry_loss = F.cross_entropy(qry_logits, query_relative_labels) qry_losses.append(qry_loss.detach()) qry_acc = (qry_logits.argmax( dim=1) == query_relative_labels).sum().item() / querysz qry_accs.append(qry_acc) qry_loss.backward() self.optimizer.step() qry_losses = sum(qry_losses) / len(qry_losses) qry_accs = 100. * sum(qry_accs) / len(qry_accs) return {"train_loss": qry_losses.item(), "train_acc":qry_accs}
def meta_train_mbrl_reacher(policy, ml3_loss, dmodel, env, task_loss_fn, goals, n_outer_iter, n_inner_iter, time_horizon, exp_folder): goals = torch.Tensor(goals) meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate) for outer_i in range(n_outer_iter): # set gradient with respect to meta loss parameters to 0 meta_opt.zero_grad() all_loss = 0 for goal in goals: goal = torch.Tensor(goal) policy.reset() inner_opt = torch.optim.SGD(policy.parameters(), lr=policy.learning_rate) for _ in range(n_inner_iter): inner_opt.zero_grad() with higher.innerloop_ctx( policy, inner_opt, copy_initial_weights=False) as (fpolicy, diffopt): # use current meta loss to update model s_tr, a_tr, g_tr = fpolicy.roll_out( goal, time_horizon, dmodel, env) meta_input = torch.cat( [s_tr[:-1].detach(), a_tr, g_tr.detach()], dim=1) pred_task_loss = ml3_loss(meta_input).mean() diffopt.step(pred_task_loss) # compute task loss s, a, g = fpolicy.roll_out(goal, time_horizon, dmodel, env) task_loss = task_loss_fn(a, s[:], goal).mean() # collect losses for logging all_loss += task_loss # backprop grad wrt to task loss task_loss.backward() if outer_i % 100 == 0: # roll out in real environment, to monitor training and tp collect data for dynamics model update states, actions, _ = fpolicy.roll_out(goal, time_horizon, dmodel, env, real_rollout=True) print("meta iter: {} loss: {}".format(outer_i, (torch.mean( (states[-1, :2] - goal[:2])**2)))) if outer_i % 300 == 0 and outer_i < 3001: # update dynamics model under current optimal policy dmodel.train(torch.Tensor(states), torch.Tensor(actions)) # step optimizer to update meta loss network meta_opt.step() torch.save(ml3_loss.state_dict(), f'{exp_folder}/ml3_loss_reacher.pt')
def tmp(): teacher_optimizer = ... student_optimizer = ... with higher.innerloop_ctx(student, student_optimizer) as (smodel, sdiffopt): student.train() teacher.train() teacher_logits = teacher(unsupervised_batch) student_logits = smodel(unsupervised_batch) distillation_loss = soft_ce( torch.log_softmax(student_logits, dim=1), torch.softmax(teacher_logits, dim=1), ) print("Distillation loss:", distillation_loss.item()) sdiffopt.step(distillation_loss) student_logits = smodel(student_data) student_logits.squeeze_(dim=1) student_loss = ce(student_logits, student_labels) print("Student loss:", student_loss.item()) student_loss.backward() print("Teacher grad: {} +- {}".format(teacher.fc1.weight.grad.mean(), teacher.fc1.weight.grad.std())) teacher_optimizer.step() if step % supervised_teacher_update_freq == 0: supervise_teacher(teacher_data, teacher_labels) clear_output(wait=True) with torch.no_grad(): student.eval() teacher.eval() plot_predictions( ( unsupervised_data.numpy(), student(unsupervised_data).numpy().argmax(1), "Student", ), ( unsupervised_data.numpy(), teacher(unsupervised_data).numpy().argmax(1), "Teacher", ), ) with torch.no_grad(): for old_p, new_p in zip(student.parameters(), smodel.parameters()): old_p.copy_(new_p)
def testGradientCorrectness(self, _, model_builder, opt_builder, kwargs=None): kwargs = {} if kwargs is None else kwargs lr = .1 model = model_builder(self) eps = 1e-3 tests = 10 count = 0 threshold = .6 # proportion of tests that should pass for i in range(tests): xs = [torch.rand(10, 4) for _ in range(2)] def closure(): cmodel = copy.deepcopy(model) opt = opt_builder(cmodel.parameters(), lr=lr, **kwargs) for x in xs[:-1]: opt.zero_grad() cmodel(x).pow(2).sum().backward() opt.step() loss = cmodel(xs[-1]).pow(2).sum() return loss fd_grads = finite_difference(model, closure, eps) opt = opt_builder(model.parameters(), lr=lr, **kwargs) with higher.innerloop_ctx(model, opt) as (fmodel, diffopt): for x in xs[:-1]: loss = fmodel(x).pow(2).sum() diffopt.step(loss) loss = fmodel(xs[-1]).pow(2).sum() grads = torch.autograd.grad(loss, fmodel.parameters(time=0), allow_unused=True) close = [] for g, fg in zip(grads, fd_grads): if g is None: # trusting that the tensor shouldn't have been used... close.append(True) else: self.assertFalse(torch.any(torch.isnan(g)), "NaNs found in gradient.") close.append(torch.allclose(g, fg, 1e-1, 1e-1)) if all(close): count += 1 self.assertTrue( count / tests >= threshold, msg="Proportion of successful finite gradient checks below {:.0f}% " "threshold ({:.0f}%).".format(threshold * 100, 100 * count / tests))
def test_episodic_loader_inner_loop_per_task_good_accumulator(debug_test=True): import torch import torch.optim as optim from automl.child_models.learner_from_opt_as_few_shot_paper import Learner import automl.child_models.learner_from_opt_as_few_shot_paper import higher ## get args for test args = get_args_for_mini_imagenet() ## get base model that meta-lstm/maml use base_model = Learner(image_size=args.image_size, bn_eps=args.bn_eps, bn_momentum=args.bn_momentum, n_classes=args.n_classes).to(args.device) ## get meta-set meta_train_loader, _, _ = get_meta_set_loaders_miniImagenet(args) ## start episodic training meta_params = base_model.parameters() outer_opt = optim.Adam(meta_params, lr=1e-2) base_model.train() for episode, (spt_x, spt_y, qry_x, qry_y) in enumerate(meta_train_loader): assert(spt_x.size(1) == args.k_shot*args.n_classes) assert(qry_x.size(1) == args.k_eval*args.n_classes) ## Get Inner Optimizer (for maml) inner_opt = torch.optim.SGD(base_model.parameters(), lr=1e-1) ## Accumulate gradient of meta-loss wrt fmodel.param(t=0) nb_tasks = spt_x.size(0) meta_losses, meta_accs = [], [] assert(nb_tasks == args.meta_batch_size_train) for t in range(nb_tasks): ## Get supprt & query set for the current task spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x[t], spt_y[t], qry_x[t], qry_y[t] ## Inner Loop Adaptation with higher.innerloop_ctx(base_model, inner_opt, copy_initial_weights=args.copy_initial_weights, track_higher_grads=args.track_higher_grads) as (fmodel, diffopt): for i_inner in range(args.nb_inner_train_steps): fmodel.train() # base/child model forward pass spt_logits_t = fmodel(spt_x_t) inner_loss = args.criterion(spt_logits_t, spt_y_t) # inner-opt update diffopt.step(inner_loss) inner_loss = args.criterion(spt_logits_t, spt_y_t) ## Evaluate on query set for current task qry_logits_t = fmodel(qry_x_t) qry_loss_t = args.criterion(qry_logits_t, qry_y_t) ## Accumulate gradients wrt meta-params for each task qry_loss_t.backward() # note this is memory efficient ## collect losses & accs for logging/debugging meta_losses.append(qry_loss_t.detach()) # remove history so it be memory efficient and able to print stats ## do outer step outer_opt.step() outer_opt.zero_grad() print(f'[episode={episode}] meta_loss = {sum(meta_losses)/len(meta_losses)}')
def maml_train(self, model, inner_batch, outer_batches): assert model.training ret_dic = {} with higher.innerloop_ctx( model, self.inner_opt, copy_initial_weights=False, device=self.device) as (fmodel, diffopt), torch.backends.cudnn.flags( enabled=False): for _step in range(self.inner_steps): inner_ret_dic = fmodel(inner_batch) inner_loss = inner_ret_dic["loss"] # use the snippet for checking higher # def test(params): # params = [p for p in params if p.requires_grad] # all_grads = torch.autograd.grad( # loss, # params, # retain_graph=True, # allow_unused=True, # ) # print(len(params), sum(p is not None for p in all_grads)) # import pdb; pdb.set_trace() # test(model.parameters()) # test(fmodel.fast_params) diffopt.step(inner_loss) logger.info(f"Inner loss: {inner_loss.item()}") mean_outer_loss = torch.Tensor([0.0]).to(self.device) with torch.set_grad_enabled(model.training): for batch_id, outer_batch in enumerate(outer_batches): outer_ret_dic = fmodel(outer_batch) mean_outer_loss += outer_ret_dic["loss"] mean_outer_loss.div_(len(outer_batches)) logger.info(f"Outer loss: {mean_outer_loss.item()}") final_loss = inner_loss + mean_outer_loss final_loss.backward() # not sure if it helps del fmodel import gc gc.collect() ret_dic["loss"] = final_loss.item() return ret_dic
def train(dataset: AudioNShot, net: TorchModule, device, meta_optimizer: Optimizer, epoch_num, log): net.train() iterations = dataset.x_train.shape[0] # batch size for batch_index in range(iterations): start_time = time.time() # Get support and query sets x_support, y_support, x_query, y_query = dataset.next() task_num, set_size, c, sample_size = x_support.size() query_size = x_query.size(1) # Set inner optimizer inner_itteration = INNER_ITERATIONS optimizer = SGD(net.parameters(), lr=1e-1) query_losses = [] query_accuracies = [] meta_optimizer.zero_grad() for i in range(task_num): with higher.innerloop_ctx(net, optimizer, copy_initial_weights=False) as (fnet, diffopt): for _ in range(inner_itteration): support_outputs = fnet(x_support[i]) support_loss = F.cross_entropy(support_outputs, y_support[i]) diffopt.step(support_loss) query_outputs = fnet(x_query[i]) query_loss = F.cross_entropy(query_outputs, y_query[i]) query_losses.append(query_loss.detach()) query_accuracy = (query_outputs.argmax(dim=1) == y_query[i]).sum().item() / query_size query_accuracies.append(query_accuracy) query_loss.backward() meta_optimizer.step() query_losses = sum(query_losses) / task_num query_accuracies = 100. * sum(query_accuracies) / task_num i = epoch_num + float(batch_index) / iterations iterration_time = time.time() - start_time if batch_index % 4 == 0: print(f'[Epoch {i:.2f}] Train Loss: {query_losses:.2f} | Acc: {query_accuracies:.2f} | Time: {iterration_time:.2f}') log.append({ 'epoch': i, 'loss': query_losses, 'acc': query_accuracies, 'mode': 'train', 'time': time.time(), })
def train_on_batch(self, tasks_batch): # sprt should never intersect with qry! So only shuffle the task # at creation! # For each task in the batch inner_losses, meta_losses, accuracies = [], [], [] self.meta_opt.zero_grad() for i, task in enumerate(tasks_batch): with higher.innerloop_ctx( self.learner, self.inner_opt, copy_initial_weights=False ) as (f_learner, diff_opt): meta_loss, inner_loss, task_accuracies = 0, 0, [] sprt, qry = task f_learner.train() for s in range(self.inner_steps): step_loss = 0 for x, y in sprt: # sprt is an iterator returning batches y_pred = f_learner(x) step_loss += self.inner_loss(y_pred, y) inner_loss += step_loss.detach() diff_opt.step(step_loss) f_learner.eval() for x, y in qry: y_pred = f_learner(x) # Use the updated model for that task # Accumulate the loss over all tasks in the meta-testing set meta_loss += self.meta_loss(y_pred, y) if self._compute_accuracy: scores, indices = y_pred.max(dim=1) acc = (y == indices).sum() / y.size(0) # Mean accuracy per batch task_accuracies.append(acc) # Divide by the number of samples because reduction is set to 'sum' so that # the meta-objective can be computed correctly. meta_losses.append(meta_loss.detach().div_(self.inner_steps*len(sprt.dataset))) inner_losses.append(inner_loss.mean().div_(len(qry.dataset))) if self._compute_accuracy: accuracies.append(torch.tensor(task_accuracies).mean()) # Update the model's meta-parameters to optimize the query # losses across all of the tasks sampled in this batch. # This unrolls through the gradient steps. meta_loss.backward() self.meta_opt.step() avg_inner_loss = torch.tensor(inner_losses).mean().item() avg_meta_loss = torch.tensor(meta_losses).mean().item() avg_accuracy = torch.tensor(accuracies).mean().item() return avg_inner_loss, avg_meta_loss, avg_accuracy
def testCtxManager(self, _, model_builder): model = model_builder(self) opt = torch.optim.SGD(model.parameters(), lr=self.lr) with higher.innerloop_ctx(model, opt) as (fmodel, diffopt): for _ in range(10): inputs = torch.rand(8, 4) loss = fmodel(inputs).pow(2).sum() diffopt.step(loss) param_sum = sum(p.sum() for p in fmodel.parameters()) final_grads = torch.autograd.grad(param_sum, fmodel.parameters(time=0)) for grad in final_grads: self.assertIsNotNone(grad)
def outer_loop(self, batch, is_train): self.network.zero_grad() train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch( batch) loss_log = 0 acc_log = 0 grad_list = [] loss_list = [] for (train_input, train_target, test_input, test_target) in zip(train_inputs, train_targets, test_inputs, test_targets): with higher.innerloop_ctx( self.network, self.inner_optimizer, track_higher_grads=is_train) as (fmodel, diffopt): for step in range(self.args.n_inner): self.inner_loop(fmodel, diffopt, train_input, train_target) test_logit = fmodel(test_input) outer_loss = F.cross_entropy(test_logit, test_target) loss_log += outer_loss.item() / self.batch_size with torch.no_grad(): acc_log += get_accuracy( test_logit, test_target).item() / self.batch_size if is_train: outer_grad = torch.autograd.grad(outer_loss, fmodel.parameters(time=0)) grad_list.append(outer_grad) loss_list.append(outer_loss.item()) if is_train: weight = torch.ones(len(grad_list)) weight = weight / torch.sum(weight) grad = mix_grad(grad_list, weight) grad_log = apply_grad(self.network, grad) self.outer_optimizer.step() return loss_log, acc_log, grad_log else: return loss_log, acc_log