def run_batches(self, batches, train=True, meta_train=True): metrics = [] device = next(self.model.parameters()).device shufflers = { key: shuffler() for key, shuffler in self._shuffler_factory.items() } with torch.backends.cudnn.flags(enabled=False): with higher.innerloop_ctx(self.model, self.inner_optimizer, copy_initial_weights=False) as (fmodel, diffopt): for n, inputs in enumerate(batches[:-1]): inputs = self.shuffle_labels(inputs, shufflers) inputs = move_to_device(inputs, device) output_dict = fmodel(**inputs, **self.forward_kwargs(n)) loss = output_dict["loss"] metric = output_dict["metric"] diffopt.step(loss) metrics.append({"loss": loss.item(), "metric": metric}) inputs = self.shuffle_labels(batches[-1], shufflers) inputs = move_to_device(inputs, device) output_dict = fmodel(**inputs, **self.forward_kwargs(len(batches) - 1)) loss = output_dict["loss"] metric = output_dict["metric"] loss.backward() metrics.append({"loss": loss.item(), "metric": metric}) return metrics
def maml_task(data_inner, data_outer, model, optimizer, create_graph): """Adapt model parameters to task and use adapted params to predict new samples Arguments: data_inner (iterable): list of input-output for task adaptation. data_outer (iterable): list of input-output for task validation. model (torch.nn.Module): task learner. optimizer (maml.optim): optimizer for inner loop. criterion (func): loss criterion. create_graph (bool): create graph through gradient step. """ metrics = [] original_parameters = model.state_dict(keep_vars=True) device = next(model.parameters()).device # Adaptation of parameters to task for i, inputs in enumerate(data_inner): inputs = move_to_device(inputs, device) loss, new_params, metric = maml_inner_step(inputs, model, optimizer, create_graph) metrics.append(metric) if create_graph: load_state_dict( model, build_dict([n for n, _ in model.named_parameters()], new_params)) for p in original_parameters.values(): p.grad = None # Run with adapted parameters on task for i, inputs in enumerate(data_outer): inputs = move_to_device(inputs, device) with torch.backends.cudnn.flags(enabled=False): output_dict = model(**inputs, return_metric=True) loss += output_dict["loss"] metrics.append(output_dict["metric"]) load_state_dict(model, original_parameters) return loss, metrics
def get_hidden_states(task): hidden_states = [] masks = [] labels = [] device = next(generator.parameters()).device for inputs in task[:self.first_n_states]: inputs = move_to_device(inputs, device) output_dict = generator(**inputs) hidden_states.append(output_dict["hidden_state"]) mask = output_dict["mask"] masks.append(mask) labels.append(inputs["langs"]) return pad_batched_tensors(hidden_states), \ torch.cat(labels, dim=0), \ pad_batched_tensors(masks)
def get_hidden_states(task): kl_losses = [] kl_divs = [] kl_div2s = [] device = next(generator.parameters()).device for inputs in task: inputs = move_to_device(inputs, device) output_dict = generator(**inputs, variational=True) kl_loss = output_dict["kl_loss"] kl_div = output_dict["kl_div"] kl_div2 = output_dict["kl_div2"] kl_losses.append(kl_loss) kl_divs.append(kl_div) kl_div2s.append(kl_div2) return kl_losses, kl_divs, kl_div2s
def run_batches(self, batches, train=False): metrics = [] device = next(self.model.parameters()).device for n, inputs in enumerate(batches): inputs = move_to_device(inputs, device) output_dict = self.model(**inputs, **self.forward_kwargs(n)) loss = output_dict["loss"] metric = output_dict["metric"] if torch.isnan(loss): raise ValueError("nan loss encountered") metrics.append({"loss": loss.item(), "metric": metric}) loss = loss / len(batches) loss = loss / self._counter loss.backward() return metrics
def run_batches(self, batches, optimizer, train=False, meta_train=False): """Iterate over task-specific batches. Arguments: batches (torch.utils.data.DataLoader): task-specific dataloaders. optimizer (torch.nn.optim): optimizer instance if training is True. train (bool): whether to train on task. meta_train (bool): whether to meta-train on task. """ metrics = [] N = len(batches) device = next(self._container.parameters()).device for n, inputs in enumerate(batches): optimizer.zero_grad() # task specific inputs = move_to_device(inputs, device) # Evaluate model output_dict = self._container(**inputs, **self.forward_kwargs(n)) loss = output_dict["loss"] metric = output_dict["metric"] if torch.isnan(loss): raise ValueError("nan loss encountered") metrics.append({"loss": loss.item(), "metric": metric}) # TRAINING # if not train: continue loss.backward() grad_norm = self.rescale_gradients() optimizer.step() if meta_train: self._partial_meta_update(n, len(batches)) return metrics