def evaluate_fn_vampire( x: torch.Tensor, f_hyper_net: higher.patch._MonkeyPatchBase, f_base_net: higher.patch._MonkeyPatchBase ) -> typing.List[torch.Tensor]: """ """ logits = [None] * config['num_models'] for model_id in range(config['num_models']): base_net_params = f_hyper_net.forward() logits_temp = f_base_net.forward(x, params=base_net_params) logits[model_id] = logits_temp return logits
def evaluate_fn_maml( x: torch.Tensor, f_hyper_net: higher.patch._MonkeyPatchBase, f_base_net: higher.patch._MonkeyPatchBase) -> torch.Tensor: """A base function to evaluate on the validation subset Args: x: the data in the validation subset f_hyper_net: the adapted meta-parameter f_base_net: the functional form of the base net Return: the logits of the prediction """ base_net_params = f_hyper_net.forward() logits = f_base_net.forward(x, params=base_net_params) return logits
def predict( self, x: torch.Tensor, f_hyper_net: higher.patch._MonkeyPatchBase, f_base_net: higher.patch._MonkeyPatchBase ) -> typing.List[torch.Tensor]: """Make prediction Args: x: input data f_hyper_net: task-specific meta-model f_base_net: functional form of the base neural network Returns: a list of logits predicted by the task-specific meta-model """ logits = [None] * self.config['num_models'] for model_id in range(self.config['num_models']): base_net_params = f_hyper_net.forward() logits_temp = f_base_net.forward(x, params=base_net_params) logits[model_id] = logits_temp return logits
def adapt_to_episode_innerloop( x: torch.Tensor, y: torch.Tensor, hyper_net: torch.nn.Module, f_base_net: higher.patch._MonkeyPatchBase, kl_div_fn: typing.Callable) -> higher.patch._MonkeyPatchBase: """Also known as inner loop Args: x, y: training data and label hyper_net: the model of interest f_base_net: the functional of the base model kl_divergence_loss: function calculating the KL divergence between variational posterior and prior Return: a MonkeyPatch module """ f_hyper_net = higher.patch.monkeypatch( module=hyper_net, copy_initial_weights=False, track_higher_grads=config['train_flag']) hyper_net_params = [p for p in hyper_net.parameters()] for _ in range(config['num_inner_updates']): for _ in range(config['num_models']): base_net_params = f_hyper_net.forward() y_logits = f_base_net.forward(x, params=base_net_params) cls_loss = torch.nn.functional.cross_entropy(input=y_logits, target=y) q_params = f_hyper_net.fast_params # list of parameters/tensors grad_accum = [0] * len(q_params) # KL divergence kl_div = kl_div_fn(p=hyper_net_params, q=q_params) loss = cls_loss + kl_div * kl_weight if config['first_order']: all_grads = torch.autograd.grad( outputs=loss, inputs=q_params, retain_graph=config['train_flag']) else: all_grads = torch.autograd.grad( outputs=loss, inputs=q_params, create_graph=config['train_flag']) for i in range(len(all_grads)): grad_accum[ i] = grad_accum[i] + all_grads[i] / config['num_models'] new_q_params = [] for param, grad in zip(q_params, grad_accum): new_q_params.append( higher.optim._add(tensor=param, a1=-config['inner_lr'], a2=grad)) f_hyper_net.update_params(new_q_params) return f_hyper_net
def adapt_to_episode( self, x: torch.Tensor, y: torch.Tensor, hyper_net: torch.nn.Module, f_base_net: higher.patch._MonkeyPatchBase, train_flag: bool = True) -> higher.patch._MonkeyPatchBase: """Inner-loop for MAML-like algorithm Args: x, y: training data and corresponding labels hyper_net: the meta-model f_base_net: the functional form of the based neural network kl_div_fn: function that calculates the KL divergence Returns: the task-specific meta-model """ # convert hyper_net to its functional form f_hyper_net = higher.patch.monkeypatch(module=hyper_net, copy_initial_weights=False, track_higher_grads=train_flag) hyper_net_params = [p for p in hyper_net.parameters()] for _ in range(self.config['num_inner_updates']): grads_accum = [0] * len( hyper_net_params ) # accumulate gradients of Monte Carlo sampling q_params = f_hyper_net.fast_params # parameters of the task-specific hyper_net # KL divergence KL_div = self.KL_divergence(p=hyper_net_params, q=q_params) for _ in range(self.config['num_models']): base_net_params = f_hyper_net.forward() y_logits = f_base_net.forward(x, params=base_net_params) cls_loss = torch.nn.functional.cross_entropy(input=y_logits, target=y) loss = cls_loss + self.config['KL_weight'] * KL_div if self.config['first_order']: grads = torch.autograd.grad(outputs=loss, inputs=q_params, retain_graph=True) else: grads = torch.autograd.grad(outputs=loss, inputs=q_params, create_graph=True) # accumulate gradients for i in range(len(grads)): grads_accum[i] = grads_accum[ i] + grads[i] / self.config['num_models'] new_q_params = [] for param, grad in zip(q_params, grads_accum): new_q_params.append( higher.optim._add(tensor=param, a1=-self.config['inner_lr'], a2=grad)) f_hyper_net.update_params(new_q_params) return f_hyper_net