def train_iter(self, dataloader, max_batches=500, is_classification_task=False): """ Runs one epoch of meta-training on the given dataset, yielding training statistics in batches. Exactly `max_batches` batches of tasks will be processed. """ if self.optimizer is None: raise RuntimeError('Trying to call `train_iter`, while the ' 'optimizer is `None`. In order to train `{0}`, you must ' 'specify a Pytorch optimizer as the argument of `{0}` ' '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) num_batches = 0 self.model.train() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break if self.scheduler is not None: self.scheduler.step(epoch=num_batches) self.optimizer.zero_grad() batch = tensors_to_device(batch, device=self.device) outer_loss, results = self.get_outer_loss(batch, is_classification_task=is_classification_task) yield results outer_loss.backward() self.optimizer.step() num_batches += 1
def train_iter(self, dataloader, max_batches=500): if self.optimizer is None: raise RuntimeError( 'Trying to call `train_iter`, while the ' 'optimizer is `None`. In order to train `{0}`, you must ' 'specify a Pytorch optimizer as the argument of `{0}` ' '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) num_batches = 0 self.model.train() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break if self.scheduler is not None: self.scheduler.step(epoch=num_batches) self.optimizer.zero_grad() batch = tensors_to_device(batch, device=self.device) outer_loss, results = self.get_outer_loss(batch) yield results outer_loss.backward() self.optimizer.step() num_batches += 1
def train_iter(self, dataloader, accumulation_steps=1, max_batches=500, is_classification_task=False): """ Runs one epoch of meta-training on the given dataset, yielding training statistics in batches. Exactly `max_batches` batches of tasks will be processed. """ if self.optimizer is None: raise RuntimeError( 'Trying to call `train_iter`, while the ' 'optimizer is `None`. In order to train `{0}`, you must ' 'specify a Pytorch optimizer as the argument of `{0}` ' '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) num_batches = 0 self.model.train() all_results = {} while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: self.optimizer.step() self.optimizer.zero_grad() break if self.scheduler is not None: self.scheduler.step(epoch=num_batches) batch = tensors_to_device(batch, device=self.device) outer_loss, results = self.get_outer_loss( batch, is_classification_task=is_classification_task) outer_loss /= accumulation_steps outer_loss.backward() # Add the results from this batch if not all_results: all_results = dict(results) else: for key in all_results: all_results[key] += results[key] # Take a step with the accumulated gradients, and reset optimizer & results if (num_batches + 1) % accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() all_results = { key: val / accumulation_steps for key, val in all_results.items() if not key in ('num_tasks', ) } yield all_results all_results = {} num_batches += 1
def evaluate_iter(self, dataloader, max_batches=500): num_batches = 0 self.model.eval() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break batch = tensors_to_device(batch, device=self.device) _, results = self.get_outer_loss(batch) yield results num_batches += 1
def evaluate_iter(self, dataloader, max_batches=500): num_batches = 0 self.model.eval() if self.warp_model is not None: self.warp_model.eval() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break batch = tensors_to_device(batch, device=self.device) _, results = self.get_outer_loss(batch, eval_mode=True, write_params=False) yield results num_batches += 1
def evaluate_iter(self, dataloader, max_batches=500, is_classification_task=False): """ Yields the validation loss (outer loss) in batches without performing any gradient steps. """ num_batches = 0 self.model.eval() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break batch = tensors_to_device(batch, device=self.device) _, results = self.get_outer_loss(batch, is_classification_task=is_classification_task) yield results num_batches += 1
def train_iter(self, dataloader, max_batches=500): if self.optimizer is None: raise RuntimeError( 'Trying to call `train_iter`, while the ' 'optimizer is `None`. In order to train `{0}`, you must ' 'specify a Pytorch optimizer as the argument of `{0}` ' '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) num_batches = 0 self.model.train() while num_batches < max_batches: # import pdb # pdb.set_trace() for batches in zip(dataloader[0], cycle(dataloader[1])): if num_batches >= max_batches: break if self.scheduler is not None: self.scheduler.step(epoch=num_batches) self.optimizer.zero_grad() batches = [ tensors_to_device(batch, device=self.device) for batch in batches ] outer_loss, results, outer_losses = self.get_outer_loss( batches) yield results if self.PCGrad: self.optimizer.pc_backward( outer_losses ) # calculate the gradient can apply gradient modification self.optimizer.step() # apply gradient step else: outer_loss.backward() self.optimizer.step() num_batches += 1
def train_iter(self, dataloader, task_weighting: TaskWeightingBase, weight_normalizer: WeightNormalizer, epoch, max_batches): if self.optimizer is None: raise RuntimeError('Trying to call `train_iter`, while the ' 'optimizer is `None`. In order to train `{0}`, you must ' 'specify a Pytorch optimizer as the argument of `{0}` ' '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) num_batches = 0 self.model.train() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break if self.scheduler is not None: self.scheduler.step(epoch=num_batches) iteration = epoch * max_batches + num_batches self.optimizer.zero_grad() task_weighting.before_gradient_step(iteration, batch) batch = tensors_to_device(batch, device=self.device) outer_losses, results = self.get_outer_losses(task_weighting, batch) yield results loss = task_weighting.compute_weighted_loss(iteration, outer_losses) loss.backward() self.optimizer.step() task_weighting.update_inner_weights(iteration, outer_losses) if hasattr(task_weighting, 'weights'): task_weighting.weights = weight_normalizer.normalize(iteration, task_weighting.weights) num_batches += 1
def train_iter(self, dataloader, max_batches=500): """ if self.optimizer is None: raise RuntimeError('Trying to call `train_iter`, while the ' 'optimizer is `None`. In order to train `{0}`, you must ' 'specify a Pytorch optimizer as the argument of `{0}` ' '(eg. `{0}(model, optimizer=torch.optim.SGD(model.' 'parameters(), lr=0.01), ...).'.format(__class__.__name__)) """ num_batches = 0 self.model.train() if self.warp_model is not None: self.warp_model.train() while num_batches < max_batches: for batch in dataloader: if num_batches >= max_batches: break self.num_steps += 1 if self.optimizer is not None: self.optimizer.zero_grad() if self.warp_optimizer is not None: self.warp_optimizer.zero_grad() if self.ensembler_optimizer is not None: self.ensembler_optimizer.zero_grad() batch = tensors_to_device(batch, device=self.device) outer_loss, results = self.get_outer_loss(batch) yield results outer_loss.backward() if self.optimizer is not None and (self.num_maml_steps <= 0 or self.num_steps < self.num_maml_steps): self.optimizer.step() if self.warp_optimizer is not None: self.warp_optimizer.step() if self.ensembler_optimizer is not None: self.ensembler_optimizer.step() if self.optimizer is not None: for param_group in self.optimizer.param_groups: wandb.log({"maml_lr": param_group['lr']}, commit=False) break if self.warp_optimizer is not None: for param_group in self.warp_optimizer.param_groups: wandb.log({"warp_lr": param_group['lr']}, commit=False) break if self.ensembler_optimizer is not None: for param_group in self.ensembler_optimizer.param_groups: wandb.log({"ensembler_lr": param_group['lr']}, commit=False) break if self.scheduler is not None: self.scheduler.step() if self.warp_scheduler is not None: self.warp_scheduler.step() if self.ensembler_scheduler is not None: self.ensembler_scheduler.step() """ for module in self.model.modules(): if isinstance(module, BatchLinear): module.reset_parameters() """ num_batches += 1
meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=args.meta_lr) metalearner = ModelAgnosticMetaLearning(benchmark.model, meta_optimizer, first_order=args.first_order, num_adaptation_steps=args.num_steps, step_size=args.step_size, loss_function=benchmark.loss_function, device=device) best_value = None from maml.utils import tensors_to_device, compute_accuracy out = next(iter(meta_train_dataloader)) out_c = tensors_to_device(out, device='cuda') benchmark.model.load_state_dict(torch.load('model.th')) # Training loop epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs))) for epoch in range(args.num_epochs): print(epoch) metalearner.train(meta_train_dataloader, max_batches=args.num_batches, verbose=args.verbose, desc='Training', leave=False) results = metalearner.evaluate(meta_val_dataloader, max_batches=args.num_batches, verbose=args.verbose,