def main(): # Retrieve config file p = create_config(args) print(colored(p, 'red')) # Model print(colored('Retrieve model', 'blue')) model = get_model(p) print('Model is {}'.format(model.__class__.__name__)) print(model) model = torch.nn.DataParallel(model) model = model.cuda() # CUDNN print(colored('Set CuDNN benchmark', 'blue')) torch.backends.cudnn.benchmark = True # Dataset print(colored('Retrieve dataset', 'blue')) transforms = get_val_transformations(p) train_dataset = get_train_dataset(p, transforms) val_dataset = get_val_dataset(p, transforms) train_dataloader = get_val_dataloader(p, train_dataset) val_dataloader = get_val_dataloader(p, val_dataset) print('Dataset contains {}/{} train/val samples'.format( len(train_dataset), len(val_dataset))) # Memory Bank print(colored('Build MemoryBank', 'blue')) memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature']) memory_bank_train.cuda() memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature']) memory_bank_val.cuda() # Load the official MoCoV2 checkpoint print(colored('Downloading moco v2 checkpoint', 'blue')) os.system( 'wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar' ) moco_state = torch.load('moco_v2_800ep_pretrain.pth.tar', map_location='cpu') # Transfer moco weights print(colored('Transfer MoCo weights to model', 'blue')) new_state_dict = {} state_dict = moco_state['state_dict'] for k in list(state_dict.keys()): # Copy backbone weights if k.startswith('module.encoder_q' ) and not k.startswith('module.encoder_q.fc'): new_k = 'module.backbone.' + k[len('module.encoder_q.'):] new_state_dict[new_k] = state_dict[k] # Copy mlp weights elif k.startswith('module.encoder_q.fc'): new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.' ):] new_state_dict[new_k] = state_dict[k] else: raise ValueError('Unexpected key {}'.format(k)) model.load_state_dict(new_state_dict) os.system('rm -rf moco_v2_800ep_pretrain.pth.tar') # Save final model print(colored('Save pretext model', 'blue')) torch.save(model.module.state_dict(), p['pretext_model']) model.module.contrastive_head = torch.nn.Identity( ) # In this case, we mine the neighbors before the MLP. # Mine the topk nearest neighbors (Train) # These will be used for training with the SCAN-Loss. topk = 50 print( colored('Mine the nearest neighbors (Train)(Top-%d)' % (topk), 'blue')) transforms = get_val_transformations(p) train_dataset = get_train_dataset(p, transforms) fill_memory_bank(train_dataloader, model, memory_bank_train) indices, acc = memory_bank_train.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on train set is %.2f' % (topk, 100 * acc)) np.save(p['topk_neighbors_train_path'], indices) # Mine the topk nearest neighbors (Validation) # These will be used for validation. topk = 5 print(colored('Mine the nearest neighbors (Val)(Top-%d)' % (topk), 'blue')) fill_memory_bank(val_dataloader, model, memory_bank_val) print('Mine the neighbors') indices, acc = memory_bank_val.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on val set is %.2f' % (topk, 100 * acc)) np.save(p['topk_neighbors_val_path'], indices)
def main(): # Retrieve config file p = create_config(args.config_env, args.config_exp) print(colored(p, 'red')) # Model print(colored('Retrieve model', 'blue')) model = get_model(p) print('Model is {}'.format(model.__class__.__name__)) print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) print(model) model = model.cuda() # CUDNN print(colored('Set CuDNN benchmark', 'blue')) torch.backends.cudnn.benchmark = True # Dataset print(colored('Retrieve dataset', 'blue')) train_transforms = get_train_transformations(p) print('Train transforms:', train_transforms) val_transforms = get_val_transformations(p) print('Validation transforms:', val_transforms) train_dataset = get_train_dataset(p, train_transforms, to_augmented_dataset=True, split='train+unlabeled') # Split is for stl-10 val_dataset = get_val_dataset(p, val_transforms) train_dataloader = get_train_dataloader(p, train_dataset) val_dataloader = get_val_dataloader(p, val_dataset) print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset))) # Memory Bank print(colored('Build MemoryBank', 'blue')) base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval base_dataloader = get_val_dataloader(p, base_dataset) memory_bank_base = MemoryBank(len(base_dataset), p['model_kwargs']['features_dim'], p['num_classes'], p['criterion_kwargs']['temperature']) memory_bank_base.cuda() memory_bank_val = MemoryBank(len(val_dataset), p['model_kwargs']['features_dim'], p['num_classes'], p['criterion_kwargs']['temperature']) memory_bank_val.cuda() # Criterion print(colored('Retrieve criterion', 'blue')) criterion = get_criterion(p) print('Criterion is {}'.format(criterion.__class__.__name__)) criterion = criterion.cuda() # Optimizer and scheduler print(colored('Retrieve optimizer', 'blue')) optimizer = get_optimizer(p, model) print(optimizer) # Checkpoint if os.path.exists(p['pretext_checkpoint']): print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue')) checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') optimizer.load_state_dict(checkpoint['optimizer']) model.load_state_dict(checkpoint['model']) model.cuda() start_epoch = checkpoint['epoch'] else: print(colored('No checkpoint file at {}'.format(p['pretext_checkpoint']), 'blue')) start_epoch = 0 model = model.cuda() # Training print(colored('Starting main loop', 'blue')) for epoch in range(start_epoch, p['epochs']): print(colored('Epoch %d/%d' %(epoch, p['epochs']), 'yellow')) print(colored('-'*15, 'yellow')) # Adjust lr lr = adjust_learning_rate(p, optimizer, epoch) print('Adjusted learning rate to {:.5f}'.format(lr)) # Train print('Train ...') simclr_train(train_dataloader, model, criterion, optimizer, epoch) # Fill memory bank print('Fill memory bank for kNN...') fill_memory_bank(base_dataloader, model, memory_bank_base) # Evaluate (To monitor progress - Not for validation) print('Evaluate ...') top1 = contrastive_evaluate(val_dataloader, model, memory_bank_base) print('Result of kNN evaluation is %.2f' %(top1)) # Checkpoint print('Checkpoint ...') torch.save({'optimizer': optimizer.state_dict(), 'model': model.state_dict(), 'epoch': epoch + 1}, p['pretext_checkpoint']) if epoch in [50, 75]: # Save final model # torch.save(model.state_dict(), p['pretext_model']) # Mine the topk nearest neighbors at the very end (Train) # These will be served as input to the SCAN loss. print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) fill_memory_bank(base_dataloader, model, memory_bank_base) topk = 20 print('Mine the nearest neighbors (Top-%d)' %(topk)) indices, acc = memory_bank_base.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) # np.save(p['topk_neighbors_train_path'], indices) # Mine the topk nearest neighbors at the very end (Val) # These will be used for validation. print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) fill_memory_bank(val_dataloader, model, memory_bank_val) topk = 5 print('Mine the nearest neighbors (Top-%d)' %(topk)) indices, acc = memory_bank_val.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) # np.save(p['topk_neighbors_val_path'], indices) # Save final model torch.save(model.state_dict(), p['pretext_model']) # Mine the topk nearest neighbors at the very end (Train) # These will be served as input to the SCAN loss. print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) fill_memory_bank(base_dataloader, model, memory_bank_base) topk = 20 print('Mine the nearest neighbors (Top-%d)' %(topk)) indices, acc = memory_bank_base.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) np.save(p['topk_neighbors_train_path'], indices) # Mine the topk nearest neighbors at the very end (Val) # These will be used for validation. print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) fill_memory_bank(val_dataloader, model, memory_bank_val) topk = 5 print('Mine the nearest neighbors (Top-%d)' %(topk)) indices, acc = memory_bank_val.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc)) np.save(p['topk_neighbors_val_path'], indices)
def main(): # Retrieve config file p = create_config(args.config_env, args.config_exp) print(colored(p, 'red')) # Model print(colored('Retrieve model', 'green')) model = get_model(p) print('Model is {}'.format(model.__class__.__name__)) print('Model parameters: {:.2f}M'.format( sum(p.numel() for p in model.parameters()) / 1e6)) print(model) model = model.to(device) # CUDNN print(colored('Set CuDNN benchmark', 'green')) torch.backends.cudnn.benchmark = True # Dataset print(colored('Retrieve dataset', 'green')) train_transforms = get_train_transformations(p) print('Train transforms:', train_transforms) val_transforms = get_val_transformations(p) print('Validation transforms:', val_transforms) train_dataset = get_train_dataset(p, train_transforms, to_augmented_dataset=True, split='train') # Split is for stl-10 val_dataset = get_val_dataset(p, val_transforms) train_dataloader = get_val_dataloader(p, train_dataset) val_dataloader = get_val_dataloader(p, val_dataset) print('Dataset contains {}/{} train/val samples'.format( len(train_dataset), len(val_dataset))) # Memory Bank print(colored('Build MemoryBank', 'green')) base_dataset = get_train_dataset( p, val_transforms, split='train') # Dataset w/o augs for knn eval base_dataloader = get_val_dataloader(p, base_dataset) memory_bank_base = MemoryBank(len(base_dataset), p['model_kwargs']['features_dim'], p['num_classes'], p['criterion_kwargs']['temperature']) memory_bank_base.to(device) memory_bank_val = MemoryBank(len(val_dataset), p['model_kwargs']['features_dim'], p['num_classes'], p['criterion_kwargs']['temperature']) memory_bank_val.to(device) # Checkpoint if os.path.exists(p['pretext_checkpoint']): print( colored( 'Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'green')) checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') # optimizer.load_state_dict(checkpoint['optimizer']) model.load_state_dict(checkpoint['model']) model.to(device) # start_epoch = checkpoint['epoch'] else: print( colored('No checkpoint file at {}'.format(p['pretext_checkpoint']), 'green')) start_epoch = 0 model = model.to(device) # # Training # print(colored('Starting main loop', 'green')) # with torch.no_grad(): # model.eval() # total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] # # # progress_bar = tqdm(train_dataloader) # for idx, batch in enumerate(train_dataloader): # images = batch['image'].to(device, non_blocking=True) # # target = batch['target'].to(device, non_blocking=True) # # output = model(images) # feature = F.normalize(output, dim=1) # feature_bank.append(feature) # # if idx % 25 == 0: # print("Feature bank buidling : {} / {}".format(idx, len(train_dataset)/p["batch_size"])) # # # [D, N] # feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() # print(colored("Feature bank created. Similarity index starts now", "green")) # print(feature_bank.size()) # # for idx, batch in enumerate(train_dataloader): # # images = batch['image'].to(device, non_blocking=True) # # target = batch['target'].to(device, non_blocking=True) # # output = model(images) # feature = F.normalize(output, dim=1) # # sim_indices = knn_predict(feature, feature_bank, "", "", 10, 0.1) # # print(sim_indices) # # if idx == 10: # break # # Mine the topk nearest neighbors at the very end (Train) # # These will be served as input to the SCAN loss. # print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'green')) # fill_memory_bank(base_dataloader, model, memory_bank_base) # topk = 20 # print('Mine the nearest neighbors (Top-%d)' %(topk)) # indices, acc = memory_bank_base.mine_nearest_neighbors(topk) # print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc)) # np.save(p['topk_neighbors_train_path'], indices) # Mine the topk nearest neighbors at the very end (Val) # These will be used for validation. print( colored('Fill memory bank for mining the nearest neighbors (val) ...', 'green')) fill_memory_bank(val_dataloader, model, memory_bank_val) topk = 5 print('Mine the nearest neighbors (Top-%d)' % (topk)) indices, acc = memory_bank_val.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on val set is %.2f' % (topk, 100 * acc)) np.save(p['topk_neighbors_val_path'], indices)
def main(): # Retrieve config file p = create_config(args.config_env, args.config_exp) print(colored(p, 'red')) # Model print(colored('Retrieve model', 'blue')) model = get_model(p) print('Model is {}'.format(model.__class__.__name__)) print('Model parameters: {:.2f}M'.format( sum(p.numel() for p in model.parameters()) / 1e6)) print(model) model = model.cuda() # CUDNN print(colored('Set CuDNN benchmark', 'blue')) torch.backends.cudnn.benchmark = True # Dataset val_transforms = get_val_transformations(p) print('Validation transforms:', val_transforms) val_dataset = get_val_dataset(p, val_transforms) val_dataloader = get_val_dataloader(p, val_dataset) print('Dataset contains {} val samples'.format(len(val_dataset))) # Memory Bank print(colored('Build MemoryBank', 'blue')) base_dataset = get_train_dataset( p, val_transforms, split='train') # Dataset w/o augs for knn eval base_dataloader = get_val_dataloader(p, base_dataset) memory_bank_base = MemoryBank(len(base_dataset), p['model_kwargs']['features_dim'], p['num_classes'], p['criterion_kwargs']['temperature']) memory_bank_base.cuda() memory_bank_val = MemoryBank(len(val_dataset), p['model_kwargs']['features_dim'], p['num_classes'], p['criterion_kwargs']['temperature']) memory_bank_val.cuda() # Checkpoint assert os.path.exists(p['pretext_checkpoint']) print( colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue')) checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu') model.load_state_dict(checkpoint) model.cuda() # Save model torch.save(model.state_dict(), p['pretext_model']) # Mine the topk nearest neighbors at the very end (Train) # These will be served as input to the SCAN loss. print( colored( 'Fill memory bank for mining the nearest neighbors (train) ...', 'blue')) fill_memory_bank(base_dataloader, model, memory_bank_base) topk = 20 print('Mine the nearest neighbors (Top-%d)' % (topk)) indices, acc = memory_bank_base.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on train set is %.2f' % (topk, 100 * acc)) np.save(p['topk_neighbors_train_path'], indices) # Mine the topk nearest neighbors at the very end (Val) # These will be used for validation. print( colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue')) fill_memory_bank(val_dataloader, model, memory_bank_val) topk = 5 print('Mine the nearest neighbors (Top-%d)' % (topk)) indices, acc = memory_bank_val.mine_nearest_neighbors(topk) print('Accuracy of top-%d nearest neighbors on val set is %.2f' % (topk, 100 * acc)) np.save(p['topk_neighbors_val_path'], indices)
def main(): # Retrieve config file p = create_config(args.config_env, args.config_exp) print(colored(p, 'red')) # Model print(colored('Retrieve model', 'green')) model = get_model(p) print('Model is {}'.format(model.__class__.__name__)) print(model) # model = torch.nn.DataParallel(model) model = model.to(device) # CUDNN print(colored('Set CuDNN benchmark', 'green')) torch.backends.cudnn.benchmark = True # Dataset print(colored('Retrieve dataset', 'green')) transforms = get_val_transformations(p) train_dataset = get_train_dataset(p, transforms) val_dataset = get_val_dataset(p, transforms) train_dataloader = get_val_dataloader(p, train_dataset) val_dataloader = get_val_dataloader(p, val_dataset) print('Dataset contains {}/{} train/val samples'.format(len(train_dataset), len(val_dataset))) # Memory Bank print(colored('Build MemoryBank', 'green')) memory_bank_train = MemoryBank(len(train_dataset), 2048, p['num_classes'], p['temperature']) memory_bank_train.to(device) memory_bank_val = MemoryBank(len(val_dataset), 2048, p['num_classes'], p['temperature']) memory_bank_val.to(device) # Load the official MoCoV2 checkpoint print(colored('Downloading moco v2 checkpoint', 'green')) # os.system('wget -L https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar') # Uploaded the model to Mist : Johan moco_state = torch.load(main_dir + model_dir + 'moco_v2_800ep_pretrain.pth.tar', map_location=device) # Transfer moco weights print(colored('Transfer MoCo weights to model', 'green')) new_state_dict = {} state_dict = moco_state['state_dict'] # for k in list(state_dict.keys()): # # Copy backbone weights # if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): # new_k = 'module.backbone.' + k[len('module.encoder_q.'):] # new_state_dict[new_k] = state_dict[k] # # # Copy mlp weights # elif k.startswith('module.encoder_q.fc'): # new_k = 'module.contrastive_head.' + k[len('module.encoder_q.fc.'):] # new_state_dict[new_k] = state_dict[k] # # else: # raise ValueError('Unexpected key {}'.format(k)) #Changed by Johan for k, v in state_dict.items(): if "conv" in k or "bn" in k or "layer" in k: new_k = "backbone." + k.split("module.encoder_q.")[1] new_state_dict[new_k] = v else: new_k = "contrastive_head." + k.split("module.encoder_q.fc.")[1] new_state_dict[new_k] = v model.load_state_dict(new_state_dict) # os.system('rm -rf moco_v2_800ep_pretrain.pth.tar') # Save final model print(colored('Save pretext model', 'green')) torch.save(model.state_dict(), p['pretext_model']) # model.contrastive_head = torch.nn.Identity() # In this case, we mine the neighbors before the MLP. model.contrastive_head = Identity()
def main(): # Read config file print(colored('Read config file {} ...'.format(args.config_exp), 'blue')) with open(args.config_exp, 'r') as stream: config = yaml.safe_load(stream) config[ 'batch_size'] = 512 # To make sure we can evaluate on a single 1080ti print(config) # Get dataset print(colored('Get validation dataset ...', 'blue')) transforms = get_val_transformations(config) dataset = get_val_dataset(config, transforms) dataloader = get_val_dataloader(config, dataset) print('Number of samples: {}'.format(len(dataset))) # Get model print(colored('Get model ...', 'blue')) model = get_model(config) print(model) # Read model weights print(colored('Load model weights ...', 'blue')) state_dict = torch.load(args.model, map_location='cpu') if config['setup'] in ['simclr', 'moco', 'selflabel']: model.load_state_dict(state_dict) elif config['setup'] == 'scan': model.load_state_dict(state_dict['model']) else: raise NotImplementedError # CUDA model.cuda() # Perform evaluation if config['setup'] in ['simclr', 'moco']: print( colored( 'Perform evaluation of the pretext task (setup={}).'.format( config['setup']), 'blue')) print('Create Memory Bank') if config['setup'] == 'simclr': # Mine neighbors after MLP memory_bank = MemoryBank(len(dataset), config['model_kwargs']['features_dim'], config['num_classes'], config['criterion_kwargs']['temperature']) else: # Mine neighbors before MLP memory_bank = MemoryBank(len(dataset), config['model_kwargs']['features_dim'], config['num_classes'], config['temperature']) memory_bank.cuda() print('Fill Memory Bank') fill_memory_bank(dataloader, model, memory_bank) print('Mine the nearest neighbors') for topk in [1, 5, 20]: # Similar to Fig 2 in paper _, acc = memory_bank.mine_nearest_neighbors(topk) print( 'Accuracy of top-{} nearest neighbors on validation set is {:.2f}' .format(topk, 100 * acc)) elif config['setup'] in ['scan', 'selflabel']: print( colored( 'Perform evaluation of the clustering model (setup={}).'. format(config['setup']), 'blue')) head = state_dict['head'] if config['setup'] == 'scan' else 0 predictions, features = get_predictions(config, dataloader, model, return_features=True) clustering_stats = hungarian_evaluate(head, predictions, dataset.classes, compute_confusion_matrix=True) print(clustering_stats) if args.visualize_prototypes: prototype_indices = get_prototypes(config, predictions[head], features, model) visualize_indices(prototype_indices, dataset, clustering_stats['hungarian_match']) else: raise NotImplementedError