def run_maml(params, test_tasks, device): if 'min' == params['dataset']: print('Loading Mini-ImageNet model') model = MiniImagenetCNN(params['ways']) else: print('Loading Omniglot model') model = OmniglotCNN(params['ways']) # Evaluate the model at every checkpoint if eval_iters: ckpnt = base_path + "/model_checkpoints/" model_ckpnt_results = {} for model_ckpnt in os.scandir(ckpnt): if model_ckpnt.path.endswith(".pt"): print(f'Testing {model_ckpnt.path}') res = evaluate_maml(params, model, test_tasks, device, model_ckpnt.path) model_ckpnt_results[model_ckpnt.path] = res with open(base_path + '/ckpnt_results.json', 'w') as fp: json.dump(model_ckpnt_results, fp, sort_keys=True, indent=4) final_model = base_path + '/model.pt' if meta_test: evaluate_maml(params, model, test_tasks, device, final_model) # Run a Continual Learning experiment if cl_exp: print("Running Continual Learning experiment...") model.load_state_dict(torch.load(final_model)) model.to(device) maml = MAML(model, lr=cl_params['inner_lr'], first_order=False) loss = torch.nn.CrossEntropyLoss(reduction='mean') run_cl_exp(base_path, maml, loss, test_tasks, device, params['ways'], params['shots'], cl_params=cl_params) # Run a Representation change experiment if rep_exp: model.load_state_dict(torch.load(final_model)) model.to(device) maml = MAML(model, lr=rep_params['inner_lr'], first_order=False) loss = torch.nn.CrossEntropyLoss(reduction='mean') print("Running Representation experiment...") run_rep_exp(base_path, maml, loss, test_tasks, device, params['ways'], params['shots'], rep_params=rep_params)
def evaluate_anil(params, features, head, test_tasks, device, features_path, head_path): features.load_state_dict(torch.load(features_path)) features.to(device) head.load_state_dict(torch.load(head_path)) head = MAML(head, lr=params['inner_lr']) head.to(device) loss = torch.nn.CrossEntropyLoss(reduction='mean') return evaluate(params, test_tasks, head, loss, device, features=features)
def run(self, env, device): baseline = ch.models.robotics.LinearValue(env.state_size, env.action_size) policy = DiagNormalPolicy(env.state_size, env.action_size) self.log_model(policy, device, input_shape=(1, env.state_size)) t = trange(self.params['num_iterations'], desc='Iteration', position=0) try: for iteration in t: iter_reward = 0.0 task_list = env.sample_tasks(self.params['batch_size']) for task_i in trange(len(task_list), leave=False, desc='Task', position=0): task = task_list[task_i] env.set_task(task) env.reset() task = Runner(env) episodes = task.run(policy, episodes=params['n_episodes']) task_reward = episodes.reward().sum().item( ) / params['n_episodes'] iter_reward += task_reward # Log average_return = iter_reward / self.params['batch_size'] metrics = {'average_return': average_return} t.set_postfix(metrics) self.log_metrics(metrics) if iteration % self.params['save_every'] == 0: self.save_model_checkpoint(policy, str(iteration + 1)) self.save_model_checkpoint( baseline, 'baseline_' + str(iteration + 1)) # Support safely manually interrupt training except KeyboardInterrupt: print( '\nManually stopped training! Start evaluation & saving...\n') self.logger['manually_stopped'] = True self.params['num_iterations'] = iteration self.save_model(policy) self.save_model(baseline, name='baseline') self.logger['elapsed_time'] = str(round(t.format_dict['elapsed'], 2)) + ' sec' # Evaluate on new test tasks policy = MAML(policy, lr=self.params['inner_lr']) self.logger['test_reward'] = evaluate_ppo(env_name, policy, baseline, params) self.log_metrics({'test_reward': self.logger['test_reward']}) self.save_logs_to_file()
def evaluate_maml(params, model, test_tasks, device, path): model.load_state_dict(torch.load(path)) model.to(device) maml = MAML(model, lr=params['inner_lr'], first_order=False) loss = torch.nn.CrossEntropyLoss(reduction='mean') return evaluate(params, test_tasks, maml, loss, device)
def run_anil(params, test_tasks, device): # ANIL if 'omni' == params['dataset']: print('Loading Omniglot model') fc_neurons = 128 features = ConvBase(output_size=64, hidden=32, channels=1, max_pool=False) else: print('Loading Mini-ImageNet model') fc_neurons = 1600 features = ConvBase(output_size=64, channels=3, max_pool=True) features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, fc_neurons))) head = torch.nn.Linear(fc_neurons, params['ways']) head = MAML(head, lr=params['inner_lr']) # Evaluate the model at every checkpoint if eval_iters: ckpnt = base_path + "/model_checkpoints/" model_ckpnt_results = {} for model_ckpnt in os.scandir(ckpnt): if model_ckpnt.path.endswith(".pt"): if "features" in model_ckpnt.path: features_path = model_ckpnt.path head_path = str.replace(features_path, "features", "head") print(f'Testing {model_ckpnt.path}') res = evaluate_anil(params, features, head, test_tasks, device, features_path, head_path) model_ckpnt_results[model_ckpnt.path] = res with open(base_path + '/ckpnt_results.json', 'w') as fp: json.dump(model_ckpnt_results, fp, sort_keys=True, indent=4) final_features = base_path + '/features.pt' final_head = base_path + '/head.pt' if meta_test: evaluate_anil(params, features, head, test_tasks, device, final_features, final_head) if cl_exp: print("Running Continual Learning experiment...") features.load_state_dict(torch.load(final_features)) features.to(device) head.load_state_dict(torch.load(final_head)) head.to(device) loss = torch.nn.CrossEntropyLoss(reduction='mean') run_cl_exp(base_path, head, loss, test_tasks, device, params['ways'], params['shots'], cl_params=cl_params, features=features) if rep_exp: features.load_state_dict(torch.load(final_features)) features.to(device) head.load_state_dict(torch.load(final_head)) head.to(device) loss = torch.nn.CrossEntropyLoss(reduction='mean') # Only check head change rep_params['layers'] = [-1] print("Running Representation experiment...") run_rep_exp(base_path, head, loss, test_tasks, device, params['ways'], params['shots'], rep_params=rep_params, features=features)
def run(self, train_tasks, valid_tasks, test_tasks, model, input_shape, device): model.to(device) optimizer = torch.optim.Adam(model.parameters(), self.params['lr']) loss = torch.nn.CrossEntropyLoss(reduction='mean') self.log_model( model, device, input_shape=input_shape) # Input shape is specific to dataset n_batch_iter = int(320 / self.params['meta_batch_size']) t = trange(self.params['num_iterations']) try: for iteration in t: # Initialize iteration's metrics train_loss = 0.0 train_accuracy = 0.0 valid_loss = 0.0 valid_accuracy = 0.0 for task in range(n_batch_iter): data, labels = train_tasks.sample() data, labels = data.to(device), labels.to(device) optimizer.zero_grad() predictions = model(data) batch_loss = loss(predictions, labels) batch_accuracy = accuracy(predictions, labels) batch_loss.backward() optimizer.step() train_loss += batch_loss.item() train_accuracy += batch_accuracy.item() if valid_tasks is not None: with torch.no_grad(): for task in range(n_batch_iter): valid_data, valid_labels = train_tasks.sample() valid_data, valid_labels = valid_data.to( device), valid_labels.to(device) predictions = model(valid_data) valid_loss += loss(predictions, valid_labels) valid_accuracy += accuracy(predictions, valid_labels) train_loss = train_loss / n_batch_iter valid_loss = valid_loss / n_batch_iter train_accuracy = train_accuracy / n_batch_iter valid_accuracy = valid_accuracy / n_batch_iter metrics = { 'train_loss': train_loss, 'train_acc': train_accuracy, 'valid_loss': valid_loss, 'valid_acc': valid_accuracy } t.set_postfix(metrics) self.log_metrics(metrics) if iteration % self.params['save_every'] == 0: self.save_model_checkpoint(model, str(iteration)) # Support safely manually interrupt training except KeyboardInterrupt: print( '\nManually stopped training! Start evaluation & saving...\n') self.logger['manually_stopped'] = True self.params['num_iterations'] = iteration self.save_model(model) self.logger['elapsed_time'] = str(round(t.format_dict['elapsed'], 2)) + ' sec' # Testing on unseen tasks model = MAML(model, lr=self.params['lr']) self.params['adapt_steps'] = 1 # Placeholder value for evaluation self.logger['test_acc'] = evaluate(self.params, test_tasks, model, loss, device) self.log_metrics({'test_acc': self.logger['test_acc']}) self.save_logs_to_file()