Exemplo n.º 1
0
    def __init__(self, args, model, optimizer_cls, meta_optimizer_cls, optimizer_kwargs,
                 meta_optimizer_kwargs, optimizer=None, scheduler=None):
        super(_FOWrapper, self).__init__(args, model, optimizer_cls, optimizer_kwargs)
        self.meta_optimizer_cls = optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
        self.meta_optimizer_kwargs = meta_optimizer_kwargs

        self._counter = 0
        self._updates = None
        self._original = clone_state_dict(self.model.state_dict(keep_vars=True))

        params = [p for p in self._original.values() if getattr(p, 'requires_grad', False)]
        if optimizer:
            self.meta_optimizer = optimizer
        else:
            self.meta_optimizer = self.meta_optimizer_cls(params, **meta_optimizer_kwargs)
        self.scheduler = scheduler
Exemplo n.º 2
0
    def run_tasks(self, tasks, meta_train):
        original = None
        if not meta_train:
            original = clone_state_dict(self.model.state_dict(keep_vars=True))

            # Non-transductive task evaluation for fair comparison
            for module in self.model.modules():
                if hasattr(module, 'reset_running_stats'):
                    module.reset_running_stats()

        # Training #
        all_batches = self.gen_multitask_batches(tasks, train=True)
        trainres = self.run_multitask(all_batches, train=True)

        # Eval #
        all_batches = self.gen_multitask_batches(tasks, train=False)
        valres = self.run_multitask(all_batches, train=False)

        results = AggRes(zip(trainres, valres))

        if not meta_train:
            self.model.load_state_dict(original)

        return results
Exemplo n.º 3
0
 def __init__(self, model, optimizer_cls, optimizer_kwargs, criterion):
     super(NoWrapper, self).__init__(criterion, model, optimizer_cls,
                                     optimizer_kwargs)
     self._original = clone_state_dict(model.state_dict(keep_vars=True))