class LeapWrapper(BaseWrapper): """Wrapper around the Leap meta-learner. Arguments: model (nn.Module): classifier. optimizer_cls: optimizer class. meta_optimizer_cls: meta optimizer class. optimizer_kwargs (dict): kwargs to pass to optimizer upon construction. meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon construction. meta_kwargs (dict): kwargs to pass to meta-learner upon construction. criterion (func): loss criterion to use. """ def __init__(self, model, optimizer_cls, meta_optimizer_cls, optimizer_kwargs, meta_optimizer_kwargs, meta_kwargs, criterion): super(LeapWrapper, self).__init__(criterion, model, optimizer_cls, optimizer_kwargs) self.meta = Leap(model, **meta_kwargs) self.meta_optimizer_cls = \ optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam self.meta_optimizer = self.meta_optimizer_cls(self.meta.parameters(), **meta_optimizer_kwargs) def _partial_meta_update(self, l, final): self.meta.update(l, self.model) def _final_meta_update(self): self.meta.normalize() self.meta_optimizer.step() self.meta_optimizer.zero_grad() def run_task(self, task, train, meta_train): if meta_train: self.meta.init_task() if train: self.meta.to(self.model) return super(LeapWrapper, self).run_task(task, train, meta_train)
def leap_adapt(model, source_corpus, target_corpus, char2idx, args, device, lang_model_n_words=0): model = model.to(device) leap = Leap(model) meta_optimizer = torch.optim.Adam(leap.parameters(), lr=args.leap_meta_lr_init) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( meta_optimizer, factor=args.lr_decay, patience=args.patience, threshold=args.threshold) best_score = 3 for meta_epoch in np.arange(args.n_meta_epochs): source_valid_cosine = [] target_valid_cosine = [] model.train() for meta_batch in np.arange(args.n_meta_batch): meta_optimizer.zero_grad() leap.init_task() leap.to(model) inner_optimizer = torch.optim.Adam(model.parameters(), lr=args.leap_inner_lr_init) for inner_batch in np.arange(args.n_task_steps): inner_optimizer.zero_grad() source_train_contexts, source_train_targets, source_train_vocabs = source_corpus.get_batch( args.meta_batch_size, args.n_shot, char2idx, device, fixed=args.fixed_shot) pred_emb = model.forward(source_train_contexts, source_train_vocabs) loss = -nn.functional.cosine_similarity( pred_emb, source_train_targets).mean() loss.backward() leap.update(loss, model) inner_optimizer.step() leap.init_task() leap.to(model) inner_optimizer = torch.optim.Adam(model.parameters(), lr=args.leap_inner_lr_init) for inner_batch in np.arange(args.n_task_steps): inner_optimizer.zero_grad() target_train_contexts, target_train_targets, target_train_vocabs = target_corpus.get_batch( args.meta_batch_size, args.n_shot, char2idx, device, fixed=args.fixed_shot, repeat_ctxs=args.meta_repeat_ctxs) pred_emb = model.forward(target_train_contexts, target_train_vocabs) loss = -nn.functional.cosine_similarity( pred_emb, target_train_targets).mean() loss.backward() leap.update(loss, model) inner_optimizer.step() leap.normalize() meta_optimizer.step() leap.to(model) model.eval() with torch.no_grad(): for batch in np.arange(args.n_batch): source_valid_contexts, source_valid_targets, source_valid_vocabs = source_corpus.get_batch( args.meta_batch_size, args.n_shot, char2idx, device, use_valid=True, fixed=args.fixed_shot) pred_emb = model.forward(source_valid_contexts, source_valid_vocabs) loss = -nn.functional.cosine_similarity( pred_emb, source_valid_targets).mean() source_valid_cosine += [loss.cpu().numpy()] target_valid_contexts, target_valid_targets, target_valid_vocabs = target_corpus.get_batch( args.meta_batch_size, args.n_shot, char2idx, device, use_valid=True, fixed=args.fixed_shot, repeat_ctxs=args.meta_repeat_ctxs) pred_emb = model.forward(target_valid_contexts, target_valid_vocabs) loss = -nn.functional.cosine_similarity( pred_emb, target_valid_targets).mean() target_valid_cosine += [loss.cpu().numpy()] avg_source_valid, avg_target_valid = np.average( source_valid_cosine), np.average(target_valid_cosine) score = avg_target_valid lr_scheduler.step(score) print( f"Average source cosine loss: {avg_source_valid}; Average target cosine loss: {avg_target_valid}" ) if score < best_score: best_score = score torch.save(model.state_dict(), os.path.join(args.save_dir, 'leap_model.pt')) if meta_optimizer.param_groups[0]['lr'] < args.leap_lr_early_stop: print('LR early stop') break