def __init__(self, args, model, optimizer_cls, meta_optimizer_cls, optimizer_kwargs, meta_optimizer_kwargs, optimizer=None, scheduler=None): super(_FOWrapper, self).__init__(args, model, optimizer_cls, optimizer_kwargs) self.meta_optimizer_cls = optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam self.meta_optimizer_kwargs = meta_optimizer_kwargs self._counter = 0 self._updates = None self._original = clone_state_dict(self.model.state_dict(keep_vars=True)) params = [p for p in self._original.values() if getattr(p, 'requires_grad', False)] if optimizer: self.meta_optimizer = optimizer else: self.meta_optimizer = self.meta_optimizer_cls(params, **meta_optimizer_kwargs) self.scheduler = scheduler
def run_tasks(self, tasks, meta_train): original = None if not meta_train: original = clone_state_dict(self.model.state_dict(keep_vars=True)) # Non-transductive task evaluation for fair comparison for module in self.model.modules(): if hasattr(module, 'reset_running_stats'): module.reset_running_stats() # Training # all_batches = self.gen_multitask_batches(tasks, train=True) trainres = self.run_multitask(all_batches, train=True) # Eval # all_batches = self.gen_multitask_batches(tasks, train=False) valres = self.run_multitask(all_batches, train=False) results = AggRes(zip(trainres, valres)) if not meta_train: self.model.load_state_dict(original) return results
def __init__(self, model, optimizer_cls, optimizer_kwargs, criterion): super(NoWrapper, self).__init__(criterion, model, optimizer_cls, optimizer_kwargs) self._original = clone_state_dict(model.state_dict(keep_vars=True))