def main(): p = class_parser.Parser() total_seeds = len(p.parse_known_args()[0].seed) rank = p.parse_known_args()[0].rank all_args = vars(p.parse_known_args()[0]) print("All args = ", all_args) args = utils.get_run(vars(p.parse_known_args()[0]), rank) utils.set_seed(args['seed']) my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') # Using first 963 classes of the omniglot as the meta-training set args['classes'] = list(range(963)) args['traj_classes'] = list(range(int(963/2), 963)) dataset = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=True,path=args["path"], all=True) dataset_test = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=False, path=args["path"], all=True) # Iterators used for evaluation iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) sampler = ts.SamplerFactory.get_sampler(args['dataset'], args['classes'], dataset, dataset_test) config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000) gpu_to_use = rank % args["gpus"] if torch.cuda.is_available(): device = torch.device('cuda:' + str(gpu_to_use)) logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use)) else: device = torch.device('cpu') maml = MetaLearingClassification(args, config).to(device) for step in range(args['steps']): t1 = np.random.choice(args['traj_classes'], args['tasks'], replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data(d_traj_iterators, d_rand_iterator, steps=args['update_step'], reset=not args['no_reset']) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device) accs, loss = maml(x_spt, y_spt, x_qry, y_qry)
def main(args): utils.set_seed(args.seed) my_experiment = experiment(args.name, args, "results/", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') # Using first 963 classes of the omniglot as the meta-training set args.classes = list(range(963)) args.traj_classes = list(range(int(963 / 2), 963)) dataset = df.DatasetFactory.get_dataset(args.dataset, background=True, train=True, all=True) dataset_test = df.DatasetFactory.get_dataset(args.dataset, background=True, train=False, all=True) # Iterators used for evaluation iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes, dataset, dataset_test) config = mf.ModelFactory.get_model("na", args.dataset) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') maml = MetaLearingClassification(args, config).to(device) utils.freeze_layers(args.rln, maml) for step in range(args.steps): t1 = np.random.choice(args.traj_classes, args.tasks, replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step, reset=not args.no_reset) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() accs, loss = maml(x_spt, y_spt, x_qry, y_qry) # Evaluation during training for sanity checks if step % 40 == 39: writer.add_scalar('/metatrain/train/accuracy', accs[-1], step) logger.info('step: %d \t training acc %s', step, str(accs)) if step % 300 == 299: utils.log_accuracy(maml, my_experiment, iterator_test, device, writer, step) utils.log_accuracy(maml, my_experiment, iterator_train, device, writer, step)
def main(): p = class_parser.Parser() total_seeds = len(p.parse_known_args()[0].seed) rank = p.parse_known_args()[0].rank all_args = vars(p.parse_known_args()[0]) print("All args = ", all_args) args = utils.get_run(vars(p.parse_known_args()[0]), rank) utils.set_seed(args["seed"]) if args["log_root"]: log_root = osp.join("./results", args["log_root"]) + "/" else: log_root = osp.join("./results/") my_experiment = experiment( args["name"], args, log_root, commit_changes=False, rank=0, seed=args["seed"], ) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger("experiment") # Using first 963 classes of the omniglot as the meta-training set # args["classes"] = list(range(963)) args["classes"] = list(range(args["num_classes"])) print("Using classes:", args["num_classes"]) # logger.info("Using classes:", str(args["num_classes"])) # args["traj_classes"] = list(range(int(963 / 2), 963)) if torch.cuda.is_available(): device = torch.device("cuda") use_cuda = True else: device = torch.device("cpu") use_cuda = False dataset_spt = df.DatasetFactory.get_dataset( args["dataset"], background=True, train=True, path=args["path"], # all=True, # all=False, all=args["all"], prefetch_gpu=args["prefetch_gpu"], device=device, resize=args["resize"], augment=args["augment_spt"], ) dataset_qry = df.DatasetFactory.get_dataset( args["dataset"], background=True, train=True, path=args["path"], # all=True, # all=False, all=args["all"], prefetch_gpu=args["prefetch_gpu"], device=device, resize=args["resize"], augment=args["augment_qry"], ) dataset_test = df.DatasetFactory.get_dataset( args["dataset"], background=True, train=False, path=args["path"], # all=True, # all=False, all=args["all"], resize=args["resize"], # augment=args["augment"], ) logger.info( f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}" ) # print(f"Support size: {len(dataset_spt)}, Query size: {len(dataset_qry)}, test size: {len(dataset_test)}") pin_memory = use_cuda if args["prefetch_gpu"]: num_workers = 0 pin_memory = False else: num_workers = args["num_workers"] # Iterators used for evaluation iterator_test = torch.utils.data.DataLoader( dataset_test, batch_size=5, shuffle=True, num_workers=0, # pin_memory=pin_memory, ) iterator_train = torch.utils.data.DataLoader( dataset_spt, batch_size=5, shuffle=True, num_workers=0, # pin_memory=pin_memory, ) logger.info("Support sampler:") sampler_spt = ts.SamplerFactory.get_sampler( args["dataset"], args["classes"], dataset_spt, dataset_test, prefetch_gpu=args["prefetch_gpu"], use_cuda=use_cuda, num_workers=0, ) logger.info("Query sampler:") sampler_qry = ts.SamplerFactory.get_sampler( args["dataset"], args["classes"], dataset_qry, dataset_test, prefetch_gpu=args["prefetch_gpu"], use_cuda=use_cuda, num_workers=0, ) config = mf.ModelFactory.get_model( "na", args["dataset"], output_dimension=1000, resize=args["resize"], ) gpu_to_use = rank % args["gpus"] if torch.cuda.is_available(): device = torch.device("cuda:" + str(gpu_to_use)) logger.info("Using gpu : %s", "cuda:" + str(gpu_to_use)) else: device = torch.device("cpu") maml = MetaLearingClassification(args, config).to(device) for step in range(args["steps"]): t1 = np.random.choice(args["classes"], args["tasks"], replace=False) d_traj_iterators_spt = [] d_traj_iterators_qry = [] for t in t1: d_traj_iterators_spt.append(sampler_spt.sample_task([t])) d_traj_iterators_qry.append(sampler_qry.sample_task([t])) d_rand_iterator = sampler_spt.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data_paper( d_traj_iterators_spt, d_traj_iterators_qry, d_rand_iterator, steps=args["update_step"], reset=not args["no_reset"], ) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = ( x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device), ) # accs, loss = maml(x_spt, y_spt, x_qry, y_qry) # Evaluation during training for sanity checks if step % 40 == 5: writer.add_scalar("/metatrain/train/accuracy", accs[-1], step) writer.add_scalar("/metatrain/train/loss", loss[-1], step) writer.add_scalar("/metatrain/train/accuracy0", accs[0], step) writer.add_scalar("/metatrain/train/loss0", loss[0], step) logger.info("step: %d \t training acc %s", step, str(accs)) logger.info("step: %d \t training loss %s", step, str(loss)) # Currently useless if (step % 300 == 3) or ((step + 1) == args["steps"]): torch.save(maml.net, my_experiment.path + "learner.model")
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) my_experiment = experiment(args.name, args, "/data5/jlindsey/continual/results", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') args.classes = list(range(963)) print('dataset', args.dataset, args.dataset == "imagenet") if args.dataset != "imagenet": dataset = df.DatasetFactory.get_dataset(args.dataset, background=True, train=True, all=True) dataset_test = df.DatasetFactory.get_dataset(args.dataset, background=True, train=False, all=True) else: args.classes = list(range(64)) dataset = imgnet.MiniImagenet(args.imagenet_path, mode='train') dataset_test = imgnet.MiniImagenet(args.imagenet_path, mode='test') iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) logger.info("Train set length = %d", len(iterator_train) * 5) logger.info("Test set length = %d", len(iterator_test) * 5) sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes, dataset, dataset_test) config = mf.ModelFactory.get_model( args.model_type, args.dataset, width=args.width, num_extra_dense_layers=args.num_extra_dense_layers) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') if args.oja or args.hebb: maml = OjaMetaLearingClassification(args, config).to(device) else: print('starting up') maml = MetaLearingClassification(args, config).to(device) import sys if args.from_saved: maml.net = torch.load(args.model) if args.use_derivative: maml.net.use_derivative = True maml.net.optimize_out = args.optimize_out if maml.net.optimize_out: maml.net.feedback_strength_vars.append( torch.nn.Parameter(maml.net.init_feedback_strength * torch.ones(1).cuda())) if args.reset_feedback_strength: for fv in maml.net.feedback_strength_vars: w = nn.Parameter(torch.ones_like(fv) * args.feedback_strength) fv.data = w if args.reset_feedback_vars: print('howdy', maml.net.num_feedback_layers) maml.net.feedback_vars = nn.ParameterList() maml.net.feedback_vars_bundled = [] maml.net.vars_plasticity = nn.ParameterList() maml.net.plasticity = nn.ParameterList() maml.net.neuron_plasticity = nn.ParameterList() maml.net.layer_plasticity = nn.ParameterList() starting_width = 84 cur_width = starting_width num_outputs = maml.net.config[-1][1][0] for i, (name, param) in enumerate(maml.net.config): print('yo', i, name, param) if name == 'conv2d': print('in conv2d') stride = param[4] padding = param[5] #print('cur_width', cur_width, param[3]) cur_width = (cur_width + 2 * padding - param[3] + stride) // stride maml.net.vars_plasticity.append( nn.Parameter(torch.ones(*param[:4]).cuda())) maml.net.vars_plasticity.append( nn.Parameter(torch.ones(param[0]).cuda())) #self.activations_list.append([]) maml.net.plasticity.append( nn.Parameter( maml.net.init_plasticity * torch.ones(param[0], param[1] * param[2] * param[3]).cuda())) #not implemented maml.net.neuron_plasticity.append( nn.Parameter(torch.zeros(1).cuda())) #not implemented maml.net.layer_plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(1).cuda())) #not implemented feedback_var = [] for fl in range(maml.net.num_feedback_layers): print('doing fl') in_dim = maml.net.width out_dim = maml.net.width if fl == maml.net.num_feedback_layers - 1: out_dim = param[0] * cur_width * cur_width if fl == 0: in_dim = num_outputs feedback_w_shape = [out_dim, in_dim] feedback_w = nn.Parameter( torch.ones(feedback_w_shape).cuda()) feedback_b = nn.Parameter(torch.zeros(out_dim).cuda()) torch.nn.init.kaiming_normal_(feedback_w) feedback_var.append((feedback_w, feedback_b)) print('adding') maml.net.feedback_vars.append(feedback_w) maml.net.feedback_vars.append(feedback_b) #maml.net.feedback_vars_bundled.append(feedback_var) #maml.net.feedback_vars_bundled.append(None)#bias feedback -- not implemented #''' maml.net.feedback_vars_bundled.append( nn.Parameter(torch.zeros( 1))) #weight feedback -- not implemented maml.net.feedback_vars_bundled.append( nn.Parameter( torch.zeros(1))) #bias feedback -- not implemented elif name == 'linear': maml.net.vars_plasticity.append( nn.Parameter(torch.ones(*param).cuda())) maml.net.vars_plasticity.append( nn.Parameter(torch.ones(param[0]).cuda())) #self.activations_list.append([]) maml.net.plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(*param).cuda())) maml.net.neuron_plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(param[0]).cuda())) maml.net.layer_plasticity.append( nn.Parameter(maml.net.init_plasticity * torch.ones(1).cuda())) feedback_var = [] for fl in range(maml.net.num_feedback_layers): in_dim = maml.net.width out_dim = maml.net.width if fl == maml.net.num_feedback_layers - 1: out_dim = param[0] if fl == 0: in_dim = num_outputs feedback_w_shape = [out_dim, in_dim] feedback_w = nn.Parameter( torch.ones(feedback_w_shape).cuda()) feedback_b = nn.Parameter(torch.zeros(out_dim).cuda()) torch.nn.init.kaiming_normal_(feedback_w) feedback_var.append((feedback_w, feedback_b)) maml.net.feedback_vars.append(feedback_w) maml.net.feedback_vars.append(feedback_b) maml.net.feedback_vars_bundled.append(feedback_var) maml.net.feedback_vars_bundled.append( None) #bias feedback -- not implemented maml.init_stuff(args) maml.net.optimize_out = args.optimize_out if maml.net.optimize_out: maml.net.feedback_strength_vars.append( torch.nn.Parameter(maml.net.init_feedback_strength * torch.ones(1).cuda())) #I recently un-indented this until the maml.init_opt() line. If stuff stops working, try re-indenting this block if args.zero_non_output_plasticity: for index in range(len(maml.net.vars_plasticity) - 2): maml.net.vars_plasticity[index] = torch.nn.Parameter( maml.net.vars_plasticity[index] * 0) if args.oja or args.hebb: for index in range(len(maml.net.plasticity) - 1): if args.plasticity_rank1: maml.net.plasticity[index] = torch.nn.Parameter( torch.zeros(1).cuda()) else: maml.net.plasticity[index] = torch.nn.Parameter( maml.net.plasticity[index] * 0) maml.net.layer_plasticity[index] = torch.nn.Parameter( maml.net.layer_plasticity[index] * 0) maml.net.neuron_plasticity[index] = torch.nn.Parameter( maml.net.neuron_plasticity[index] * 0) if args.oja or args.hebb: for index in range(len(maml.net.vars_plasticity) - 2): maml.net.vars_plasticity[index] = torch.nn.Parameter( maml.net.vars_plasticity[index] * 0) if args.zero_all_plasticity: print('zeroing plasticity') for index in range(len(maml.net.vars_plasticity)): maml.net.vars_plasticity[index] = torch.nn.Parameter( maml.net.vars_plasticity[index] * 0) for index in range(len(maml.net.plasticity)): if args.plasticity_rank1: maml.net.plasticity[index] = torch.nn.Parameter( torch.zeros(1).cuda()) else: maml.net.plasticity[index] = torch.nn.Parameter( maml.net.plasticity[index] * 0) maml.net.layer_plasticity[index] = torch.nn.Parameter( maml.net.layer_plasticity[index] * 0) maml.net.neuron_plasticity[index] = torch.nn.Parameter( maml.net.neuron_plasticity[index] * 0) print('heyy', maml.net.feedback_vars) maml.init_opt() for name, param in maml.named_parameters(): param.learn = True for name, param in maml.net.named_parameters(): param.learn = True if args.freeze_out_plasticity: maml.net.plasticity[-1].requires_grad = False total_ff_vars = 2 * (6 + 2 + args.num_extra_dense_layers) frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("net.vars." + str(temp)) for temp in range(args.rln_end * 2): frozen_layers.append("net.vars." + str(total_ff_vars - 1 - temp)) for name, param in maml.named_parameters(): # logger.info(name) if name in frozen_layers: logger.info("RLN layer %s", str(name)) param.learn = False # Update the classifier list_of_params = list(filter(lambda x: x.learn, maml.parameters())) list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters())) for a in list_of_names: logger.info("TLN layer = %s", a[0]) for step in range(args.steps): ''' print('plasticity') for p in maml.net.plasticity: print(p.size(), torch.sum(p), p) ''' t1 = np.random.choice( args.classes, args.tasks, replace=False ) #np.random.randint(1, args.tasks + 1), replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step, iid=args.iid) perm = np.random.permutation(args.tasks) old = [] for i in range(y_spt.size()[0]): num = int(y_spt[i].cpu().numpy()) if num not in old: old.append(num) y_spt[i] = torch.tensor(perm[old.index(num)]) for i in range(y_qry.size()[1]): num = int(y_qry[0][i].cpu().numpy()) y_qry[0][i] = torch.tensor(perm[old.index(num)]) #print('hi', y_qry.size()) #print('y_spt', y_spt) #print('y_qry', y_qry) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() #print('heyyyy', x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size()) accs, loss = maml(x_spt, y_spt, x_qry, y_qry) if step % 1 == 0: writer.add_scalar('/metatrain/train/accuracy', accs[-1], step) logger.info('step: %d \t training acc %s', step, str(accs)) if step % 300 == 0: correct = 0 torch.save(maml.net, my_experiment.path + "learner.model") for img, target in iterator_test: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml.net(img, vars=None, bn_training=False, feature=False) pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) writer.add_scalar('/metatrain/test/classifier/accuracy', correct / len(iterator_test), step) logger.info("Test Accuracy = %s", str(correct / len(iterator_test))) correct = 0 for img, target in iterator_train: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml.net(img, vars=None, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) logger.info("Train Accuracy = %s", str(correct / len(iterator_train))) writer.add_scalar('/metatrain/train/classifier/accuracy', correct / len(iterator_train), step)
def main(args): utils.set_seed(args.seed) my_experiment = experiment(args.name, args, "../results/", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') # Using first 963 classes of the omniglot as the meta-training set args.classes = list(range(963)) args.traj_classes = list(range(963)) # dataset = df.DatasetFactory.get_dataset(args.dataset, background=True, train=True, all=True) dataset_test = df.DatasetFactory.get_dataset(args.dataset, background=False, train=True, all=True) sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes, dataset, dataset) sampler_test = ts.SamplerFactory.get_sampler(args.dataset, list(range(600)), dataset_test, dataset_test) config = mf.ModelFactory.get_model("na", "omniglot-fc") if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') maml = MetaLearingClassification(args, config).to(device) utils.freeze_layers(args.rln, maml) for step in range(args.steps): t1 = np.random.choice(args.traj_classes, args.tasks, replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step, reset=not args.no_reset) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() accs, loss = maml(x_spt, y_spt, x_qry, y_qry) # # Evaluation during training for sanity checks if step % 20 == 0: writer.add_scalar('/metatrain/train/accuracy', accs[-1], step) logger.info('step: %d \t training acc %s', step, str(accs)) logger.info("Loss = %s", str(loss[-1].item())) if step % 600 == 599: torch.save(maml.net, my_experiment.path + "learner.model") accs_avg = None for temp_temp in range(0, 40): t1_test = np.random.choice(list(range(600)), args.tasks, replace=False) d_traj_test_iterators = [] for t in t1_test: d_traj_test_iterators.append(sampler_test.sample_task([t])) x_spt, y_spt, x_qry, y_qry = maml.sample_few_shot_training_data( d_traj_test_iterators, None, steps=args.update_step, reset=not args.no_reset) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() accs, loss = maml.finetune(x_spt, y_spt, x_qry, y_qry) if accs_avg is None: accs_avg = accs else: accs_avg += accs logger.info("Loss = %s", str(loss[-1].item())) writer.add_scalar('/metatest/train/accuracy', accs_avg[-1] / 40, step) logger.info('TEST: step: %d \t testing acc %s', step, str(accs_avg / 40))
def main(args): utils.set_seed(args.seed) my_experiment = experiment(args.name, args, "../results/", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger("experiment") # Using first 963 classes of the omniglot as the meta-training set args.classes = list(range(963)) if torch.cuda.is_available(): device = torch.device("cuda") use_cuda = True else: device = torch.device("cpu") use_cuda = False dataset = df.DatasetFactory.get_dataset( args.dataset, background=True, train=True, all=True, prefetch_gpu=args.prefetch_gpu, device=device, ) dataset_test = dataset # dataset_test = df.DatasetFactory.get_dataset( # args.dataset, background=True, train=False, all=True # ) # Iterators used for evaluation iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) sampler = ts.SamplerFactory.get_sampler( args.dataset, args.classes, dataset, dataset_test, prefetch_gpu=args.prefetch_gpu, use_cuda=use_cuda, ) config = mf.ModelFactory.get_model(args.treatment, args.dataset) maml = MetaLearingClassification(args, config, args.treatment).to(device) if args.checkpoint: checkpoint = torch.load(args.saved_model, map_location="cpu") for idx in range(len(checkpoint)): maml.net.parameters()[idx].data = checkpoint.parameters()[idx].data maml = maml.to(device) utils.freeze_layers(args.rln, maml) for step in range(args.steps): t1 = np.random.choice(args.classes, args.tasks, replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step, reset=not args.no_reset, ) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = ( x_spt.cuda(), y_spt.cuda(), x_qry.cuda(), y_qry.cuda(), ) accs, loss = maml(x_spt, y_spt, x_qry, y_qry) # , args.tasks) # Evaluation during training for sanity checks if step % 40 == 0: # writer.add_scalar('/metatrain/train/accuracy', accs, step) logger.info("step: %d \t training acc %s", step, str(accs)) if step % 100 == 0 or step == 19999: torch.save(maml.net, args.model_name) if step % 2000 == 0 and step != 0: utils.log_accuracy(maml, my_experiment, iterator_test, device, writer, step) utils.log_accuracy(maml, my_experiment, iterator_train, device, writer, step)
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) my_experiment = experiment(args.name, args, "../results/", commit_changes=args.commit) writer = SummaryWriter(my_experiment.path + "tensorboard") logger = logging.getLogger('experiment') args.classes = list(range(963)) dataset = df.DatasetFactory.get_dataset(args.dataset, background=True, train=True) dataset_test = df.DatasetFactory.get_dataset(args.dataset, background=True, train=False) iterator_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=True, num_workers=1) iterator_train = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) logger.info("Train set length = %d", len(iterator_train) * 5) logger.info("Test set length = %d", len(iterator_test) * 5) sampler = ts.SamplerFactory.get_sampler(args.dataset, args.classes, dataset, dataset_test) config = mf.ModelFactory.get_model("na", args.dataset) if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') maml = MetaLearingClassification(args, config).to(device) for name, param in maml.named_parameters(): param.learn = True for name, param in maml.net.named_parameters(): param.learn = True frozen_layers = [] for temp in range(args.rln * 2): frozen_layers.append("net.vars." + str(temp)) for name, param in maml.named_parameters(): logger.info(name) if name in frozen_layers: logger.info("Freeezing name %s", str(name)) param.learn = False # Update the classifier list_of_params = list(filter(lambda x: x.learn, maml.parameters())) list_of_names = list(filter(lambda x: x[1].learn, maml.named_parameters())) for a in list_of_names: logger.info("Unfrozen layers for rep learning = %s", a[0]) for step in range(args.epoch): t1 = np.random.choice(args.classes, np.random.randint(1, args.tasks + 1), replace=False) d_traj_iterators = [] for t in t1: d_traj_iterators.append(sampler.sample_task([t])) d_rand_iterator = sampler.get_complete_iterator() x_spt, y_spt, x_qry, y_qry = maml.sample_training_data( d_traj_iterators, d_rand_iterator, steps=args.update_step) if torch.cuda.is_available(): x_spt, y_spt, x_qry, y_qry = x_spt.cuda(), y_spt.cuda( ), x_qry.cuda(), y_qry.cuda() accs, loss = maml(x_spt, y_spt, x_qry, y_qry) if step % 40 == 0: writer.add_scalar('/metatrain/train/accuracy', accs[-1], step) logger.info('step: %d \t training acc %s', step, str(accs)) if step % 300 == 0: correct = 0 torch.save(maml.net, my_experiment.path + "learner.model") for img, target in iterator_test: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml.net(img, vars=None, bn_training=False, feature=False) pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) writer.add_scalar('/metatrain/test/classifier/accuracy', correct / len(iterator_test), step) logger.info("Test Accuracy = %s", str(correct / len(iterator_test))) correct = 0 for img, target in iterator_train: with torch.no_grad(): img = img.to(device) target = target.to(device) logits_q = maml.net(img, vars=None, bn_training=False, feature=False) pred_q = (logits_q).argmax(dim=1) correct += torch.eq(pred_q, target).sum().item() / len(img) logger.info("Train Accuracy = %s", str(correct / len(iterator_train))) writer.add_scalar('/metatrain/train/classifier/accuracy', correct / len(iterator_train), step)