def create_model_and_load_chkpt(args, dataset_name, checkpoint_path): print("\n", "--" * 20, "MODEL", "--" * 20) if args.model_type == 'resnet_12': if 'miniImagenet' in dataset_name or 'CUB' in dataset_name: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=5, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection)) else: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=2, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection)) elif args.model_type in ['conv64', 'conv48', 'conv32']: dim = int(args.model_type[-2:]) model = shallow_conv.ShallowConv(z_dim=dim, h_dim=dim, num_classes=args.num_classes_train, x_width=image_size, classifier_type=args.classifier_type, projection=str2bool(args.projection)) elif args.model_type == 'wide_resnet28_10': model = wide_resnet.wrn28_10(projection=str2bool(args.projection), classifier_type=args.classifier_type) elif args.model_type == 'wide_resnet16_10': model = wide_resnet.wrn16_10(projection=str2bool(args.projection), classifier_type=args.classifier_type) else: raise ValueError('Unrecognized model type {}'.format(args.model_type)) print("Model\n" + "==" * 27) print(model) print(f"loading model from {checkpoint_path}") model_dict = model.state_dict() chkpt = torch.load(checkpoint_path, map_location=torch.device('cpu')) chkpt_state_dict = chkpt['model'] chkpt_state_dict_cpy = chkpt_state_dict.copy() # remove "module." from key, possibly present as it was dumped by data-parallel for key in chkpt_state_dict_cpy.keys(): if 'module.' in key: new_key = re.sub('module\.', '', key) chkpt_state_dict[new_key] = chkpt_state_dict.pop(key) chkpt_state_dict = { k: v for k, v in chkpt_state_dict.items() if k in model_dict } model_dict.update(chkpt_state_dict) updated_keys = set(model_dict).intersection(set(chkpt_state_dict)) print(f"Updated {len(updated_keys)} keys using chkpt") print("Following keys updated :", "\n".join(sorted(updated_keys))) missed_keys = set(model_dict).difference(set(chkpt_state_dict)) print(f"Missed {len(missed_keys)} keys") print("Following keys missed :", "\n".join(sorted(missed_keys))) model.load_state_dict(model_dict) # Multi-gpu support and device setup os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"]) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() return model
def main(args): #################################################### # LOGGING AND SAVING # #################################################### args.output_folder = ensure_path('./runs/{0}'.format(args.output_folder)) writer = SummaryWriter(args.output_folder) with open(f'{args.output_folder}/config.txt', 'w') as config_txt: for k, v in sorted(vars(args).items()): config_txt.write(f'{k}: {v}\n') save_folder = args.output_folder #################################################### # DATASET AND DATALOADER CREATION # #################################################### # json paths dataset_name = args.dataset_path.split('/')[-1] image_size = args.img_side_len # Following is needed when same train config is used for both 5w5s and 5w1s evaluations. # This is the case in the case of SVM when 5w15s5q is used for both 5w5s and 5w1s evaluations. all_n_shot_vals = [args.n_shot_val, 1] if str2bool( args.do_one_shot_eval_too) else [args.n_shot_val] # base_class_generalization = dataset_name.lower() in ['miniimagenet', 'fc100-base', 'cifar-fs-base'] train_file = os.path.join(args.dataset_path, 'base.json') val_file = os.path.join(args.dataset_path, 'val.json') test_file = os.path.join(args.dataset_path, 'novel.json') ''' if base_class_generalization: base_test_file = os.path.join(args.dataset_path, 'base_test.json') ''' print("Dataset name", dataset_name, "image_size", image_size) if args.algorithm != 'SupervisedBaseline': print("all_n_shot_vals", all_n_shot_vals) # print("base_class_generalization:", base_class_generalization) """ 1. Create FedDataset object, which handles preloading of images for every single client 2. Create FedDataLoader object from FedDataset, which samples a batch of clients. """ print("\n", "--" * 20, "TRAIN", "--" * 20) if args.algorithm in ["SupervisedBaseline", "TransferLearning"]: """ For Transfer Learning we create a SimpleFedDataset. The augmentation is decided by query_aug flag. """ train_dataset = SimpleFedDataset( json_path=train_file, image_size=(image_size, image_size), # has to be a (h, w) tuple preload=str2bool(args.preload_train)) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size_train, shuffle=True, num_workers=6) else: train_meta_dataset = FedDataset( json_path=train_file, n_shot_per_class=args.n_shot_train, n_query_per_class=args.n_query_train, image_size=(image_size, image_size), # has to be a (h, w) tuple randomize_query=str2bool(args.randomize_query), preload=str2bool(args.preload_train), fixed_sq=str2bool(args.fixed_sq)) train_loader = FedDataLoader(dataset=train_meta_dataset, n_batches=args.n_iters_per_epoch, batch_size=args.batch_size_train) print("\n", "--" * 20, "VAL", "--" * 20) if args.algorithm in ["SupervisedBaseline", "TransferLearning"]: val_dataset = SimpleFedDataset( json_path=val_file, image_size=(image_size, image_size), # has to be a (h, w) tuple preload=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size_val, shuffle=False, drop_last=False, num_workers=6) else: val_meta_datasets = {} val_loaders = {} for ns_val in all_n_shot_vals: print("====", f"n_shots_val {ns_val}", "====") val_meta_datasets[ns_val] = FedDataset( json_path=val_file, n_shot_per_class=ns_val, n_query_per_class=args.n_query_val, image_size=(image_size, image_size), randomize_query=False, preload=True, fixed_sq=str2bool(args.fixed_sq)) val_loaders[ns_val] = FedDataLoader( dataset=val_meta_datasets[ns_val], n_batches=args.n_iterations_val, batch_size=args.batch_size_val) print("\n", "--" * 20, "TEST", "--" * 20) if args.algorithm in ["SupervisedBaseline", "TransferLearning"]: test_dataset = SimpleFedDataset( json_path=test_file, image_size=(image_size, image_size), # has to be a (h, w) tuple preload=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size_val, shuffle=False, drop_last=False, num_workers=6) else: test_meta_datasets = {} test_loaders = {} for ns_val in all_n_shot_vals: print("====", f"n_shots_val {ns_val}", "====") test_meta_datasets[ns_val] = FedDataset( json_path=test_file, n_shot_per_class=ns_val, n_query_per_class=args.n_query_val, image_size=(image_size, image_size), randomize_query=False, preload=True, fixed_sq=str2bool(args.fixed_sq)) test_loaders[ns_val] = FedDataLoader( dataset=test_meta_datasets[ns_val], n_batches=args.n_iterations_val, batch_size=args.batch_size_val) ''' # currently for fedlearn_datasets there is no notion of base_class because the classes # used for base, val, test are already the same. if base_class_generalization: # can only do this if there is only one type of evaluation print("\n", "--"*20, "BASE TEST", "--"*20) base_test_classes = ClassImagesSet(base_test_file) base_test_meta_dataset = MetaDataset( dataset_name=dataset_name, support_class_images_set=base_test_classes, query_class_images_set=base_test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder=save_folder) base_test_loader = MetaDataLoader( dataset=base_test_meta_dataset, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False) if args.fix_support > 0: base_test_meta_dataset_using_fixS = MetaDataset( dataset_name=dataset_name, support_class_images_set=train_classes, query_class_images_set=base_test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder=save_folder, fix_support_path=os.path.join(save_folder, "fixed_support_pool.pkl")) base_test_loader_using_fixS = MetaDataLoader( dataset=base_test_meta_dataset_using_fixS, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False,) ''' #################################################### # MODEL/BACKBONE CREATION # #################################################### print("\n", "--" * 20, "MODEL", "--" * 20) if args.model_type == 'resnet_12': if 'miniImagenet' in dataset_name or 'CUB' in dataset_name or 'celeba' in dataset_name: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=5, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) else: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=2, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) elif args.model_type in ['conv64', 'conv48', 'conv32']: dim = int(args.model_type[-2:]) model = shallow_conv.ShallowConv(z_dim=dim, h_dim=dim, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) elif args.model_type == 'wide_resnet28_10': model = wide_resnet.wrn28_10(projection=str2bool(args.projection), classifier_type=args.classifier_type, learnable_scale=str2bool( args.learnable_scale)) elif args.model_type == 'wide_resnet16_10': model = wide_resnet.wrn16_10(projection=str2bool(args.projection), classifier_type=args.classifier_type, learnable_scale=str2bool( args.learnable_scale)) else: raise ValueError('Unrecognized model type {}'.format(args.model_type)) print("Model\n" + "==" * 27) print(model) #################################################### # OPTIMIZER CREATION # #################################################### # optimizer construction print("\n", "--" * 20, "OPTIMIZER", "--" * 20) print("Optimzer", args.optimizer_type) if args.optimizer_type == 'adam': optimizer = torch.optim.Adam([{ 'params': model.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay }]) else: optimizer = modified_sgd.SGD([ { 'params': model.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay, 'momentum': 0.9, 'nesterov': True }, ]) print("Total n_epochs: ", args.n_epochs) # learning rate scheduler creation if args.lr_scheduler_type == 'deterministic': drop_eps = [int(x) for x in args.drop_lr_epoch.split(',')] if args.drop_factors != '': drop_factors = [float(x) for x in args.drop_factors.split(',')] else: drop_factors = [0.06, 0.012, 0.0024] assert len(drop_factors) >= len(drop_eps), "No ennough drop factors" print("Drop lr at epochs", drop_eps) print("Drop factors", drop_factors[:len(drop_eps)]) assert len( drop_eps ) <= 3, "Must give less than or equal to three epochs to drop lr" if len(drop_eps) == 3: lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else ( drop_factors[0] if e < drop_eps[1] else drop_factors[1] if e < drop_eps[2] else (drop_factors[2])) elif len(drop_eps) == 2: lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (drop_factors[ 0] if e < drop_eps[1] else drop_factors[1]) else: lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else drop_factors[0 ] lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda_epoch, last_epoch=-1) for _ in range(args.restart_iter): lr_scheduler.step() elif args.lr_scheduler_type == 'val_based': lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', patience=5, factor=0.1, min_lr=5e-6, threshold=0.5) else: raise ValueError("Unimplemented lr scheduler") print("LR scheduler ", args.lr_scheduler_type) #################################################### # LOAD FROM CHECKPOINT # #################################################### if args.checkpoint != '': print(f"loading model from {args.checkpoint}") model_dict = model.state_dict() chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu')) try: print(f"loading optimizer from {args.checkpoint}") optimizer.state = chkpt['optimizer'].state print("Successfully loaded optimizer") except: print("Failed to load optimizer") chkpt_state_dict = chkpt['model'] chkpt_state_dict_cpy = chkpt_state_dict.copy() # remove "module." from key, possibly present as it was dumped by data-parallel for key in chkpt_state_dict_cpy.keys(): if 'module.' in key: new_key = re.sub('module\.', '', key) chkpt_state_dict[new_key] = chkpt_state_dict.pop(key) chkpt_state_dict = { k: v for k, v in chkpt_state_dict.items() if k in model_dict } model_dict.update(chkpt_state_dict) updated_keys = set(model_dict).intersection(set(chkpt_state_dict)) print(f"Updated {len(updated_keys)} keys using chkpt") print("Following keys updated :", "\n".join(sorted(updated_keys))) print() missed_keys = set(model_dict).difference(set(chkpt_state_dict)) print(f"Missed {len(missed_keys)} keys") print("Following keys missed :", "\n".join(sorted(missed_keys))) model.load_state_dict(model_dict) # Multi-gpu support and device setup os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"]) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() #################################################### # ALGORITHM AND ALGORITHM TRAINER # #################################################### # start tboard from restart iter init_global_iteration = 0 if args.restart_iter: init_global_iteration = args.restart_iter * args.n_iters_per_epoch # algorithm if args.algorithm == 'InitBasedAlgorithm': algorithm = InitBasedAlgorithm( model=model, loss_func=torch.nn.CrossEntropyLoss(), method=args.init_meta_algorithm, alpha=args.alpha, inner_loop_grad_clip=args.grad_clip_inner, inner_update_method=args.inner_update_method, device='cuda') elif args.algorithm == 'ProtoNet': algorithm = ProtoNet(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), device='cuda', scale=args.scale_factor, metric=args.classifier_metric) elif args.algorithm == 'SVM': algorithm = SVM(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), scale=args.scale_factor, device='cuda') elif args.algorithm == 'Ridge': algorithm = Ridge(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), scale=args.scale_factor, device='cuda') elif args.algorithm in ['TransferLearning', 'SupervisedBaseline']: """ We use the ProtoNet algorithm at test time. """ algorithm = ProtoNet(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), device='cuda', scale=args.scale_factor, metric=args.classifier_metric) else: raise ValueError('Unrecognized algorithm {}'.format(args.algorithm)) if args.algorithm == 'InitBasedAlgorithm': trainer = Init_algorithm_trainer( algorithm=algorithm, optimizer=optimizer, writer=writer, log_interval=args.log_interval, save_folder=save_folder, grad_clip=args.grad_clip, num_updates_inner_train=args.num_updates_inner_train, num_updates_inner_val=args.num_updates_inner_val, init_global_iteration=init_global_iteration) elif args.algorithm in ['TransferLearning', 'SupervisedBaseline']: trainer = TL_algorithm_trainer( algorithm=algorithm, optimizer=optimizer, writer=writer, log_interval=args.log_interval, save_folder=save_folder, grad_clip=args.grad_clip, init_global_iteration=init_global_iteration) else: trainer = Meta_algorithm_trainer( algorithm=algorithm, optimizer=optimizer, writer=writer, log_interval=args.log_interval, save_folder=save_folder, grad_clip=args.grad_clip, init_global_iteration=init_global_iteration) #################################################### # TRAINER LOOP # #################################################### print("\n", "--" * 20, "BEGIN TRAINING", "--" * 20) # iterate over training epochs for iter_start in range(args.restart_iter, args.n_epochs): # training for param_group in optimizer.param_groups: print('\n\nlearning rate:', param_group['lr']) if args.algorithm in ['SupervisedBaseline']: trainer.run(mt_loader=train_loader, evaluate_supervised_classification=True, is_training=True, epoch=iter_start + 1) else: trainer.run(mt_loader=train_loader, is_training=True, epoch=iter_start + 1) if iter_start % args.val_frequency == 0: ''' # On ML train objective print("Train Loss on ML objective") results = trainer.run( mt_loader=no_fixS_train_loader, is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) writer.add_scalar( "train_acc_on_ml", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar( "train_loss_on_ml", results['test_loss_after']['loss'], iter_start + 1) base_train_loss = results['test_loss_after']['loss'] ''' # validation/test if args.algorithm in ["SupervisedBaseline"]: print("Validation") results = trainer.run(mt_loader=val_loader, is_training=False, evaluate_supervised_classification=True) writer.add_scalar(f"val_acc", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar(f"val_loss", results['test_loss_after']['loss'], iter_start + 1) val_accu = results['test_loss_after']['accu'] else: val_accus = {} for ns_val in all_n_shot_vals: print("Validation ", f"n_shots_val {ns_val}") results = trainer.run(mt_loader=val_loaders[ns_val], is_training=False) writer.add_scalar(f"val_acc_{args.n_way_val}w{ns_val}s", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar(f"val_loss_{args.n_way_val}w{ns_val}s", results['test_loss_after']['loss'], iter_start + 1) val_accus[ns_val] = results['test_loss_after']['accu'] pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) if args.algorithm in ["SupervisedBaseline"]: print("Test") results = trainer.run(mt_loader=test_loader, is_training=False, evaluate_supervised_classification=True) writer.add_scalar(f"test_acc", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar(f"test_loss", results['test_loss_after']['loss'], iter_start + 1) novel_test_loss = results['test_loss_after']['loss'] else: print("Test ", f"n_shots_val {ns_val}") results = trainer.run(mt_loader=test_loaders[ns_val], is_training=False) writer.add_scalar(f"test_acc_{args.n_way_val}w{ns_val}s", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar(f"test_loss_{args.n_way_val}w{ns_val}s", results['test_loss_after']['loss'], iter_start + 1) novel_test_losses[ns_val] = results['test_loss_after']['loss'] pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) if args.algorithm != "SupervisedBaseline": val_accu = val_accus[ args.n_shot_val] # stick with 5w5s for model selection novel_test_loss = novel_test_losses[ args.n_shot_val] # stick with 5w5s for model selection # base class generalization ''' if base_class_generalization: # can only do this if there is only one type of evaluation print("Base Test") results = trainer.run( mt_loader=base_test_loader, is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) writer.add_scalar( "base_test_acc", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar( "base_test_loss", results['test_loss_after']['loss'], iter_start + 1) base_test_loss = results['test_loss_after']['loss'] writer.add_scalar( "base_gen_gap", base_test_loss - base_train_loss, iter_start + 1) writer.add_scalar( "novel_gen_gap", novel_test_loss - base_train_loss, iter_start + 1) if args.fix_support > 0: print("Base Test using FixSupport, matching train and test for fixml") results = trainer.run( mt_loader=base_test_loader_using_fixS, is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) writer.add_scalar( "base_test_acc_usingFixS", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar( "base_test_loss_usingFixS", results['test_loss_after']['loss'], iter_start + 1) ''' # scheduler step if args.lr_scheduler_type == 'val_based': assert args.val_frequency == 1, "eval after every epoch is mandatory for val based lr scheduler" lr_scheduler.step(val_accu) else: lr_scheduler.step()
def main(args): #################################################### # LOGGING AND SAVING # #################################################### args.output_folder = ensure_path('./runs/{0}'.format(args.output_folder)) if str2bool(args.eot_model): eval_results = f'{args.output_folder}/evaleot_results.txt' else: eval_results = f'{args.output_folder}/eval_results.txt' with open(eval_results, 'w') as f: f.write("--" * 20 + "EVALUATION RESULTS" + "--" * 20 + '\n') #################################################### # DATASET AND DATALOADER CREATION # #################################################### # json paths dataset_name = args.dataset_path.split('/')[-1] image_size = args.img_side_len dataset_name = args.dataset_path.split('/')[-1] # Following is needed when same train config is used for both 5w5s and 5w1s evaluations. # This is the case in the case of SVM when 5w15s5q is used for both 5w5s and 5w1s evaluations. all_n_shot_vals = [args.n_shot_val, 1] if str2bool( args.do_one_shot_eval_too) else [args.n_shot_val] base_class_generalization = dataset_name.lower() in [ 'miniimagenet', 'fc100-base', 'cifar-fs-base', 'tieredimagenet-base' ] train_file = os.path.join(args.dataset_path, 'base.json') val_file = os.path.join(args.dataset_path, 'val.json') test_file = os.path.join(args.dataset_path, 'novel.json') if base_class_generalization: base_test_file = os.path.join(args.dataset_path, 'base_test.json') print("Dataset name", dataset_name, "image_size", image_size, "all_n_shot_vals", all_n_shot_vals) print("base_class_generalization:", base_class_generalization) """ 1. Create ClassImagesSet object, which handles preloading of images 2. Pass ClassImagesSet to MetaDataset which handles nshot, nquery and fixSupport 3. Create Dataloader object from MetaDataset """ print("\n", "--" * 20, "BASE", "--" * 20) train_classes = ClassImagesSet(train_file, preload=str2bool(args.preload_train)) # create a dataloader that has no fixed support no_fixS_train_meta_dataset = MetaDataset( dataset_name=dataset_name, support_class_images_set=train_classes, query_class_images_set=train_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, # no fixed support save_folder='', verbose=False) no_fixS_train_loader = MetaDataLoader(dataset=no_fixS_train_meta_dataset, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False) print("\n", "--" * 20, "VAL", "--" * 20) val_classes = ClassImagesSet(val_file, preload=str2bool(args.preload_train)) val_meta_datasets = {} val_loaders = {} for ns_val in all_n_shot_vals: print("====", f"n_shots_val {ns_val}", "====") val_meta_datasets[ns_val] = MetaDataset( dataset_name=dataset_name, support_class_images_set=val_classes, query_class_images_set=val_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder='') val_loaders[ns_val] = MetaDataLoader(dataset=val_meta_datasets[ns_val], n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=ns_val, n_query=args.n_query_val, randomize_query=False) print("\n", "--" * 20, "NOVEL", "--" * 20) test_classes = ClassImagesSet(test_file, preload=str2bool(args.preload_train)) test_meta_datasets = {} test_loaders = {} for ns_val in all_n_shot_vals: print("====", f"n_shots_val {ns_val}", "====") test_meta_datasets[ns_val] = MetaDataset( dataset_name=dataset_name, support_class_images_set=test_classes, query_class_images_set=test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder='') test_loaders[ns_val] = MetaDataLoader( dataset=test_meta_datasets[ns_val], n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=ns_val, n_query=args.n_query_val, randomize_query=False, ) if base_class_generalization: # can only do this if there is only one type of evaluation print("\n", "--" * 20, "BASE TEST", "--" * 20) base_test_classes = ClassImagesSet(base_test_file, preload=str2bool( args.preload_train)) # if args.fix_support > 0: # base_test_meta_dataset_using_fixS = MetaDataset( # dataset_name=dataset_name, # support_class_images_set=train_classes, # query_class_images_set=base_test_classes, # image_size=image_size, # support_aug=False, # query_aug=False, # fix_support=0, # save_folder='', # fix_support_path=os.path.join(args.output_folder, # "fixed_support_pool.pkl")) # base_test_loader_using_fixS = MetaDataLoader( # dataset=base_test_meta_dataset_using_fixS, # n_batches=args.n_iterations_val, # batch_size=args.batch_size_val, # n_way=args.n_way_val, # n_shot=args.n_shot_val, # n_query=args.n_query_val, # randomize_query=False,) print("\n", "--" * 20, "BASE + NOVEL TEST", "--" * 20) assert len(set(base_test_classes.keys()).intersection(set(test_classes.keys()))) == 0,\ f"the base and novel classes must have different ids, base:{set(base_test_classes.keys())}, novel: f{set(test_classes.keys())}" # combine both base and novel classes base_novel_test_classes = ClassImagesSet(base_test_file, test_file) base_novel_test_meta_dataset = MetaDataset( dataset_name=dataset_name, support_class_images_set=base_novel_test_classes, query_class_images_set=base_novel_test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder='') # sample classes from base and novel with mix prob. given by lambd base_novel_test_loaders_dict = {} for lambd in np.arange(0., 1.1, 0.1): base_novel_test_loaders_dict[lambd] = MetaDataLoader( dataset=base_novel_test_meta_dataset, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False, p_dict={ k: ((1 - lambd) / len(base_test_classes) if k in base_test_classes else lambd / len(test_classes)) for k in list(base_test_classes) + list(test_classes) }) #################################################### # MODEL/BACKBONE CREATION # #################################################### print("\n", "--" * 20, "MODEL", "--" * 20) if args.model_type == 'resnet_12': if 'miniImagenet' in dataset_name or 'CUB' in dataset_name: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=5, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) else: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=2, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) elif args.model_type in ['conv64', 'conv48', 'conv32']: dim = int(args.model_type[-2:]) model = shallow_conv.ShallowConv(z_dim=dim, h_dim=dim, num_classes=args.num_classes_train, x_width=image_size, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) elif args.model_type == 'wide_resnet28_10': model = wide_resnet.wrn28_10(projection=str2bool(args.projection), classifier_type=args.classifier_type, learnable_scale=str2bool( args.learnable_scale)) elif args.model_type == 'wide_resnet16_10': model = wide_resnet.wrn16_10(projection=str2bool(args.projection), classifier_type=args.classifier_type, learnable_scale=str2bool( args.learnable_scale)) else: raise ValueError('Unrecognized model type {}'.format(args.model_type)) print("Model\n" + "==" * 27) print(model) #################################################### # LOAD FROM CHECKPOINT # #################################################### assert args.checkpoint != '', "Must provide checkpoint" print(f"loading model from {args.checkpoint}") model_dict = model.state_dict() chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu')) ### load model chkpt_state_dict = chkpt['model'] chkpt_state_dict_old_keys = list(chkpt_state_dict.keys()) # remove "module." from key, possibly present as it was dumped by data-parallel for key in chkpt_state_dict_old_keys: if 'module.' in key: new_key = re.sub('module\.', '', key) chkpt_state_dict[new_key] = chkpt_state_dict.pop(key) load_model_state_dict = { k: v for k, v in chkpt_state_dict.items() if k in model_dict } model_dict.update(load_model_state_dict) # updated_keys = set(model_dict).intersection(set(chkpt_state_dict)) print(f"Updated {len(load_model_state_dict.keys())} keys using chkpt") print("Following keys updated :", "\n".join(sorted(load_model_state_dict.keys()))) missed_keys = set(model_dict).difference(set(load_model_state_dict)) print(f"Missed {len(missed_keys)} keys") print("Following keys missed :", "\n".join(sorted(missed_keys))) model.load_state_dict(model_dict) # Multi-gpu support and device setup os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"]) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() #################################################### # ALGORITHM AND ALGORITHM TRAINER # #################################################### # algorithm if args.algorithm == 'InitBasedAlgorithm': algorithm = InitBasedAlgorithm( model=model, loss_func=torch.nn.CrossEntropyLoss(), method=args.init_meta_algorithm, alpha=args.alpha, inner_loop_grad_clip=args.grad_clip_inner, inner_update_method=args.inner_update_method, device='cuda') elif args.algorithm == 'ProtoNet': algorithm = ProtoNet(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), device='cuda', scale=args.scale_factor, metric=args.classifier_metric) elif args.algorithm == 'SVM': algorithm = SVM(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), scale=args.scale_factor, device='cuda') elif args.algorithm == 'Ridge': algorithm = Ridge(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), scale=args.scale_factor, device='cuda') else: raise ValueError('Unrecognized algorithm {}'.format(args.algorithm)) if args.algorithm != 'InitBasedAlgorithm': trainer = Meta_algorithm_trainer(algorithm=algorithm, optimizer=None, writer=None, log_interval=args.log_interval, save_folder='', grad_clip=None, init_global_iteration=None) else: trainer = Init_algorithm_trainer( algorithm=algorithm, optimizer=None, writer=None, log_interval=args.log_interval, save_folder='', grad_clip=None, num_updates_inner_train=args.num_updates_inner_train, num_updates_inner_val=args.num_updates_inner_val, init_global_iteration=None) #################################################### # EVALUATION # #################################################### print("\n", "--" * 20, "BEGIN EVALUATION", "--" * 20) # On ML train objective print("Train Loss on ML objective") results = trainer.run(mt_loader=no_fixS_train_loader, is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) with open(eval_results, 'a') as f: f.write( f"TrainLossOnML{args.n_way_val}w{ns_val}s: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}" + '\n') # validation/test # val_accus = {} # novel_test_losses = {} for ns_val in all_n_shot_vals: print("Validation ", f"n_shots_val {ns_val}") results = trainer.run(mt_loader=val_loaders[ns_val], is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) with open(eval_results, 'a') as f: f.write( f"Val{args.n_way_val}w{ns_val}s: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}" + '\n') print("Test ", f"n_shots_val {ns_val}") results = trainer.run(mt_loader=test_loaders[ns_val], is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) with open(eval_results, 'a') as f: f.write( f"Test{args.n_way_val}w{ns_val}s: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}" + '\n') # base class generalization if base_class_generalization: # can only do this if there is only one type of evaluation print("Base Test") # if args.fix_support > 0: # print("Base Test using FixSupport, matching train and test for fixml") # results = trainer.run( # mt_loader=base_test_loader_using_fixS, is_training=False) # pp = pprint.PrettyPrinter(indent=4) # pp.pprint(results) # with open(eval_results, 'a') as f: # f.write(f"BaseTestUsingFixSupport: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}"+'\n') for lambd, base_novel_test_loader in base_novel_test_loaders_dict.items( ): print( f"Base + Novel Test lambda={round(lambd, 2)} Novel {round(1-lambd, 2)} Base" ) results = trainer.run(mt_loader=base_novel_test_loader, is_training=False) pp = pprint.PrettyPrinter(indent=4) pp.pprint(results) with open(eval_results, 'a') as f: f.write( f"Base+NovelTestLambda={round(lambd, 2)}Novel{round(1-lambd, 2)}Base: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}" + '\n')
def compute_norm(args, p=2): #################################################### # MODEL/BACKBONE CREATION # #################################################### print("\n", "--" * 20, "MODEL", "--" * 20) dataset_name = args.dataset_name if args.model_type == 'resnet_12': if 'miniImagenet' in dataset_name or 'CUB' in dataset_name: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=5, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection)) else: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=2, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection)) elif args.model_type in ['conv64', 'conv48', 'conv32']: dim = int(args.model_type[-2:]) model = shallow_conv.ShallowConv(z_dim=dim, h_dim=dim, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection)) elif args.model_type == 'wide_resnet28_10': model = wide_resnet.wrn28_10(projection=str2bool(args.projection), classifier_type=args.classifier_type) elif args.model_type == 'wide_resnet16_10': model = wide_resnet.wrn16_10(projection=str2bool(args.projection), classifier_type=args.classifier_type) else: raise ValueError('Unrecognized model type {}'.format(args.model_type)) print("Model\n" + "==" * 27) print(model) #################################################### # LOAD FROM CHECKPOINT # #################################################### if args.checkpoint != '': print(f"loading model from {args.checkpoint}") model_dict = model.state_dict() chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu')) try: print(f"loading optimizer from {args.checkpoint}") optimizer.state = chkpt['optimizer'].state print("Successfully loaded optimizer") except: print("Failed to load optimizer") chkpt_state_dict = chkpt['model'] chkpt_state_dict_cpy = chkpt_state_dict.copy() # remove "module." from key, possibly present as it was dumped by data-parallel for key in chkpt_state_dict_cpy.keys(): if 'module.' in key: new_key = re.sub('module\.', '', key) chkpt_state_dict[new_key] = chkpt_state_dict.pop(key) chkpt_state_dict = { k: v for k, v in chkpt_state_dict.items() if k in model_dict } model_dict.update(chkpt_state_dict) updated_keys = set(model_dict).intersection(set(chkpt_state_dict)) print(f"Updated {len(updated_keys)} keys using chkpt") print("Following keys updated :", "\n".join(sorted(updated_keys))) missed_keys = set(model_dict).difference(set(chkpt_state_dict)) print(f"Missed {len(missed_keys)} keys") print("Following keys missed :", "\n".join(sorted(missed_keys))) model.load_state_dict(model_dict) #################################################### # COMPUTE NORM HERE # #################################################### if p == 2: return l2_norm(model) else: raise ValueError('Unspecified norm type')
def main(args): #################################################### # LOGGING AND SAVING # #################################################### if args.checkpoint != '': # if we are reloading, we don't need to timestamp and create a new folder # instead keep writing to the original output_folder assert os.path.exists(f'./runs/{args.output_folder}') args.output_folder = f'./runs/{args.output_folder}' print(f'resume training and will write to {args.output_folder}') else: args.output_folder = ensure_path('./runs/{0}'.format( args.output_folder)) writer = SummaryWriter(args.output_folder) time_now = datetime.now( pytz.timezone("America/New_York")).strftime("%d:%b:%Y:%H:%M:%S") with open(f'{args.output_folder}/config_{time_now}.txt', 'w') as config_txt: for k, v in sorted(vars(args).items()): config_txt.write(f'{k}: {v}\n') save_folder = args.output_folder # replace stdout with Logger; the original sys.stdout is saved in src.logger.stdout sys.stdout = src.logger.Logger( log_filename=f'{args.output_folder}/train_{time_now}.log') src.logger.stdout.write('hi!') #################################################### # DATASET AND DATALOADER CREATION # #################################################### # json paths dataset_name = args.dataset_path.split('/')[-1] image_size = args.img_side_len dataset_name = args.dataset_path.split('/')[-1] # Following is needed when same train config is used for both 5w5s and 5w1s evaluations. # This is the case in the case of SVM when 5w15s5q is used for both 5w5s and 5w1s evaluations. all_n_shot_vals = [args.n_shot_val, 1] if str2bool( args.do_one_shot_eval_too) else [args.n_shot_val] base_class_generalization = dataset_name.lower() in [ 'miniimagenet', 'fc100-base', 'cifar-fs-base', 'tieredimagenet-base' ] train_file = os.path.join(args.dataset_path, 'base.json') val_file = os.path.join(args.dataset_path, 'val.json') test_file = os.path.join(args.dataset_path, 'novel.json') if base_class_generalization: base_test_file = os.path.join(args.dataset_path, 'base_test.json') print("Dataset name", dataset_name, "image_size", image_size, "all_n_shot_vals", all_n_shot_vals) print("base_class_generalization:", base_class_generalization) """ 1. Create ClassImagesSet object, which handles preloading of images 2. Pass ClassImagesSet to MetaDataset which handles nshot, nquery and fixSupport 3. Create Dataloader object from MetaDataset """ print("\n", "--" * 20, "TRAIN", "--" * 20) train_classes = ClassImagesSet(train_file, preload=str2bool(args.preload_train)) if args.algorithm == 'TransferLearning': """ For Transfer Learning we create a SimpleDataset. The augmentation is decided by query_aug flag. """ train_dataset = SimpleDataset(dataset_name=dataset_name, class_images_set=train_classes, image_size=image_size, aug=str2bool(args.query_aug)) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size_train, shuffle=True, num_workers=6) else: train_meta_dataset = MetaDataset( dataset_name=dataset_name, support_class_images_set=train_classes, query_class_images_set=train_classes, image_size=image_size, support_aug=str2bool(args.support_aug), query_aug=str2bool(args.query_aug), fix_support=args.fix_support, save_folder=save_folder, fix_support_path=args.fix_support_path) train_loader = MetaDataLoader(dataset=train_meta_dataset, batch_size=args.batch_size_train, n_batches=args.n_iters_per_epoch, n_way=args.n_way_train, n_shot=args.n_shot_train, n_query=args.n_query_train, randomize_query=str2bool( args.randomize_query)) # create a dataloader that has no fixed support no_fixS_train_meta_dataset = MetaDataset( dataset_name=dataset_name, support_class_images_set=train_classes, query_class_images_set=train_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, # no fixed support save_folder='', verbose=False) no_fixS_train_loader = MetaDataLoader(dataset=no_fixS_train_meta_dataset, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False) print("\n", "--" * 20, "VAL", "--" * 20) val_classes = ClassImagesSet(val_file, preload=False) val_meta_datasets = {} val_loaders = {} for ns_val in all_n_shot_vals: print("====", f"n_shots_val {ns_val}", "====") val_meta_datasets[ns_val] = MetaDataset( dataset_name=dataset_name, support_class_images_set=val_classes, query_class_images_set=val_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder='') val_loaders[ns_val] = MetaDataLoader(dataset=val_meta_datasets[ns_val], n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=ns_val, n_query=args.n_query_val, randomize_query=False) print("\n", "--" * 20, "TEST", "--" * 20) test_classes = ClassImagesSet(test_file) test_meta_datasets = {} test_loaders = {} for ns_val in all_n_shot_vals: print("====", f"n_shots_val {ns_val}", "====") test_meta_datasets[ns_val] = MetaDataset( dataset_name=dataset_name, support_class_images_set=test_classes, query_class_images_set=test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder='') test_loaders[ns_val] = MetaDataLoader( dataset=test_meta_datasets[ns_val], n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=ns_val, n_query=args.n_query_val, randomize_query=False, ) if base_class_generalization: # can only do this if there is only one type of evaluation print("\n", "--" * 20, "BASE TEST", "--" * 20) base_test_classes = ClassImagesSet(base_test_file) base_test_meta_dataset = MetaDataset( dataset_name=dataset_name, support_class_images_set=base_test_classes, query_class_images_set=base_test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder=save_folder) base_test_loader = MetaDataLoader(dataset=base_test_meta_dataset, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False) if args.fix_support > 0: base_test_meta_dataset_using_fixS = MetaDataset( dataset_name=dataset_name, support_class_images_set=train_classes, query_class_images_set=base_test_classes, image_size=image_size, support_aug=False, query_aug=False, fix_support=0, save_folder=save_folder, fix_support_path=os.path.join(save_folder, "fixed_support_pool.pkl")) base_test_loader_using_fixS = MetaDataLoader( dataset=base_test_meta_dataset_using_fixS, n_batches=args.n_iterations_val, batch_size=args.batch_size_val, n_way=args.n_way_val, n_shot=args.n_shot_val, n_query=args.n_query_val, randomize_query=False, ) #################################################### # MODEL/BACKBONE CREATION # #################################################### print("\n", "--" * 20, "MODEL", "--" * 20) if args.model_type == 'resnet_12': # technically tieredimagenet should also have dropblock size of 5 if 'miniImagenet' in dataset_name or 'CUB' in dataset_name: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=5, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) else: model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool), drop_rate=0.1, dropblock_size=2, num_classes=args.num_classes_train, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) elif args.model_type in ['conv64', 'conv48', 'conv32']: dim = int(args.model_type[-2:]) model = shallow_conv.ShallowConv(z_dim=dim, h_dim=dim, num_classes=args.num_classes_train, x_width=image_size, classifier_type=args.classifier_type, projection=str2bool(args.projection), learnable_scale=str2bool( args.learnable_scale)) elif args.model_type == 'wide_resnet28_10': model = wide_resnet.wrn28_10(projection=str2bool(args.projection), classifier_type=args.classifier_type, learnable_scale=str2bool( args.learnable_scale)) elif args.model_type == 'wide_resnet16_10': model = wide_resnet.wrn16_10(projection=str2bool(args.projection), classifier_type=args.classifier_type, learnable_scale=str2bool( args.learnable_scale)) else: raise ValueError('Unrecognized model type {}'.format(args.model_type)) print("Model\n" + "==" * 27) print(model) #################################################### # OPTIMIZER CREATION # #################################################### # optimizer construction print("\n", "--" * 20, "OPTIMIZER", "--" * 20) print("Optimzer", args.optimizer_type) if args.optimizer_type == 'adam': optimizer = torch.optim.Adam([{ 'params': model.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay }]) else: optimizer = modified_sgd.SGD([ { 'params': model.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay, 'momentum': 0.9, 'nesterov': True }, ]) print("Total n_epochs: ", args.n_epochs) # learning rate scheduler creation if args.lr_scheduler_type == 'deterministic': drop_eps = [int(x) for x in args.drop_lr_epoch.split(',')] if args.drop_factors != '': drop_factors = [float(x) for x in args.drop_factors.split(',')] else: drop_factors = [0.06, 0.012, 0.0024] print("Drop lr at epochs", drop_eps) print("Drop factors", drop_factors[:len(drop_eps)]) assert len(drop_factors) >= len(drop_eps), "No enough drop factors" # assert len(drop_eps) <= 3, "Must give less than or equal to three epochs to drop lr" ''' if len(drop_eps) == 3: lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (drop_factors[0] if e < drop_eps[1] else drop_factors[1] if e < drop_eps[2] else (drop_factors[2])) elif len(drop_eps) == 2: lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (drop_factors[0] if e < drop_eps[1] else drop_factors[1]) else: lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else drop_factors[0] ''' def lr_lambda(x): ''' x is an epoch number drop_eps is assumed to an list of strictly increasing epoch numbers here we require len(drop_factors) >= len(drop_eps) ideally they are of the same length but technically the code can just not use the additional factors ''' for i in range(len(drop_eps)): if x >= drop_eps[i]: continue else: if i == 0: return 1.0 else: return drop_factors[i - 1] return drop_factors[len(drop_eps) - 1] lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda, last_epoch=-1) for _ in range(args.restart_iter): lr_scheduler.step() elif args.lr_scheduler_type == 'val_based': lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', patience=5, factor=0.1, min_lr=5e-6, threshold=0.5) else: raise ValueError("Unimplemented lr scheduler") print("LR scheduler ", args.lr_scheduler_type) #################################################### # LOAD FROM CHECKPOINT # #################################################### if args.checkpoint != '': print(f"loading model from {args.checkpoint}") model_dict = model.state_dict() # new model's state dict chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu')) ### load model chkpt_state_dict = chkpt['model'] chkpt_state_dict_old_keys = list(chkpt_state_dict.keys()) # remove "module." from key, possibly present as it was dumped by data-parallel for key in chkpt_state_dict_old_keys: if 'module.' in key: new_key = re.sub('module\.', '', key) chkpt_state_dict[new_key] = chkpt_state_dict.pop(key) load_model_state_dict = { k: v for k, v in chkpt_state_dict.items() if k in model_dict } model_dict.update(load_model_state_dict) # updated_keys = set(model_dict).intersection(set(chkpt_state_dict)) print(f"Updated {len(load_model_state_dict.keys())} keys using chkpt") print("Following keys updated :", "\n".join(sorted(load_model_state_dict.keys()))) missed_keys = set(model_dict).difference(set(load_model_state_dict)) print(f"Missed {len(missed_keys)} keys") print("Following keys missed :", "\n".join(sorted(missed_keys))) model.load_state_dict(model_dict) ### load optimizer try: print(f"loading optimizer from {args.checkpoint}") optimizer.load_state_dict(chkpt['optimizer'].state_dict()) print("Successfully loaded optimizer") except: print("Failed to load optimizer") ### Multi-gpu support and device setup os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"]) # move model to cuda model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() print("Successfully moved the model to cuda") # move the optimizer's states to cuda if loaded if args.checkpoint != '': # https://github.com/pytorch/pytorch/issues/2830 # when using gpu, need to move all the statistics of the optimizer to cuda # in addition to the model parameters for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() print("Successfully moved the optimizer's states to cuda") #################################################### # ALGORITHM AND ALGORITHM TRAINER # #################################################### # start tboard from restart iter init_global_iteration = 0 if args.restart_iter: init_global_iteration = args.restart_iter * args.n_iters_per_epoch # algorithm if args.algorithm == 'InitBasedAlgorithm': algorithm = InitBasedAlgorithm( model=model, loss_func=torch.nn.CrossEntropyLoss(), method=args.init_meta_algorithm, alpha=args.alpha, inner_loop_grad_clip=args.grad_clip_inner, inner_update_method=args.inner_update_method, device='cuda') elif args.algorithm == 'ProtoNet': algorithm = ProtoNet(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), device='cuda', scale=args.scale_factor, metric=args.classifier_metric) elif args.algorithm == 'SVM': algorithm = SVM(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), scale=args.scale_factor, device='cuda') elif args.algorithm == 'Ridge': algorithm = Ridge(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), scale=args.scale_factor, device='cuda') elif args.algorithm == 'TransferLearning': """ We use the ProtoNet algorithm at test time. """ algorithm = ProtoNet(model=model, inner_loss_func=torch.nn.CrossEntropyLoss(), device='cuda', scale=args.scale_factor, metric=args.classifier_metric) else: raise ValueError('Unrecognized algorithm {}'.format(args.algorithm)) if args.algorithm == 'InitBasedAlgorithm': trainer = Init_algorithm_trainer( algorithm=algorithm, optimizer=optimizer, writer=writer, log_interval=args.log_interval, save_folder=save_folder, grad_clip=args.grad_clip, num_updates_inner_train=args.num_updates_inner_train, num_updates_inner_val=args.num_updates_inner_val, init_global_iteration=init_global_iteration) elif args.algorithm == 'TransferLearning': trainer = TL_algorithm_trainer( algorithm=algorithm, optimizer=optimizer, writer=writer, log_interval=args.log_interval, save_folder=save_folder, grad_clip=args.grad_clip, init_global_iteration=init_global_iteration) else: trainer = Meta_algorithm_trainer( algorithm=algorithm, optimizer=optimizer, writer=writer, log_interval=args.log_interval, save_folder=save_folder, grad_clip=args.grad_clip, init_global_iteration=init_global_iteration) #################################################### # TRAINER LOOP # #################################################### print("\n", "--" * 20, "BEGIN TRAINING", "--" * 20) # iterate over training epochs for iter_start in range(args.restart_iter, args.n_epochs): # training for param_group in optimizer.param_groups: print('\n\nlearning rate:', param_group['lr']) trainer.run(mt_loader=train_loader, is_training=True, epoch=iter_start + 1) # 1 based instead of 0 based if iter_start % args.val_frequency == 0: # On ML train objective print("Train Loss on ML objective") results = trainer.run(mt_loader=no_fixS_train_loader, is_training=False) print(pprint.pformat(results, indent=4)) writer.add_scalar("train_acc_on_ml", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar("train_loss_on_ml", results['test_loss_after']['loss'], iter_start + 1) base_train_loss = results['test_loss_after']['loss'] # validation/test val_accus = {} novel_test_losses = {} for ns_val in all_n_shot_vals: print("Validation ", f"n_shots_val {ns_val}") results = trainer.run(mt_loader=val_loaders[ns_val], is_training=False) print(pprint.pformat(results, indent=4)) writer.add_scalar(f"val_acc_{args.n_way_val}w{ns_val}s", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar(f"val_loss_{args.n_way_val}w{ns_val}s", results['test_loss_after']['loss'], iter_start + 1) val_accus[ns_val] = results['test_loss_after']['accu'] print("Test ", f"n_shots_val {ns_val}") results = trainer.run(mt_loader=test_loaders[ns_val], is_training=False) print(pprint.pformat(results, indent=4)) writer.add_scalar(f"test_acc_{args.n_way_val}w{ns_val}s", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar(f"test_loss_{args.n_way_val}w{ns_val}s", results['test_loss_after']['loss'], iter_start + 1) novel_test_losses[ns_val] = results['test_loss_after']['loss'] val_accu = val_accus[ args.n_shot_val] # stick with 5w5s for model selection novel_test_loss = novel_test_losses[ args.n_shot_val] # stick with 5w5s for model selection # base class generalization if base_class_generalization: # can only do this if there is only one type of evaluation print("Base Test") results = trainer.run(mt_loader=base_test_loader, is_training=False) print(pprint.pformat(results, indent=4)) writer.add_scalar("base_test_acc", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar("base_test_loss", results['test_loss_after']['loss'], iter_start + 1) base_test_loss = results['test_loss_after']['loss'] writer.add_scalar("base_gen_gap", base_test_loss - base_train_loss, iter_start + 1) writer.add_scalar("novel_gen_gap", novel_test_loss - base_train_loss, iter_start + 1) if args.fix_support > 0: print( "Base Test using FixSupport, matching train and test for fixml" ) results = trainer.run( mt_loader=base_test_loader_using_fixS, is_training=False) print(pprint.pformat(results, indent=4)) writer.add_scalar("base_test_acc_usingFixS", results['test_loss_after']['accu'], iter_start + 1) writer.add_scalar("base_test_loss_usingFixS", results['test_loss_after']['loss'], iter_start + 1) # scheduler step if args.lr_scheduler_type == 'val_based': assert args.val_frequency == 1, "eval after every epoch is mandatory for val based lr scheduler" lr_scheduler.step(val_accu) else: lr_scheduler.step()