def main(args): torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') np.random.seed(args.seed) my_experiment = experiment(args.name, args, "./results/") args.classes = list(range(64)) # args.traj_classes = list(range(int(64 / 2), 963)) dataset = imgnet.MiniImagenet(args.dataset_path, mode='train') dataset_test = imgnet.MiniImagenet(args.dataset_path, mode='test') # Iterators used for evaluation iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=1) # logger.info(str(args)) config = mf.ModelFactory.get_model("na", args.dataset) maml = learner.Learner(config).to(device) opt = torch.optim.Adam(maml.parameters(), lr=args.lr) for e in range(args.epoch): correct = 0 for img, y in tqdm(iterator): if e == 50: opt = torch.optim.Adam(maml.parameters(), lr=0.00001) logger.info("Changing LR from %f to %f", 0.0001, 0.00001) img = img.to(device) y = y.to(device) pred = maml(img) feature = maml(img, feature=True) loss_rep = torch.abs(feature).sum() opt.zero_grad() loss = F.cross_entropy(pred, y) # loss_rep.backward(retain_graph=True) # logger.info("L1 norm = %s", str(loss_rep.item())) loss.backward() opt.step() correct += (pred.argmax(1) == y).sum().float() / len(y) logger.info("Accuracy at epoch %d = %s", e, str(correct / len(iterator))) # correct = 0 # with torch.no_grad(): # for img, y in tqdm(iterator_test): # # img = img.to(device) # y = y.to(device) # pred = maml(img) # feature = maml(img, feature=True) # loss_rep = torch.abs(feature).sum() # # correct += (pred.argmax(1) == y).sum().float() / len(y) # logger.info("Accuracy Test at epoch %d = %s", e, str(correct / len(iterator_test))) torch.save(maml, my_experiment.path + "baseline_pretraining_imagenet.net")
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) my_experiment = experiment(args.name, args, "/data5/jlindsey/continual/results", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') args.classes = list(range(963)) print('dataset', args.dataset, args.dataset == "imagenet") if args.dataset != "imagenet": dataset = df.DatasetFactory.get_dataset(args.dataset, background=True, train=True, all=True) dataset_test = df.DatasetFactory.get_dataset(args.dataset, background=True, train=False, all=True) else: args.classes = list(range(64)) dataset = imgnet.MiniImagenet(args.imagenet_path, mode='train') dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test') iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) logger.info("Train set length = %d", len(iterator_train) * 5) logger.info("Test set length = %d", len(iterator_test) * 5) sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes, dataset, dataset_test) config = mf.ModelFactory.get_model( args.model_type, args.dataset, width=args.width, num_extra_dense_layers=args.num_extra_dense_layers) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') if args.oja or args.hebb: maml = OjaMetaLearingClassification(args, config).to(device) else: print('starting up') maml = MetaLearingClassification(args, config).to(device) import sys if args.from_saved: maml.net = torch.load(args.model) if args.use_derivative: maml.net.use_derivative = True maml.net.optimize_out = args.optimize_out if maml.net.optimize_out: maml.net.feedback_strength_vars.append( torch.nn.Parameter(maml.net.init_feedback_strength * torch.ones(1).cuda())) if args.reset_feedback_strength: for fv in maml.net.feedback_strength_vars: w = nn.Parameter(torch.ones_like(fv) * args.feedback_strength) fv.data = w if args.reset_feedback_vars: print('howdy', maml.net.num_feedback_layers) maml.net.feedback_vars = nn.ParameterList() maml.net.feedback_vars_bundled = [] maml.net.vars_plasticity = nn.ParameterList() maml.net.plasticity = nn.ParameterList() maml.net.neuron_plasticity = nn.ParameterList() maml.net.layer_plasticity = nn.ParameterList() starting_width = 84 cur_width = starting_width num_outputs = maml.net.config[-1][1][0] for i, (name, param) in enumerate(maml.net.config): print('yo', i, name, param) if name == 'conv2d': print('in conv2d') stride = param[4] padding = param[5] #print('cur_width', cur_width, param[3]) cur_width = (cur_width + 2 * padding - param[3] + stride) // stride maml.net.vars_plasticity.append( nn.Parameter(torch.ones(*param[:4]).cuda())) maml.net.vars_plasticity.append( nn.Parameter(torch.ones(param[0]).cuda())) #self.activations_list.append([]) maml.net.plasticity.append( nn.Parameter( maml.net.init_plasticity * torch.ones(param[0], param[1] * param[2] * param[3]).cuda())) #not implemented maml.net.neuron_plasticity.append( nn.Parameter(torch.zeros(1).cuda())) #not implemented maml.net.layer_plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(1).cuda())) #not implemented feedback_var = [] for fl in range(maml.net.num_feedback_layers): print('doing fl') in_dim = maml.net.width out_dim = maml.net.width if fl == maml.net.num_feedback_layers - 1: out_dim = param[0] * cur_width * cur_width if fl == 0: in_dim = num_outputs feedback_w_shape = [out_dim, in_dim] feedback_w = nn.Parameter( torch.ones(feedback_w_shape).cuda()) feedback_b = nn.Parameter(torch.zeros(out_dim).cuda()) torch.nn.init.kaiming_normal_(feedback_w) feedback_var.append((feedback_w, feedback_b)) print('adding') maml.net.feedback_vars.append(feedback_w) maml.net.feedback_vars.append(feedback_b) #maml.net.feedback_vars_bundled.append(feedback_var) #maml.net.feedback_vars_bundled.append(None)#bias feedback -- not implemented #''' maml.net.feedback_vars_bundled.append( nn.Parameter(torch.zeros( 1))) #weight feedback -- not implemented maml.net.feedback_vars_bundled.append( nn.Parameter( torch.zeros(1))) #bias feedback -- not implemented elif name == 'linear': maml.net.vars_plasticity.append( nn.Parameter(torch.ones(*param).cuda())) maml.net.vars_plasticity.append( nn.Parameter(torch.ones(param[0]).cuda())) #self.activations_list.append([]) maml.net.plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(*param).cuda())) maml.net.neuron_plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(param[0]).cuda())) maml.net.layer_plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(1).cuda())) feedback_var = [] for fl in range(maml.net.num_feedback_layers): in_dim = maml.net.width out_dim = maml.net.width if fl == maml.net.num_feedback_layers - 1: out_dim = param[0] if fl == 0: in_dim = num_outputs feedback_w_shape = [out_dim, in_dim] feedback_w = nn.Parameter( torch.ones(feedback_w_shape).cuda()) feedback_b = nn.Parameter(torch.zeros(out_dim).cuda()) torch.nn.init.kaiming_normal_(feedback_w) feedback_var.append((feedback_w, feedback_b)) maml.net.feedback_vars.append(feedback_w) maml.net.feedback_vars.append(feedback_b) maml.net.feedback_vars_bundled.append(feedback_var) maml.net.feedback_vars_bundled.append( None) #bias feedback -- not implemented maml.init_stuff(args) maml.net.optimize_out = args.optimize_out if maml.net.optimize_out: maml.net.feedback_strength_vars.append( torch.nn.Parameter(maml.net.init_feedback_strength * torch.ones(1).cuda())) #I recently un-indented this until the maml.init_opt() line. If stuff stops working, try re-indenting this block if args.zero_non_output_plasticity: for index in range(len(maml.net.vars_plasticity) - 2): maml.net.vars_plasticity[index] = torch.nn.Parameter( maml.net.vars_plasticity[index] * 0) if args.oja or args.hebb: for index in range(len(maml.net.plasticity) - 1): if args.plasticity_rank1: maml.net.plasticity[index] = torch.nn.Parameter( torch.zeros(1).cuda()) else: maml.net.plasticity[index] = torch.nn.Parameter( maml.net.plasticity[index] * 0) maml.net.layer_plasticity[index] = torch.nn.Parameter( maml.net.layer_plasticity[index] * 0) maml.net.neuron_plasticity[index] = torch.nn.Parameter( maml.net.neuron_plasticity[index] * 0) if args.oja or args.hebb: for index in range(len(maml.net.vars_plasticity) - 2): maml.net.vars_plasticity[index] = torch.nn.Parameter( maml.net.vars_plasticity[index] * 0) if args.zero_all_plasticity: print('zeroing plasticity') for index in range(len(maml.net.vars_plasticity)): maml.net.vars_plasticity[index] = torch.nn.Parameter( maml.net.vars_plasticity[index] * 0) for index in range(len(maml.net.plasticity)): if args.plasticity_rank1: maml.net.plasticity[index] = torch.nn.Parameter( torch.zeros(1).cuda()) else: maml.net.plasticity[index] = torch.nn.Parameter( maml.net.plasticity[index] * 0) maml.net.layer_plasticity[index] = torch.nn.Parameter( maml.net.layer_plasticity[index] * 0) maml.net.neuron_plasticity[index] = torch.nn.Parameter( maml.net.neuron_plasticity[index] * 0) print('heyy', maml.net.feedback_vars) maml.init_opt() for name, param in maml.named_parameters(): param.learn = True for name, param in maml.net.named_parameters(): param.learn = True if args.freeze_out_plasticity: maml.net.plasticity[-1].requires_grad = False total_ff_vars = 2 * (6 + 2 + args.num_extra_dense_layers) frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("net.vars." + str(temp)) for temp in range(args.rln_end * 2): frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp)) for name, param in maml.named_parameters(): # logger.info(name) if name in frozen_layers: logger.info("RLN layer %s", str(name)) param.learn = False # Update the classifier list_of_params = list(filter(lambda x: x.learn, maml.parameters())) list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters())) for a in list_of_names: logger.info("TLN layer = %s", a[0]) for step in range(args.steps): ''' print('plasticity') for p in maml.net.plasticity: print(p.size(), torch.sum(p), p) ''' t1 = np.random.choice( args.classes, args.tasks, replace=False ) #np.random.randint(1, args.tasks + 1), replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step, iid=args.iid) perm = np.random.permutation(args.tasks) old = [] for i in range(y_spt.size()[0]): num = int(y_spt[i].cpu().numpy()) if num not in old: old.append(num) y_spt[i] = torch.tensor(perm[old.index(num)]) for i in range(y_qry.size()[1]): num = int(y_qry[0][i].cpu().numpy()) y_qry[0][i] = torch.tensor(perm[old.index(num)]) #print('hi', y_qry.size()) #print('y_spt', y_spt) #print('y_qry', y_qry) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() #print('heyyyy', x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size()) accs, loss = maml(x_spt, y_spt, x_qry, y_qry) if step % 1 == 0: writer.add_scalar('/metatrain/train/accuracy', accs[-1], step) logger.info('step: %d \t training acc %s', step, str(accs)) if step % 300 == 0: correct = 0 torch.save(maml.net, my_experiment.path + "learner.model") for img, target in iterator_test: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml.net(img, vars=None, bn_training=False, feature=False) pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) writer.add_scalar('/metatrain/test/classifier/accuracy', correct / len(iterator_test), step) logger.info("Test Accuracy = %s", str(correct / len(iterator_test))) correct = 0 for img, target in iterator_train: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml.net(img, vars=None, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) logger.info("Train Accuracy = %s", str(correct / len(iterator_train))) writer.add_scalar('/metatrain/train/classifier/accuracy', correct / len(iterator_train), step)
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) random.seed(args.seed) my_experiment = experiment(args.name, args, "../results/", args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') logger.setLevel(logging.INFO) total_clases = 10 frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("vars." + str(temp)) logger.info("Frozen layers = %s", " ".join(frozen_layers)) final_results_all = [] total_clases = args.schedule for tot_class in total_clases: lr_list = [ 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001, 0.000003 ] for aoo in range(0, args.runs): keep = np.random.choice(list(range(20)), tot_class, replace=False) # dataset = imgnet.MiniImagenet(args.dataset_path, mode='test', elem_per_class=30, classes=keep, seed=aoo) dataset_test = imgnet.MiniImagenet(args.dataset_path, mode='test', elem_per_class=30, test=args.test, classes=keep, seed=aoo) # Iterators used for evaluation iterator = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=True, num_workers=1) iterator_sorted = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=args.iid, num_workers=1) # print(args) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') results_mem_size = {} for mem_size in [args.memory]: max_acc = -10 max_lr = -10 for lr in lr_list: print(lr) # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]: maml = torch.load(args.model, map_location='cpu') if args.scratch: config = mf.ModelFactory.get_model("na", args.dataset) maml = learner.Learner(config) # maml = MetaLearingClassification(args, config).to(device).net maml = maml.to(device) for name, param in maml.named_parameters(): param.learn = True for name, param in maml.named_parameters(): # logger.info(name) if name in frozen_layers: # logger.info("Freeezing name %s", str(name)) param.learn = False # logger.info(str(param.requires_grad)) else: if args.reset: w = nn.Parameter(torch.ones_like(param)) # logger.info("W shape = %s", str(len(w.shape))) if len(w.shape) > 1: torch.nn.init.kaiming_normal_(w) else: w = nn.Parameter(torch.zeros_like(param)) param.data = w param.learn = True frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("vars." + str(temp)) torch.nn.init.kaiming_normal_(maml.parameters()[-2]) w = nn.Parameter(torch.zeros_like(maml.parameters()[-1])) maml.parameters()[-1].data = w for n, a in maml.named_parameters(): n = n.replace(".", "_") # logger.info("Name = %s", n) if n == "vars_14": w = nn.Parameter(torch.ones_like(a)) # logger.info("W shape = %s", str(w.shape)) torch.nn.init.kaiming_normal_(w) a.data = w if n == "vars_15": w = nn.Parameter(torch.zeros_like(a)) a.data = w correct = 0 for img, target in iterator: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml(img, vars=None, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) logger.info("Pre-epoch accuracy %s", str(correct / len(iterator))) filter_list = [ "vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5" ] logger.info("Filter list = %s", ",".join(filter_list)) list_of_names = list( map( lambda x: x[1], list( filter(lambda x: x[0] not in filter_list, maml.named_parameters())))) list_of_params = list( filter(lambda x: x.learn, maml.parameters())) list_of_names = list( filter(lambda x: x[1].learn, maml.named_parameters())) if args.scratch or args.no_freeze: print("Empty filter list") list_of_params = maml.parameters() # for x in list_of_names: logger.info("Unfrozen layer = %s", str(x[0])) opt = torch.optim.Adam(list_of_params, lr=lr) res_sampler = rep.ReservoirSampler(mem_size) for _ in range(0, args.epoch): for img, y in iterator_sorted: if mem_size > 0: res_sampler.update_buffer(zip(img, y)) res_sampler.update_observations(len(img)) img = img.to(device) y = y.to(device) img2, y2 = res_sampler.sample_buffer(8) img2 = img2.to(device) y2 = y2.to(device) img = torch.cat([img, img2], dim=0) y = torch.cat([y, y2], dim=0) else: img = img.to(device) y = y.to(device) pred = maml(img) opt.zero_grad() loss = F.cross_entropy(pred, y) loss.backward() opt.step() logger.info("Result after one epoch for LR = %f", lr) correct = 0 for img, target in iterator: img = img.to(device) target = target.to(device) logits_q = maml(img, vars=None, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) # print("Pred=", pred_q) # print("Target=", target) correct += torch.eq(pred_q, target).sum().item() / len(img) logger.info(str(correct / len(iterator))) if (correct / len(iterator) > max_acc): max_acc = correct / len(iterator) max_lr = lr lr_list = [max_lr] results_mem_size[mem_size] = (max_acc, max_lr) logger.info("Final Max Result = %s", str(max_acc)) writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class) final_results_all.append((tot_class, results_mem_size)) print("A= ", results_mem_size) logger.info("Final results = %s", str(results_mem_size)) my_experiment.results["Final Results"] = final_results_all my_experiment.store_json() print("FINAL RESULTS = ", final_results_all) writer.close()
def main(args): utils.set_seed(args.seed) my_experiment = experiment(args.name, args, "../results/", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') # Using first 963 classes of the omniglot as the meta-training set args.classes = list(range(64)) # args.traj_classes = list(range(int(64 / 2), 963)) dataset = imgnet.MiniImagenet(args.dataset_path, mode='train') dataset_test = imgnet.MiniImagenet(args.dataset_path, mode='test') # Iterators used for evaluation iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes, dataset) config = mf.ModelFactory.get_model("na", "imagenet") if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') maml = MetaLearingClassification(args, config).to(device) utils.freeze_layers(args.rln, maml) for step in range(args.steps): t1 = np.random.choice(args.classes, args.tasks, replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step, reset=not args.no_reset) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() accs, loss = maml(x_spt, y_spt, x_qry, y_qry) # Evaluation during training for sanity checks if step % 40 == 39: writer.add_scalar('/metatrain/train/accuracy', accs[-1], step) logger.info('step: %d \t training acc %s', step, str(accs)) if step % 300 == 299: utils.log_accuracy(maml, my_experiment, iterator_test, device, writer, step) utils.log_accuracy(maml, my_experiment, iterator_train, device, writer, step)
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) random.seed(args.seed) my_experiment = experiment(args.name, args, "./evals/", args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") ver = 0 while os.path.exists(args.modelX + "_" + str(ver)): ver += 1 args.modelX = args.modelX + "_" + str(ver-1) + "/learner.model" logger = logging.getLogger('experiment') logger.setLevel(logging.INFO) total_clases = 10 total_ff_vars = 2*(6 + 2 + args.num_extra_dense_layers) frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("vars." + str(temp)) for temp in range(args.rln_end * 2): frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp)) #logger.info("Frozen layers = %s", " ".join(frozen_layers)) # final_results_all = [] total_clases = [5] if args.twentyclass: total_clases = [20] if args.twotask: total_clases = [2, 10] if args.fiftyclass: total_clases = [50] if args.tenclass: total_clases = [10] if args.fiveclass: total_clases = [5] print('yooo', total_clases) for tot_class in total_clases: avg_perf = 0.0 print('TOT_CLASS', tot_class) lr_list = [0]#[0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001] for aoo in range(0, args.runs): #print('run', aoo) keep = np.random.choice(list(range(650)), tot_class, replace=False) if args.dataset == "imagenet": keep = np.random.choice(list(range(20)), tot_class, replace=False) dataset = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, seed=aoo) dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test', elem_per_class=30, classes=keep, test=args.test, seed=aoo) iterator = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=True, num_workers=1) iterator_sorted = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) if args.dataset == "omniglot": dataset = utils.remove_classes_omni( df.DatasetFactory.get_dataset("omniglot", train=True, background=False), keep) iterator_sorted = torch.utils.data.DataLoader( utils.iterator_sorter_omni(dataset, False, classes=total_clases), batch_size=1, shuffle=False, num_workers=2) dataset = utils.remove_classes_omni( df.DatasetFactory.get_dataset("omniglot", train=not args.test, background=False), keep) iterator = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=1) elif args.dataset == "CIFAR100": keep = np.random.choice(list(range(50, 100)), tot_class) dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=True), keep) iterator_sorted = torch.utils.data.DataLoader( utils.iterator_sorter(dataset, False, classes=tot_class), batch_size=16, shuffle=False, num_workers=2) dataset = utils.remove_classes(df.DatasetFactory.get_dataset(args.dataset, train=False), keep) iterator = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=1) # sampler = ts.MNISTSampler(list(range(0, total_clases)), dataset) # #print(args) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') results_mem_size = {} #print("LEN", len(iterator_sorted)) for mem_size in [args.memory]: max_acc = -10 max_lr = -10 for lr in lr_list: #torch.cuda.empty_cache() #print(lr) # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]: maml = torch.load(args.modelX, map_location='cpu') if args.scratch: config = mf.ModelFactory.get_model(args.model_type, args.dataset) maml = learner.Learner(config, lr) # maml = MetaLearingClassification(args, config).to(device).net #maml.update_lr = lr maml = maml.to(device) for name, param in maml.named_parameters(): param.learn = True for name, param in maml.named_parameters(): #if name.find("feedback_strength_vars") != -1: # print(name, param) if name in frozen_layers: # logger.info("Freeezing name %s", str(name)) param.learn = False # logger.info(str(param.requires_grad)) else: if args.reset: w = nn.Parameter(torch.ones_like(param)) # logger.info("W shape = %s", str(len(w.shape))) if len(w.shape) > 1: torch.nn.init.kaiming_normal_(w) else: w = nn.Parameter(torch.zeros_like(param)) param.data = w param.learn = True frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("vars." + str(temp)) #torch.nn.init.kaiming_normal_(maml.parameters()[-2]) #w = nn.Parameter(torch.zeros_like(maml.parameters()[-1])) #maml.parameters()[-1].data = w for n, a in maml.named_parameters(): n = n.replace(".", "_") # logger.info("Name = %s", n) if n == "vars_"+str(14+2*args.num_extra_dense_layers): pass #w = nn.Parameter(torch.ones_like(a)) # logger.info("W shape = %s", str(w.shape)) #torch.nn.init.kaiming_normal_(w) #a.data = w if n == "vars_"+str(15+2*args.num_extra_dense_layers): pass #w = nn.Parameter(torch.zeros_like(a)) #a.data = w #for fv in maml.feedback_vars: # w = nn.Parameter(torch.zeros_like(fv)) # fv.data = w #for fv in maml.feedback_strength_vars: # w = nn.Parameter(torch.ones_like(fv)) # fv.data = w correct = 0 for img, target in iterator: #print('size', target.size()) target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])])) with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml(img, vars=None, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) #logger.info("Pre-epoch accuracy %s", str(correct / len(iterator))) filter_list = ["vars.0", "vars.1", "vars.2", "vars.3", "vars.4", "vars.5"] #logger.info("Filter list = %s", ",".join(filter_list)) list_of_names = list( map(lambda x: x[1], list(filter(lambda x: x[0] not in filter_list, maml.named_parameters())))) list_of_params = list(filter(lambda x: x.learn, maml.parameters())) list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters())) if args.scratch or args.no_freeze: print("Empty filter list") list_of_params = maml.parameters() # #for x in list_of_names: # logger.info("Unfrozen layer = %s", str(x[0])) opt = torch.optim.Adam(list_of_params, lr=lr) fast_weights = None if args.randomize_plastic_weights: maml.randomize_plastic_weights() if args.zero_plastic_weights: maml.zero_plastic_weights() res_sampler = rep.ReservoirSampler(mem_size) iterator_sorted_new = [] iter_count = 0 for img, y in iterator_sorted: y = torch.tensor(np.array([list(keep).index(int(y.cpu().numpy()[i])) for i in range(y.size()[0])])) if iter_count % 15 >= args.shots: iter_count += 1 continue iterator_sorted_new.append((img, y)) iter_count += 1 iterator_sorted = [] perm = np.random.permutation(len(iterator_sorted_new)) for i in range(len(iterator_sorted_new)): if args.iid: iterator_sorted.append(iterator_sorted_new[perm[i]]) else: iterator_sorted.append(iterator_sorted_new[i]) for iter in range(0, args.epoch): iter_count = 0 imgs = [] ys = [] for img, y in iterator_sorted: #print('iter count', iter_count) #print('y is', y) #if iter_count % 15 >= args.shots: # iter_count += 1 # continue iter_count += 1 #with torch.no_grad(): if args.memory == 0: img = img.to(device) y = y.to(device) else: res_sampler.update_buffer(zip(img, y)) res_sampler.update_observations(len(img)) img = img.to(device) y = y.to(device) img2, y2 = res_sampler.sample_buffer(8) img2 = img2.to(device) y2 = y2.to(device) img = torch.cat([img, img2], dim=0) y = torch.cat([y, y2], dim=0) #print('img size', img.size()) imgs.append(img) ys.append(y) if not args.batch_learning: logits = maml(img, vars=fast_weights) fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb) if args.batch_learning: y = torch.cat(ys, 0) img = torch.cat(imgs, 0) logits = maml(img, vars=fast_weights) fast_weights = maml.getOjaUpdate(y, logits, fast_weights, hebbian=args.hebb) #logger.info("Result after one epoch for LR = %f", lr) correct = 0 for img, target in iterator: target = torch.tensor(np.array([list(keep).index(int(target.cpu().numpy()[i])) for i in range(target.size()[0])])) img = img.to(device) target = target.to(device) logits_q = maml(img, vars=fast_weights, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) #logger.info(str(correct / len(iterator))) if (correct / len(iterator) > max_acc): max_acc = correct / len(iterator) max_lr = lr del maml #del maml #del fast_weights lr_list = [max_lr] #print('result', max_acc) results_mem_size[mem_size] = (max_acc, max_lr) #logger.info("Final Max Result = %s", str(max_acc)) writer.add_scalar('/finetune/best_' + str(aoo), max_acc, tot_class) avg_perf += max_acc / args.runs #TODO: change this if/when I ever use memory -- can't choose max memory size differently for each run! print('avg perf', avg_perf * args.runs / (1+aoo)) final_results_all.append((tot_class, results_mem_size)) #writer.add_scalar('performance', avg_perf, tot_class) #print("A= ", results_mem_size) #logger.info("Final results = %s", str(results_mem_size)) my_experiment.results["Final Results"] = final_results_all my_experiment.store_json() np.save('evals/final_results_'+args.orig_name+'.npy', final_results_all) #print("FINAL RESULTS = ", final_results_all) writer.close()