def inner_loop_adapt(self, task, num_updates=None, analysis=False, iter=None): # adapt means doing the complete inner loop update measurements_trajectory = defaultdict(list) if analysis: grad_norm_by_step = [ ] # records the gradient norm at every inner loop step grad_quantiles_by_step = defaultdict(list) adapted_param_dict = OrderedDict() adapted_param_dict[ 'classifier.fully_connected.weight'] = self._model.classifier.fully_connected.weight adapted_param_dict[ 'classifier.fully_connected.bias'] = self._model.classifier.fully_connected.bias modulation = self._embedding_model(task, return_task_embedding=False) if num_updates is None: # if num_updates is not specified # apply inner loop update for self._num_updates times num_updates = self._num_updates for i in range(num_updates): # here model is just a functional template # all of the parameters are passed in through params and embeddings adapted_param_dict, measurements, grad_list = \ self.inner_loop_one_step_gradient_descent(task=task, adapted_param_dict=adapted_param_dict, modulation=modulation, return_grad_list=analysis) # add this step's measurement to its trajectory for key in measurements.keys(): measurements_trajectory[key].append(measurements[key]) if analysis: grad_norm_by_step.append(get_grad_norm(grad_list)) grad_quantiles_by_step[i + 1].extend( get_grad_quantiles(grad_list)) with torch.no_grad( ): # compute the train loss after the last adaptation preds = self._model(task.x, modulation=modulation, update_params=adapted_param_dict) loss = self._inner_loss_func(preds, task.y) measurements_trajectory['loss'].append(loss.item()) if self.is_classification: measurements_trajectory['accu'].append(accuracy(preds, task.y)) info_dict = None if analysis: info_dict = {} info_dict['grad_norm_by_step'] = grad_norm_by_step info_dict['grad_quantiles_by_step'] = grad_quantiles_by_step info_dict['modulation'] = modulation return adapted_param_dict, measurements_trajectory, info_dict
def inner_loop_adapt(self, task, num_updates=None, analysis=False, iter=None): # adapt means doing the complete inner loop update measurements_trajectory = defaultdict(list) if analysis: grad_norm_by_step = [ ] # records the gradient norm at every inner loop step grad_quantiles_by_step = defaultdict(list) adapted_param_dict = self._model.param_dict # parameters to be updated in the inner loop if isinstance(self._embedding_model, LSTMAttentionEmbeddingModel): layer_modulations = self._embedding_model( task, return_task_embedding=False, iter=iter) else: layer_modulations = self._embedding_model( task, return_task_embedding=False) # apply inner loop update for self._num_updates times if num_updates is None: num_updates = self._num_updates for i in range(num_updates): # here model is just a functional template # all of the parameters are passed in through params and embeddings adapted_param_dict, measurements, grad_list = \ self.inner_loop_one_step_gradient_descent( task=task, layer_modulations=layer_modulations, param_dict=adapted_param_dict, return_grad_list=analysis) # add this step's measurement to its trajectory for key in measurements.keys(): measurements_trajectory[key].append(measurements[key]) if analysis: grad_norm_by_step.append(get_grad_norm(grad_list)) grad_quantiles_by_step[i + 1].extend( get_grad_quantiles(grad_list)) with torch.no_grad( ): # compute the train loss after the last adaptation preds = self._model(task.x, params=adapted_param_dict, layer_modulations=layer_modulations) loss = self._inner_loss_func(preds, task.y) measurements_trajectory['loss'].append(loss.item()) if self.is_classification: measurements_trajectory['accu'].append(accuracy(preds, task.y)) info_dict = None if analysis: info_dict = {} info_dict['grad_norm_by_step'] = grad_norm_by_step info_dict['grad_quantiles_by_step'] = grad_quantiles_by_step info_dict['layer_modulations'] = layer_modulations return adapted_param_dict, measurements_trajectory, info_dict
def inner_loop_one_step_gradient_descent(self, task, adapted_param_dict, modulation, return_grad_list=False): """Apply one step of gradient descent on self._inner_loss_func, based on data in the single task from argument task with respect to parameters in param_dict with step-size `self._fast_lr`, and returns the updated parameters loss before adaptation gradient if return_grad_list=True """ preds = self._model(task.x, modulation=modulation, update_params=adapted_param_dict) loss = self._inner_loss_func(preds, task.y) measurements = {} measurements['loss'] = loss.item() if self.is_classification: measurements['accu'] = accuracy(preds, task.y) create_graph = not self._first_order grad_list = torch.autograd.grad(loss, adapted_param_dict.values(), create_graph=create_graph, allow_unused=False) # allow_unused If False, specifying inputs that were not used when computing outputs # (and therefore their grad is always zero) is an error. Defaults to False. clip_grad = (self._inner_loop_grad_clip > 0) if clip_grad: clip_grad_list = [] for (name, param), grad in zip(adapted_param_dict.items(), grad_list): # grad will be torch.Tensor assert grad is not None if clip_grad: grad = soft_clip(grad, clip_value=self._inner_loop_grad_clip, slope=self._inner_loop_soft_clip_slope) clip_grad_list.append(grad) adapted_param_dict[name] = param - self._fast_lr * grad if return_grad_list: if clip_grad: grad_list = clip_grad_list else: grad_list = None return adapted_param_dict, measurements, grad_list
def _update_measurements(self, task, loss, preds): self._count_iters += 1.0 self._cum_loss += loss.data.cpu().numpy() if self._collect_accuracies: self._cum_accuracy += accuracy(preds, task.y).data.cpu().numpy()
def run(self, dataset_iterator, is_training=False, start=1, stop=1): # looping through the entire meta_dataset once sum_train_measurements_trajectory_over_meta_set = defaultdict(float) sum_test_measurements_before_adapt_over_meta_set = defaultdict(float) sum_test_measurements_after_adapt_over_meta_set = defaultdict(float) n_tasks = 0 for i, (train_task_batch, test_task_batch) in tqdm( enumerate(dataset_iterator, start=start if is_training else 1)): if is_training and i == stop: return { 'train_loss_trajectory': divide_measurements( sum_train_measurements_trajectory_over_meta_set, n=n_tasks), 'test_loss_before': divide_measurements( sum_test_measurements_before_adapt_over_meta_set, n=n_tasks), 'test_loss_after': divide_measurements( sum_test_measurements_after_adapt_over_meta_set, n=n_tasks) } # _meta_dataset yields data iteration train_measurements_trajectory_over_batch = defaultdict(list) test_measurements_before_adapt_over_batch = defaultdict(list) test_measurements_after_adapt_over_batch = defaultdict(list) analysis = (i % self._log_interval == 0 or i == 1) modulation_analysis = hasattr(self._algorithm, '_embedding_model') and \ isinstance(self._algorithm._embedding_model, LSTMAttentionEmbeddingModel) if analysis and is_training: grad_norm_by_step_over_batch = [] grad_quantiles_by_step_over_batch = defaultdict(list) if modulation_analysis: task_modulation_params_over_batch = [] batch_size = len(train_task_batch) sum_test_loss_after_adapt = 0.0 for train_task, test_task in zip(train_task_batch, test_task_batch): # evalute test loss before adapt over train with torch.no_grad(): test_pred_before_adapt = self._algorithm.predict_without_adapt( train_task, test_task.x) test_loss_before_adapt = self._outer_loss_func( test_pred_before_adapt, test_task.y) test_measurements_before_adapt_over_batch['loss'].append( test_loss_before_adapt.item()) if self._algorithm.is_classification: test_measurements_before_adapt_over_batch[ 'accu'].append( accuracy(test_pred_before_adapt, test_task.y)) # adapt according train_task adapted_param_dict, train_measurements_trajectory, info_dict = \ self._algorithm.inner_loop_adapt(train_task, analysis=analysis and is_training, iter=i) for key, measurements in train_measurements_trajectory.items(): train_measurements_trajectory_over_batch[key].append( measurements) if analysis and is_training: grad_norm_by_step = info_dict['grad_norm_by_step'] grad_quantiles_by_step = info_dict[ 'grad_quantiles_by_step'] grad_norm_by_step_over_batch.append(grad_norm_by_step) for step, quantiles in grad_quantiles_by_step.items(): grad_quantiles_by_step_over_batch[step].append( quantiles) if modulation_analysis: task_modulation_params = info_dict['layer_modulations'] task_modulation_params_over_batch.append( task_modulation_params) test_pred_after_adapt = self._algorithm.predict_without_adapt( train_task, test_task.x, param_dict=adapted_param_dict) test_loss_after_adapt = self._outer_loss_func( test_pred_after_adapt, test_task.y) sum_test_loss_after_adapt += test_loss_after_adapt test_measurements_after_adapt_over_batch['loss'].append( test_loss_after_adapt.item()) if self._algorithm.is_classification: test_measurements_after_adapt_over_batch['accu'].append( accuracy(test_pred_after_adapt, test_task.y)) update_sum_measurements_trajectory( sum_train_measurements_trajectory_over_meta_set, train_measurements_trajectory_over_batch) update_sum_measurements( sum_test_measurements_before_adapt_over_meta_set, test_measurements_before_adapt_over_batch) update_sum_measurements( sum_test_measurements_after_adapt_over_meta_set, test_measurements_after_adapt_over_batch) n_tasks += batch_size if is_training: avg_test_loss_after_adapt = sum_test_loss_after_adapt / batch_size # torch.mean(torch.stack(test_measurements_after_adapt_over_batch['loss'])) # make list a torch.tensor self._outer_optimizer.zero_grad() avg_test_loss_after_adapt.backward( ) # here back prop will propagate all the way to the initialization parameters outer_grad_norm_before_clip = get_grad_norm_from_parameters( self._algorithm._model.parameters()) self._writer.add_scalar('outer_grad/model_norm/before_clip', outer_grad_norm_before_clip, i) if self._outer_loop_grad_norm > 0.: clip_grad_norm_(self._algorithm._model.parameters(), self._outer_loop_grad_norm) #clip_grad_norm_(self._algorithm._embedding_model.parameters(), self._outer_loop_grad_norm) self._outer_optimizer.step() # logging # (i % self._log_interval == 0 or i == 1) if analysis and is_training: self.log_output(i, train_measurements_trajectory_over_batch, test_measurements_before_adapt_over_batch, test_measurements_after_adapt_over_batch, write_tensorboard=is_training) if is_training: self.write_gradient_info_to_board( i, grad_norm_by_step_over_batch, grad_quantiles_by_step_over_batch) if modulation_analysis: metadata = [ str(t.task_info['task_id']) for t in train_task_batch ] self.write_embeddings_output_to_board( task_modulation_params_over_batch, metadata, i) # Save model if (i % self._save_interval == 0 or i == 1) and is_training: save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i) save_path = os.path.join(self._save_folder, save_name) with open(save_path, 'wb') as f: torch.save(self._algorithm.state_dict(), f) results = { 'train_loss_trajectory': divide_measurements( sum_train_measurements_trajectory_over_meta_set, n=n_tasks), 'test_loss_before': divide_measurements( sum_test_measurements_before_adapt_over_meta_set, n=n_tasks), 'test_loss_after': divide_measurements( sum_test_measurements_after_adapt_over_meta_set, n=n_tasks) } if not is_training: self.log_output(start, results['train_loss_trajectory'], results['test_loss_before'], results['test_loss_after'], write_tensorboard=True, meta_val=True) return results