def evaluate_anil(params, features, head, test_tasks, device, features_path, head_path): features.load_state_dict(torch.load(features_path)) features.to(device) head.load_state_dict(torch.load(head_path)) head = MAML(head, lr=params['inner_lr']) head.to(device) loss = torch.nn.CrossEntropyLoss(reduction='mean') return evaluate(params, test_tasks, head, loss, device, features=features)
def run_anil(params, test_tasks, device): # ANIL if 'omni' == params['dataset']: print('Loading Omniglot model') fc_neurons = 128 features = ConvBase(output_size=64, hidden=32, channels=1, max_pool=False) else: print('Loading Mini-ImageNet model') fc_neurons = 1600 features = ConvBase(output_size=64, channels=3, max_pool=True) features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, fc_neurons))) head = torch.nn.Linear(fc_neurons, params['ways']) head = MAML(head, lr=params['inner_lr']) # Evaluate the model at every checkpoint if eval_iters: ckpnt = base_path + "/model_checkpoints/" model_ckpnt_results = {} for model_ckpnt in os.scandir(ckpnt): if model_ckpnt.path.endswith(".pt"): if "features" in model_ckpnt.path: features_path = model_ckpnt.path head_path = str.replace(features_path, "features", "head") print(f'Testing {model_ckpnt.path}') res = evaluate_anil(params, features, head, test_tasks, device, features_path, head_path) model_ckpnt_results[model_ckpnt.path] = res with open(base_path + '/ckpnt_results.json', 'w') as fp: json.dump(model_ckpnt_results, fp, sort_keys=True, indent=4) final_features = base_path + '/features.pt' final_head = base_path + '/head.pt' if meta_test: evaluate_anil(params, features, head, test_tasks, device, final_features, final_head) if cl_exp: print("Running Continual Learning experiment...") features.load_state_dict(torch.load(final_features)) features.to(device) head.load_state_dict(torch.load(final_head)) head.to(device) loss = torch.nn.CrossEntropyLoss(reduction='mean') run_cl_exp(base_path, head, loss, test_tasks, device, params['ways'], params['shots'], cl_params=cl_params, features=features) if rep_exp: features.load_state_dict(torch.load(final_features)) features.to(device) head.load_state_dict(torch.load(final_head)) head.to(device) loss = torch.nn.CrossEntropyLoss(reduction='mean') # Only check head change rep_params['layers'] = [-1] print("Running Representation experiment...") run_rep_exp(base_path, head, loss, test_tasks, device, params['ways'], params['shots'], rep_params=rep_params, features=features)