def run(self, train_tasks, valid_tasks, test_tasks, input_shape, device): # Create model if dataset == "omni": features = ConvBase(output_size=64, hidden=32, channels=1, max_pool=False) else: features = ConvBase(output_size=64, channels=3, max_pool=True) features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, fc_neurons))) features.to(device) head = torch.nn.Linear(fc_neurons, self.params['ways']) head = MAML(head, lr=self.params['inner_lr']) head.to(device) # Setup optimization all_parameters = list(features.parameters()) + list(head.parameters()) optimizer = torch.optim.Adam(all_parameters, lr=self.params['outer_lr']) loss = torch.nn.CrossEntropyLoss(reduction='mean') self.log_model(features, device, input_shape=input_shape, name='features') # Input shape is specific to dataset head_input_shape = (self.params['ways'], fc_neurons) self.log_model(head, device, input_shape=head_input_shape, name='head') # Input shape is specific to dataset t = trange(self.params['num_iterations']) try: for iteration in t: optimizer.zero_grad() meta_train_loss = 0.0 meta_train_accuracy = 0.0 meta_valid_loss = 0.0 meta_valid_accuracy = 0.0 for task in range(self.params['meta_batch_size']): # Compute meta-training loss learner = head.clone() batch = train_tasks.sample() eval_loss, eval_acc = fast_adapt(batch, learner, loss, self.params['adapt_steps'], self.params['shots'], self.params['ways'], device, features=features) eval_loss.backward() meta_train_loss += eval_loss.item() meta_train_accuracy += eval_acc.item() # Compute meta-validation loss learner = head.clone() batch = valid_tasks.sample() eval_loss, eval_acc = fast_adapt(batch, learner, loss, self.params['adapt_steps'], self.params['shots'], self.params['ways'], device, features=features) meta_valid_loss += eval_loss.item() meta_valid_accuracy += eval_acc.item() meta_train_loss = meta_train_loss / self.params['meta_batch_size'] meta_valid_loss = meta_valid_loss / self.params['meta_batch_size'] meta_train_accuracy = meta_train_accuracy / self.params['meta_batch_size'] meta_valid_accuracy = meta_valid_accuracy / self.params['meta_batch_size'] metrics = {'train_loss': meta_train_loss, 'train_acc': meta_train_accuracy, 'valid_loss': meta_valid_loss, 'valid_acc': meta_valid_accuracy} t.set_postfix(metrics) self.log_metrics(metrics) # Average the accumulated gradients and optimize for p in all_parameters: p.grad.data.mul_(1.0 / self.params['meta_batch_size']) optimizer.step() if iteration % self.params['save_every'] == 0: self.save_model_checkpoint(features, 'features_' + str(iteration + 1)) self.save_model_checkpoint(head, 'head_' + str(iteration + 1)) # Support safely manually interrupt training except KeyboardInterrupt: print('\nManually stopped training! Start evaluation & saving...\n') self.logger['manually_stopped'] = True self.params['num_iterations'] = iteration self.save_model(features, name='features') self.save_model(head, name='head') self.logger['elapsed_time'] = str(round(t.format_dict['elapsed'], 2)) + ' sec' # Meta-testing on unseen tasks self.logger['test_acc'] = evaluate(self.params, test_tasks, head, loss, device, features=features) self.log_metrics({'test_acc': self.logger['test_acc']}) self.save_logs_to_file()
evaluation_error = fast_adapt_est_h(batch, learner, loss_fn, update_step, shots, h_est_force_z) meta_valid_error += evaluation_error.item() # Print some metrics print('\n') print('Iteration', iteration) print('Meta Train Loss', meta_train_error / batch_size) print('Meta Valid Loss', meta_valid_error / batch_size) # Average the accumulated gradients and optimize for p in maml.parameters(): p.grad.data.mul_(1.0 / batch_size) optimizer.step() if iteration % 50 == 0: maml_bf_gains_val = [] scratch_bf_gains_val = [] dft_bf_gains_val = [] for test_iter in range(nval): dataset.change_cluster() sample_idc_train = dataset.sample() x_train = h_concat_scaled[sample_idc_train,:] y_train = egc_gain_scaled[sample_idc_train] # model_maml = maml.module.clone()
def run(self, train_tasks, valid_tasks, test_tasks, model, input_shape, device): model.to(device) maml = MAML(model, lr=self.params['inner_lr'], first_order=False) opt = torch.optim.Adam(maml.parameters(), self.params['outer_lr']) loss = torch.nn.CrossEntropyLoss(reduction='mean') self.log_model(maml, device, input_shape=input_shape) # Input shape is specific to dataset t = trange(self.params['num_iterations']) try: for iteration in t: # Clear the gradients after successfully back-propagating through the whole network opt.zero_grad() # Initialize iteration's metrics meta_train_loss = 0.0 meta_train_accuracy = 0.0 meta_valid_loss = 0.0 meta_valid_accuracy = 0.0 # Inner (Adaptation) loop for task in range(self.params['meta_batch_size']): # Compute meta-training loss learner = maml.clone() batch = train_tasks.sample() eval_loss, eval_acc = fast_adapt(batch, learner, loss, self.params['adapt_steps'], self.params['shots'], self.params['ways'], device) # Calculate the gradients of the now updated parameters of the model using the evaluation loss! eval_loss.backward() meta_train_loss += eval_loss.item() meta_train_accuracy += eval_acc.item() # Compute meta-validation loss learner = maml.clone() batch = valid_tasks.sample() eval_loss, eval_acc = fast_adapt(batch, learner, loss, self.params['adapt_steps'], self.params['shots'], self.params['ways'], device) meta_valid_loss += eval_loss.item() meta_valid_accuracy += eval_acc.item() meta_train_loss = meta_train_loss / self.params['meta_batch_size'] meta_valid_loss = meta_valid_loss / self.params['meta_batch_size'] meta_train_accuracy = meta_train_accuracy / self.params['meta_batch_size'] meta_valid_accuracy = meta_valid_accuracy / self.params['meta_batch_size'] metrics = {'train_loss': meta_train_loss, 'train_acc': meta_train_accuracy, 'valid_loss': meta_valid_loss, 'valid_acc': meta_valid_accuracy} t.set_postfix(metrics) self.log_metrics(metrics) # Average the accumulated gradients and optimize for p in maml.parameters(): p.grad.data.mul_(1.0 / self.params['meta_batch_size']) opt.step() if iteration % self.params['save_every'] == 0: self.save_model_checkpoint(model, str(iteration)) # Support safely manually interrupt training except KeyboardInterrupt: print('\nManually stopped training! Start evaluation & saving...\n') self.logger['manually_stopped'] = True self.params['num_iterations'] = iteration self.save_model(model) self.logger['elapsed_time'] = str(round(t.format_dict['elapsed'], 2)) + ' sec' # Meta-testing on unseen tasks self.logger['test_acc'] = evaluate(self.params, test_tasks, maml, loss, device) self.log_metrics({'test_acc': self.logger['test_acc']}) self.save_logs_to_file()