Example #1
0
    def train_iter(self, dataloader, max_batches=500, is_classification_task=False):
        """ 
        Runs one epoch of meta-training on the given dataset, yielding
        training statistics in batches.

        Exactly `max_batches` batches of tasks will be processed.
        """
        if self.optimizer is None:
            raise RuntimeError('Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        num_batches = 0
        self.model.train()
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                if self.scheduler is not None:
                    self.scheduler.step(epoch=num_batches)

                self.optimizer.zero_grad()

                batch = tensors_to_device(batch, device=self.device)
                outer_loss, results = self.get_outer_loss(batch, is_classification_task=is_classification_task)
                yield results

                outer_loss.backward()
                self.optimizer.step()

                num_batches += 1
Example #2
0
    def train_iter(self, dataloader, max_batches=500):
        if self.optimizer is None:
            raise RuntimeError(
                'Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        num_batches = 0
        self.model.train()
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                if self.scheduler is not None:
                    self.scheduler.step(epoch=num_batches)

                self.optimizer.zero_grad()

                batch = tensors_to_device(batch, device=self.device)
                outer_loss, results = self.get_outer_loss(batch)
                yield results

                outer_loss.backward()
                self.optimizer.step()

                num_batches += 1
Example #3
0
    def train_iter(self,
                   dataloader,
                   accumulation_steps=1,
                   max_batches=500,
                   is_classification_task=False):
        """ 
        Runs one epoch of meta-training on the given dataset, yielding
        training statistics in batches.

        Exactly `max_batches` batches of tasks will be processed.
        """
        if self.optimizer is None:
            raise RuntimeError(
                'Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        num_batches = 0
        self.model.train()

        all_results = {}
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    break

                if self.scheduler is not None:
                    self.scheduler.step(epoch=num_batches)

                batch = tensors_to_device(batch, device=self.device)
                outer_loss, results = self.get_outer_loss(
                    batch, is_classification_task=is_classification_task)
                outer_loss /= accumulation_steps
                outer_loss.backward()

                # Add the results from this batch
                if not all_results:
                    all_results = dict(results)
                else:
                    for key in all_results:
                        all_results[key] += results[key]

                # Take a step with the accumulated gradients, and reset optimizer & results
                if (num_batches + 1) % accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                    all_results = {
                        key: val / accumulation_steps
                        for key, val in all_results.items()
                        if not key in ('num_tasks', )
                    }
                    yield all_results
                    all_results = {}

                num_batches += 1
Example #4
0
    def evaluate_iter(self, dataloader, max_batches=500):
        num_batches = 0
        self.model.eval()
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                batch = tensors_to_device(batch, device=self.device)
                _, results = self.get_outer_loss(batch)
                yield results

                num_batches += 1
Example #5
0
    def evaluate_iter(self, dataloader, max_batches=500):
        num_batches = 0
        self.model.eval()
        if self.warp_model is not None:
            self.warp_model.eval()
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                batch = tensors_to_device(batch, device=self.device)
                _, results = self.get_outer_loss(batch, eval_mode=True, write_params=False)
                yield results

                num_batches += 1
Example #6
0
    def evaluate_iter(self, dataloader, max_batches=500, is_classification_task=False):
        """
        Yields the validation loss (outer loss) in batches without performing any
        gradient steps.
        """
        num_batches = 0
        self.model.eval()
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                batch = tensors_to_device(batch, device=self.device)
                _, results = self.get_outer_loss(batch, is_classification_task=is_classification_task)
                yield results

                num_batches += 1
    def train_iter(self, dataloader, max_batches=500):
        if self.optimizer is None:
            raise RuntimeError(
                'Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        num_batches = 0
        self.model.train()
        while num_batches < max_batches:
            # import pdb
            # pdb.set_trace()
            for batches in zip(dataloader[0], cycle(dataloader[1])):

                if num_batches >= max_batches:
                    break

                if self.scheduler is not None:
                    self.scheduler.step(epoch=num_batches)

                self.optimizer.zero_grad()

                batches = [
                    tensors_to_device(batch, device=self.device)
                    for batch in batches
                ]
                outer_loss, results, outer_losses = self.get_outer_loss(
                    batches)
                yield results

                if self.PCGrad:
                    self.optimizer.pc_backward(
                        outer_losses
                    )  # calculate the gradient can apply gradient modification
                    self.optimizer.step()  # apply gradient step
                else:
                    outer_loss.backward()
                    self.optimizer.step()

                num_batches += 1
    def train_iter(self, dataloader, task_weighting: TaskWeightingBase, weight_normalizer: WeightNormalizer,
                   epoch, max_batches):
        if self.optimizer is None:
            raise RuntimeError('Trying to call `train_iter`, while the '
                               'optimizer is `None`. In order to train `{0}`, you must '
                               'specify a Pytorch optimizer as the argument of `{0}` '
                               '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                               'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        num_batches = 0
        self.model.train()

        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                if self.scheduler is not None:
                    self.scheduler.step(epoch=num_batches)

                iteration = epoch * max_batches + num_batches

                self.optimizer.zero_grad()

                task_weighting.before_gradient_step(iteration, batch)

                batch = tensors_to_device(batch, device=self.device)
                outer_losses, results = self.get_outer_losses(task_weighting, batch)
                yield results

                loss = task_weighting.compute_weighted_loss(iteration, outer_losses)

                loss.backward()
                self.optimizer.step()

                task_weighting.update_inner_weights(iteration, outer_losses)

                if hasattr(task_weighting, 'weights'):
                    task_weighting.weights = weight_normalizer.normalize(iteration, task_weighting.weights)

                num_batches += 1
Example #9
0
    def train_iter(self, dataloader, max_batches=500):
        """
        if self.optimizer is None:
            raise RuntimeError('Trying to call `train_iter`, while the '
                'optimizer is `None`. In order to train `{0}`, you must '
                'specify a Pytorch optimizer as the argument of `{0}` '
                '(eg. `{0}(model, optimizer=torch.optim.SGD(model.'
                'parameters(), lr=0.01), ...).'.format(__class__.__name__))
        """
        num_batches = 0
        self.model.train()
        if self.warp_model is not None:
            self.warp_model.train()
        while num_batches < max_batches:
            for batch in dataloader:
                if num_batches >= max_batches:
                    break

                self.num_steps += 1

                if self.optimizer is not None:
                    self.optimizer.zero_grad()
                if self.warp_optimizer is not None:
                    self.warp_optimizer.zero_grad()

                if self.ensembler_optimizer is not None:
                    self.ensembler_optimizer.zero_grad()

                batch = tensors_to_device(batch, device=self.device)
                outer_loss, results = self.get_outer_loss(batch)
                yield results

                outer_loss.backward()


                if self.optimizer is not None and (self.num_maml_steps <= 0 or self.num_steps < self.num_maml_steps):
                    self.optimizer.step()

                if self.warp_optimizer is not None:
                    self.warp_optimizer.step()

                if self.ensembler_optimizer is not None:
                    self.ensembler_optimizer.step()

                if self.optimizer is not None:
                    for param_group in self.optimizer.param_groups:
                        wandb.log({"maml_lr": param_group['lr']}, commit=False)
                        break

                if self.warp_optimizer is not None:
                    for param_group in self.warp_optimizer.param_groups:
                        wandb.log({"warp_lr": param_group['lr']}, commit=False)
                        break

                if self.ensembler_optimizer is not None:
                    for param_group in self.ensembler_optimizer.param_groups:
                        wandb.log({"ensembler_lr": param_group['lr']}, commit=False)
                        break

                if self.scheduler is not None:
                    self.scheduler.step()

                if self.warp_scheduler is not None:
                    self.warp_scheduler.step()

                if self.ensembler_scheduler is not None:
                    self.ensembler_scheduler.step()

                """
                for module in self.model.modules():
                    if isinstance(module, BatchLinear):
                        module.reset_parameters()
                """

                num_batches += 1
Example #10
0
    meta_optimizer = torch.optim.Adam(benchmark.model.parameters(),
                                      lr=args.meta_lr)

metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                        meta_optimizer,
                                        first_order=args.first_order,
                                        num_adaptation_steps=args.num_steps,
                                        step_size=args.step_size,
                                        loss_function=benchmark.loss_function,
                                        device=device)

best_value = None

from maml.utils import tensors_to_device, compute_accuracy
out = next(iter(meta_train_dataloader))
out_c = tensors_to_device(out, device='cuda')

benchmark.model.load_state_dict(torch.load('model.th'))

# Training loop
epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs)))
for epoch in range(args.num_epochs):
    print(epoch)
    metalearner.train(meta_train_dataloader,
                      max_batches=args.num_batches,
                      verbose=args.verbose,
                      desc='Training',
                      leave=False)
    results = metalearner.evaluate(meta_val_dataloader,
                                   max_batches=args.num_batches,
                                   verbose=args.verbose,