def validate_few_shot(self): # start = time.time() # train_acc, train_std = meta_test(self.model, self.meta_trainloader) # train_time = time.time() - start # print('train_acc: {:.4f}, train_std: {:.4f}, time: {:.1f}'.format(train_acc, train_std, train_time)) # # start = time.time() # train_acc_feat, train_std_feat = meta_test(self.model, self.meta_trainloader, use_logit=False) # train_time = time.time() - start # print('train_acc_feat: {:.4f}, train_std: {:.4f}, time: {:.1f}'.format( # train_acc_feat, train_std_feat, train_time)) start = time.time() val_acc, val_std = meta_test(self.model, self.meta_valloader) val_time = time.time() - start print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc, val_std, val_time)) start = time.time() val_acc_feat, val_std_feat = meta_test(self.model, self.meta_valloader, use_logit=False) val_time = time.time() - start print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc_feat, val_std_feat, val_time)) start = time.time() test_acc, test_std = meta_test(self.model, self.meta_testloader) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(self.model, self.meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time)) pass
def main(): opt = parse_option() opt.n_test_runs = 600 train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders( opt) # load model model = create_model(opt.model, n_cls, opt.dataset) ckpt = torch.load(opt.model_path) model.load_state_dict(ckpt["model"]) if torch.cuda.is_available(): model = model.cuda() cudnn.benchmark = True start = time.time() test_acc, test_std = meta_test(model, meta_testloader) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time))
def evaluate(meta_valloader, model, args, mode): start = time.time() val_acc1, val_std1 = meta_test(model, meta_valloader, only_base=args.only_base, classifier=args.cls, is_norm=True) val_time = time.time() - start print(f'Mode: ' + mode) print(f'Partition: {args.partition} Accuracy: {round(val_acc1 * 100, 2)}' + u" \u00B1 " + f'{round(val_std1 * 100, 2)}, Time: {val_time}')
def main(): opt = parse_option() opt.n_test_runs = 600 train_loader, val_loader, meta_testloader, meta_valloader, n_cls, _ = get_dataloaders( opt) # load model model = create_model(opt.model, n_cls, opt.dataset) ckpt = torch.load(opt.model_path)["model"] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in ckpt.items(): name = k.replace("module.", "") new_state_dict[name] = v model.load_state_dict(new_state_dict) # model.load_state_dict(ckpt["model"]) if torch.cuda.is_available(): model = model.cuda() cudnn.benchmark = True start = time.time() test_acc, test_std = meta_test(model, meta_testloader) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time))
def main(): seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) opt = parse_option() train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) n_cls = 64 model = create_model(opt.model, n_cls, opt.dataset) ckpt = torch.load(opt.model_path) model.load_state_dict(ckpt['model']) if torch.cuda.is_available(): model = model.cuda() cudnn.benchmark = True start = time.time() test_acc, test_std = meta_test(model, meta_testloader, use_logit=False, classifier='original_avrithis') test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time))
def main(): opt = parse_option() # test loader args = opt if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_test_options['D'] meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # load model model = create_model(opt.model, n_cls, opt.dataset) ckpt = torch.load(opt.model_path) model.load_state_dict(ckpt['model']) if torch.cuda.is_available(): model = model.cuda() cudnn.benchmark = True # evalation start = time.time() val_acc, val_std = meta_test(model, meta_valloader) val_time = time.time() - start print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc, val_std, val_time)) start = time.time() val_acc_feat, val_std_feat = meta_test(model, meta_valloader, use_logit=False) val_time = time.time() - start print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc_feat, val_std_feat, val_time)) start = time.time() test_acc, test_std = meta_test(model, meta_testloader) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time))
def main(): opt = parse_option() # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # model model = create_model(opt.model, n_cls, opt.dataset) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, test_acc_top5, test_loss = validate(val_loader, model, criterion, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) logger.log_value('test_loss', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # few-shot evaluation start = time.time() val_acc, val_std = meta_test(model, meta_valloader, use_logit=True) val_time = time.time() - start print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc, val_std, val_time)) start = time.time() val_acc_feat, val_std_feat = meta_test(model, meta_valloader, use_logit=False) val_time = time.time() - start print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc_feat, val_std_feat, val_time)) start = time.time() test_acc, test_std = meta_test(model, meta_testloader, use_logit=True) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time)) logger.log_value('meta_val_acc', val_acc, opt.epochs) logger.log_value('meta_val_acc_feat', val_acc_feat, opt.epochs) logger.log_value('meta_test_acc', test_acc, opt.epochs) logger.log_value('meta_test_acc_feat', test_acc_feat, opt.epochs) # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model)) torch.save(state, save_file)
def main(): opt = parse_option() if opt.name is not None: wandb.init(name=opt.name) else: wandb.init() wandb.config.update(opt) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format(opt.dataset)) elif opt.dataset == 'CUB_200_2011': train_trans, test_trans = transforms_options['C'] vocab = lang_utils.load_vocab(opt.lang_dir) if opt.lsl else None devocab = {v:k for k,v in vocab.items()} if opt.lsl else None train_loader = DataLoader(CUB2011(args=opt, partition=train_partition, transform=train_trans, vocab=vocab), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CUB2011(args=opt, partition='val', transform=test_trans, vocab=vocab), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCUB2011(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, vocab=vocab), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCUB2011(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, vocab=vocab), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: raise NotImplementedError(opt.dataset) # no trainval supported yet n_cls = 150 else: n_cls = 100 else: raise NotImplementedError(opt.dataset) print('Amount training data: {}'.format(len(train_loader.dataset))) print('Amount val data: {}'.format(len(val_loader.dataset))) # model model = create_model(opt.model, n_cls, opt.dataset) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) criterion = nn.CrossEntropyLoss() # lsl lang_model = None if opt.lsl: if opt.glove_init: vecs = lang_utils.glove_init(vocab, emb_size=opt.lang_emb_size) embedding_model = nn.Embedding( len(vocab), opt.lang_emb_size, _weight=vecs if opt.glove_init else None ) if opt.freeze_emb: embedding_model.weight.requires_grad = False lang_input_size = n_cls if opt.use_logit else 640 # 640 for resnet12 lang_model = TextProposal( embedding_model, input_size=lang_input_size, hidden_size=opt.lang_hidden_size, project_input=lang_input_size != opt.lang_hidden_size, rnn=opt.rnn_type, num_layers=opt.rnn_num_layers, dropout=opt.rnn_dropout, vocab=vocab, **lang_utils.get_special_indices(vocab) ) if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True if opt.lsl: embedding_model = embedding_model.cuda() lang_model = lang_model.cuda() # tensorboard #logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1) # routine: supervised pre-training best_val_acc = 0 for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss, train_lang_loss = train( epoch, train_loader, model, criterion, optimizer, opt, lang_model, devocab=devocab if opt.lsl else None ) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print("==> validating...") test_acc, test_acc_top5, test_loss, test_lang_loss = validate( val_loader, model, criterion, opt, lang_model ) # wandb log_metrics = { 'train_acc': train_acc, 'train_loss': train_loss, 'val_acc': test_acc, 'val_acc_top5': test_acc_top5, 'val_loss': test_loss } if opt.lsl: log_metrics['train_lang_loss'] = train_lang_loss log_metrics['val_lang_loss'] = test_lang_loss wandb.log(log_metrics, step=epoch) # # regular saving # if epoch % opt.save_freq == 0 and not opt.dryrun: # print('==> Saving...') # state = { # 'epoch': epoch, # 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), # } # save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) # torch.save(state, save_file) if test_acc > best_val_acc: wandb.run.summary['best_val_acc'] = test_acc wandb.run.summary['best_val_acc_epoch'] = epoch # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(wandb.run.dir, '{}_last.pth'.format(opt.model)) torch.save(state, save_file) # evaluate on test set print("==> testing...") start = time.time() (test_acc, test_std), (test_acc5, test_std5) = meta_test(model, meta_testloader) test_time = time.time() - start print('Using logit layer for embedding') print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) print('test_acc top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(test_acc5, test_std5, test_time)) start = time.time() (test_acc_feat, test_std_feat), (test_acc5_feat, test_std5_feat) = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('Using layer before logits for embedding') print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time)) print('test_acc_feat top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format( test_acc5_feat, test_std5_feat, test_time)) wandb.run.summary['test_acc'] = test_acc wandb.run.summary['test_std'] = test_std wandb.run.summary['test_acc5'] = test_acc5 wandb.run.summary['test_std5'] = test_std5 wandb.run.summary['test_acc_feat'] = test_acc_feat wandb.run.summary['test_std_feat'] = test_std_feat wandb.run.summary['test_acc5_feat'] = test_acc5_feat wandb.run.summary['test_std5_feat'] = test_std5_feat
def main(): opt = parse_option() print(pp.pformat(vars(opt))) train_partition = "trainval" if opt.use_trainval else "train" if opt.dataset == "miniImageNet": train_trans, test_trans = transforms_options[opt.transform] if opt.augment == "none": train_train_trans = train_test_trans = test_trans elif opt.augment == "all": train_train_trans = train_test_trans = train_trans elif opt.augment == "spt": train_train_trans = train_trans train_test_trans = test_trans elif opt.augment == "qry": train_train_trans = test_trans train_test_trans = train_trans print("spt trans") print(train_train_trans) print("qry trans") print(train_test_trans) sub_batch_size, rmd = divmod(opt.batch_size, opt.apply_every) assert rmd == 0 print("Train sub batch-size:", sub_batch_size) meta_train_dataset = MetaImageNet( args=opt, partition="train", train_transform=train_train_trans, test_transform=train_test_trans, fname="miniImageNet_category_split_train_phase_%s.pickle", fix_seed=False, n_test_runs=10000000, # big number to never stop new_labels=False, ) meta_trainloader = DataLoader( meta_train_dataset, batch_size=sub_batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, pin_memory=True, ) meta_train_dataset_qry = MetaImageNet( args=opt, partition="train", train_transform=train_train_trans, test_transform=train_test_trans, fname="miniImageNet_category_split_train_phase_%s.pickle", fix_seed=False, n_test_runs=10000000, # big number to never stop new_labels=False, n_ways=opt.n_qry_way, n_shots=opt.n_qry_shot, n_queries=0, ) meta_trainloader_qry = DataLoader( meta_train_dataset_qry, batch_size=sub_batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, pin_memory=True, ) meta_val_dataset = MetaImageNet( args=opt, partition="val", train_transform=test_trans, test_transform=test_trans, fix_seed=False, n_test_runs=200, n_ways=5, n_shots=5, n_queries=15, ) meta_valloader = DataLoader( meta_val_dataset, batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, pin_memory=True, ) val_loader = DataLoader( ImageNet(args=opt, partition="val", transform=test_trans), batch_size=opt.sup_val_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, pin_memory=True, ) # if opt.use_trainval: # n_cls = 80 # else: # n_cls = 64 n_cls = len(meta_train_dataset.classes) print(n_cls) # x_spt, y_spt, x_qry, y_qry = next(iter(meta_trainloader)) # x_spt2, y_spt2, x_qry2, y_qry2 = next(iter(meta_trainloader_qry)) # print(x_spt, y_spt, x_qry, y_qry) # print(x_spt2, y_spt2, x_qry2, y_qry2) # print(x_spt.shape, y_spt.shape, x_qry.shape, y_qry.shape) # print(x_spt2.shape, y_spt2.shape, x_qry2.shape, y_qry2.shape) model = create_model( opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock, opt.track_stats, opt.initializer, opt.weight_norm, activation=opt.activation, normalization=opt.normalization, ) print(model) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): print(torch.cuda.get_device_name()) device = torch.device("cuda") # if opt.n_gpu > 1: # model = nn.DataParallel(model) model = model.to(device) criterion = criterion.to(device) cudnn.benchmark = True else: device = torch.device("cpu") print("Learning rate") print(opt.learning_rate) print("Inner Learning rate") print(opt.inner_lr) if opt.learn_lr: print("Optimizing learning rate") inner_lr = nn.Parameter(torch.tensor(opt.inner_lr), requires_grad=opt.learn_lr) optimizer = torch.optim.Adam( list(model.parameters()) + [inner_lr] if opt.learn_lr else model.parameters(), lr=opt.learning_rate, ) # classifier = model.classifier() inner_opt = torch.optim.SGD( model.classifier.parameters(), lr=opt.inner_lr, ) logger = SummaryWriter(logdir=opt.tb_folder, flush_secs=10, comment=opt.model_name) comet_logger = Experiment( api_key=os.environ["COMET_API_KEY"], project_name=opt.comet_project_name, workspace=opt.comet_workspace, disabled=not opt.logcomet, auto_metric_logging=False, ) comet_logger.set_name(opt.model_name) comet_logger.log_parameters(vars(opt)) comet_logger.set_model_graph(str(model)) if opt.cosine: eta_min = opt.learning_rate * opt.cosine_factor scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.num_steps, eta_min, -1) # routine: supervised pre-training data_sampler = iter(meta_trainloader) data_sampler_qry = iter(meta_trainloader_qry) pbar = tqdm( range(1, opt.num_steps + 1), miniters=opt.print_freq, mininterval=3, maxinterval=30, ncols=0, ) best_val_acc = 0.0 for step in pbar: if not opt.cosine: adjust_learning_rate(step, opt, optimizer) # print("==> training...") time1 = time.time() foa = 0.0 fol = 0.0 ioa = 0.0 iil = 0.0 fil = 0.0 iia = 0.0 fia = 0.0 for j in range(opt.apply_every): x_spt, y_spt, x_qry, y_qry = [ t.to(device) for t in next(data_sampler) ] x_qry2, y_qry2, _, _ = [ t.to(device) for t in next(data_sampler_qry) ] y_spt = y_spt.flatten(1) y_qry2 = y_qry2.flatten(1) x_qry = torch.cat((x_spt, x_qry, x_qry2), 1) y_qry = torch.cat((y_spt, y_qry, y_qry2), 1) if step == 1 and j == 0: print(x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size()) info = train_step( model, model.classifier, None, # inner_opt, inner_lr, x_spt, y_spt, x_qry, y_qry, reset_head=opt.reset_head, num_steps=opt.num_inner_steps, ) _foa = info["foa"] / opt.batch_size _fol = info["fol"] / opt.batch_size _ioa = info["ioa"] / opt.batch_size _iil = info["iil"] / opt.batch_size _fil = info["fil"] / opt.batch_size _iia = info["iia"] / opt.batch_size _fia = info["fia"] / opt.batch_size _fol.backward() foa += _foa.detach() fol += _fol.detach() ioa += _ioa.detach() iil += _iil.detach() fil += _fil.detach() iia += _iia.detach() fia += _fia.detach() optimizer.step() optimizer.zero_grad() inner_lr.data.clamp_(min=0.001) if opt.cosine: scheduler.step() if (step == 1) or (step % opt.eval_freq == 0): val_info = test_run( iter(meta_valloader), model, model.classifier, torch.optim.SGD(model.classifier.parameters(), lr=inner_lr.item()), num_inner_steps=opt.num_inner_steps_test, device=device, ) val_acc_feat, val_std_feat = meta_test( model, meta_valloader, use_logit=False, ) val_acc = val_info["outer"]["acc"].cpu() val_loss = val_info["outer"]["loss"].cpu() sup_acc, sup_acc_top5, sup_loss = validate( val_loader, model, criterion, print_freq=100000000, ) sup_acc = sup_acc.item() sup_acc_top5 = sup_acc_top5.item() print(f"\nValidation step {step}") print(f"MAML 5-way-5-shot accuracy: {val_acc.item()}") print(f"LR 5-way-5-shot accuracy: {val_acc_feat}+-{val_std_feat}") print( f"Supervised accuracy: Acc@1: {sup_acc} Acc@5: {sup_acc_top5} Loss: {sup_loss}" ) if val_acc_feat > best_val_acc: best_val_acc = val_acc_feat print( f"New best validation accuracy {best_val_acc.item()} saving checkpoints\n" ) # print(val_acc.item()) torch.save( { "opt": opt, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), "optimizer": optimizer.state_dict(), "step": step, "val_acc": val_acc, "val_loss": val_loss, "val_acc_lr": val_acc_feat, "sup_acc": sup_acc, "sup_acc_top5": sup_acc_top5, "sup_loss": sup_loss, }, os.path.join(opt.save_folder, "{}_best.pth".format(opt.model)), ) comet_logger.log_metrics( dict( fol=val_loss, foa=val_acc, acc_lr=val_acc_feat, sup_acc=sup_acc, sup_acc_top5=sup_acc_top5, sup_loss=sup_loss, ), step=step, prefix="val", ) logger.add_scalar("val_acc", val_acc, step) logger.add_scalar("val_loss", val_loss, step) logger.add_scalar("val_acc_lr", val_acc_feat, step) logger.add_scalar("sup_acc", sup_acc, step) logger.add_scalar("sup_acc_top5", sup_acc_top5, step) logger.add_scalar("sup_loss", sup_loss, step) if (step == 1) or (step % opt.eval_freq == 0) or (step % opt.print_freq == 0): tfol = fol.cpu() tfoa = foa.cpu() tioa = ioa.cpu() tiil = iil.cpu() tfil = fil.cpu() tiia = iia.cpu() tfia = fia.cpu() comet_logger.log_metrics( dict( fol=tfol, foa=tfoa, ioa=tfoa, iil=tiil, fil=tfil, iia=tiia, fia=tfia, ), step=step, prefix="train", ) logger.add_scalar("train_acc", tfoa.item(), step) logger.add_scalar("train_loss", tfol.item(), step) logger.add_scalar("train_ioa", tioa, step) logger.add_scalar("train_iil", tiil, step) logger.add_scalar("train_fil", tfil, step) logger.add_scalar("train_iia", tiia, step) logger.add_scalar("train_fia", tfia, step) pbar.set_postfix( # iol=f"{info['iol'].item():.2f}", fol=f"{tfol.item():.2f}", # ioa=f"{info['ioa'].item():.2f}", foa=f"{tfoa.item():.2f}", ioa=f"{tioa.item():.2f}", iia=f"{tiia.item():.2f}", fia=f"{tfia.item():.2f}", vl=f"{val_loss.item():.2f}", va=f"{val_acc.item():.2f}", valr=f"{val_acc_feat:.2f}", lr=f"{inner_lr.item():.4f}", vsa=f"{sup_acc:.2f}", # iil=f"{info['iil'].item():.2f}", # fil=f"{info['fil'].item():.2f}", # iia=f"{info['iia'].item():.2f}", # fia=f"{info['fia'].item():.2f}", # counter=info["counter"], refresh=True, ) # save the last model state = { "opt": opt, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), "optimizer": optimizer.state_dict(), "step": step, } save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model)) torch.save(state, save_file)
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] if opt.distill in ['contrast']: train_set = ImageNet(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = ImageNet(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] if opt.distill in ['contrast']: train_set = TieredImageNet(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = TieredImageNet(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] if opt.distill in ['contrast']: train_set = CIFAR100(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = CIFAR100(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls, opt.dataset) model_s = create_model(opt.model_s, n_cls, opt.dataset) data = torch.randn(2, 3, 84, 84) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'contrast': criterion_kd = NCELoss(opt, n_data) embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim) embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'hint': criterion_kd = HintLoss() else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) # routine: supervised model distillation for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) logger.log_value('test_loss', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # few-shot evaluation start = time.time() val_acc, val_std = meta_test(model_s, meta_valloader, use_logit=True) val_time = time.time() - start print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc, val_std, val_time)) start = time.time() val_acc_feat, val_std_feat = meta_test(model_s, meta_valloader, use_logit=False) val_time = time.time() - start print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc_feat, val_std_feat, val_time)) start = time.time() test_acc, test_std = meta_test(model_s, meta_testloader, use_logit=True) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model_s, meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time)) logger.log_value('meta_val_acc', val_acc, opt.epochs) logger.log_value('meta_val_acc_feat', val_acc_feat, opt.epochs) logger.log_value('meta_test_acc', test_acc, opt.epochs) logger.log_value('meta_test_acc_feat', test_acc_feat, opt.epochs) # save the last model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
# start = time.time() # val_acc_feat, val_std_feat = meta_test(model, meta_valloader, use_logit=False, classifier=opt.classifier) # val_time = time.time() - start # print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc_feat, val_std_feat, val_time)) # start = time.time() # test_acc, test_std = meta_test(model, meta_testloader, classifier=opt.classifier, model_list=model_list) # test_time = time.time() - start # print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) # start = time.time() # test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False, classifier=opt.classifier, model_list=model_list) # test_time = time.time() - start # print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc_feat, test_std_feat, test_time)) # start = time.time() # test_acc, test_std = meta_test(model, meta_testloader, is_norm=False, classifier=opt.classifier, model_list=model_list) # test_time = time.time() - start # print('test_acc (no normalization): {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False, is_norm=False, classifier=opt.classifier) test_time = time.time() - start print( 'test_acc_feat (no normalization): {:.4f}, test_std: {:.4f}, time: {:.1f}' .format(test_acc_feat, test_std_feat, test_time))
def generate_final_report(model, opt, wandb): from eval.meta_eval import meta_test opt.n_shots = 1 train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders( opt) #validate meta_val_acc, meta_val_std = meta_test(model, meta_valloader) meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) #evaluate meta_test_acc, meta_test_std = meta_test(model, meta_testloader) meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format( meta_val_acc, meta_val_std)) print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format( meta_val_acc_feat, meta_val_std_feat)) print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format( meta_test_acc, meta_test_std)) print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format( meta_test_acc_feat, meta_test_std_feat)) wandb.log({ 'Final Meta Test Acc @1': meta_test_acc, 'Final Meta Test std @1': meta_test_std, 'Final Meta Test Acc (feat) @1': meta_test_acc_feat, 'Final Meta Test std (feat) @1': meta_test_std_feat, 'Final Meta Val Acc @1': meta_val_acc, 'Final Meta Val std @1': meta_val_std, 'Final Meta Val Acc (feat) @1': meta_val_acc_feat, 'Final Meta Val std (feat) @1': meta_val_std_feat }) opt.n_shots = 5 train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders( opt) #validate meta_val_acc, meta_val_std = meta_test(model, meta_valloader) meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) #evaluate meta_test_acc, meta_test_std = meta_test(model, meta_testloader) meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format( meta_val_acc, meta_val_std)) print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format( meta_val_acc_feat, meta_val_std_feat)) print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format( meta_test_acc, meta_test_std)) print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format( meta_test_acc_feat, meta_test_std_feat)) wandb.log({ 'Final Meta Test Acc @5': meta_test_acc, 'Final Meta Test std @5': meta_test_std, 'Final Meta Test Acc (feat) @5': meta_test_acc_feat, 'Final Meta Test std (feat) @5': meta_test_std_feat, 'Final Meta Val Acc @5': meta_val_acc, 'Final Meta Val std @5': meta_val_std, 'Final Meta Val Acc (feat) @5': meta_val_acc_feat, 'Final Meta Val std (feat) @5': meta_val_std_feat })
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu logger = get_logger(name='log', log_dir='.') # suppress printing if not master if args.multiprocessing_distributed and args.gpu != 0: def print_pass(*args): pass builtins.print = print_pass if args.gpu is not None: logger.info("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.dataset == 'tiered': train_dataset = TieredImageNet( root=args.data, partition='train', mode=args.mode, transform=ancor.loader.TwoCropsTransform( transforms.Compose(AUGS[f"train_{args.dataset}"]))) val_dataset = MetaTieredImageNet( args=Box(data_root=args.data, mode='fine', n_ways=5, n_shots=1, n_queries=15, n_test_runs=200, n_aug_support_samples=5), partition='validation', train_transform=transforms.Compose( AUGS[f"meta_test_{args.dataset}"]), test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"])) fg_val_dataset = MetaFGTieredImageNet( args=Box(data_root=args.data, mode='fine', n_ways=5, n_shots=1, n_queries=15, n_test_runs=200, n_aug_support_samples=5), partition='validation', train_transform=transforms.Compose( AUGS[f"meta_test_{args.dataset}"]), test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"])) elif args.dataset == 'cifar100': train_transforms = transforms.Compose( AUGS[f"train_{args.dataset}"][1:]) train_dataset = Cifar100( root=args.data, train=True, mode=args.mode, transform=ancor.loader.TwoCropsTransform(train_transforms)) val_dataset = MetaCifar100(args=Box( data_root=args.data, mode='fine', n_ways=5, n_shots=1, n_queries=15, n_test_runs=200, n_aug_support_samples=5, ), partition='test', train_transform=transforms.Compose( AUGS[f"meta_test_{args.dataset}"]), test_transform=transforms.Compose( AUGS[f"test_{args.dataset}"])) fg_val_dataset = MetaFGCifar100(args=Box( data_root=args.data, mode='fine', n_ways=5, n_shots=1, n_queries=15, n_test_runs=200, n_aug_support_samples=5, ), partition='test', train_transform=transforms.Compose( AUGS[f"meta_test_{args.dataset}"]), test_transform=transforms.Compose( AUGS[f"test_{args.dataset}"])) elif args.dataset in ['living17', 'entity13', 'nonliving26', 'entity30']: breeds_factory = BREEDSFactory( info_dir=os.path.join(args.data, "BREEDS"), data_dir=os.path.join(args.data, "Data", "CLS-LOC")) train_dataset = breeds_factory.get_breeds( ds_name=args.dataset, partition='train', mode=args.mode, transforms=ancor.loader.TwoCropsTransform( transforms.Compose(AUGS[f"train_{args.dataset}"])), split=args.split) val_dataset = MetaDataset( args=Box( n_ways=5, n_shots=1, n_queries=15, n_test_runs=200, n_aug_support_samples=5, ), dataset=breeds_factory.get_breeds(ds_name=args.dataset, partition='val', mode='fine', transforms=None, split=args.split), train_transform=transforms.Compose( AUGS[f"meta_test_{args.dataset}"]), test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"])) fg_val_dataset = MetaDataset( args=Box( n_ways=5, n_shots=1, n_queries=15, n_test_runs=200, n_aug_support_samples=5, ), dataset=breeds_factory.get_breeds(ds_name=args.dataset, partition='val', mode='fine', transforms=None, split=args.split), train_transform=transforms.Compose( AUGS[f"meta_test_{args.dataset}"]), test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"]), fg=True) else: raise NotImplementedError # create model model, criterions = ANCORModelGenerator().generate_ancor_model( arch=args.arch, head_type=args.head, dim=args.cst_dim, K=args.queue_k, m=args.encoder_m, T=args.cst_t, mlp=args.mlp, num_classes=train_dataset.num_classes, queue_type=args.queue, metric=args.metric, calc_types=args.calc_types, loss_types=args.loss_types, gpu=args.gpu) log(args.rank, logger, "loaded model") if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) # comment out the following line for debugging # raise NotImplementedError("Only DistributedDataParallel is supported.") else: # AllGather implementation (batch shuffle, queue update, etc.) in # this code only supports DistributedDataParallel. raise NotImplementedError("Only DistributedDataParallel is supported.") # define loss function (criterion) and optimizer optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): log(args.rank, logger, "=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] msg = model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if 'best_accs' in checkpoint: best_accs = checkpoint['best_accs'] else: best_accs = [0.] log( args.rank, logger, " WARNING: BACKWARDS COMPATIBLE RESUME. NO BEST MODEL CHECKPOINT" ) log( args.rank, logger, "=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: log(args.rank, logger, "=> no checkpoint found at '{}'".format(args.resume)) raise ValueError() else: best_accs = [0.] cudnn.benchmark = True # Data loading code if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) fg_val_loader = torch.utils.data.DataLoader(fg_val_dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) for epoch in range(args.start_epoch, args.epochs): best_flag = False if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, model, criterions, optimizer, epoch, logger, args) if args.rank % ngpus_per_node == 0: if (epoch + 1) % args.save_freq == 0 and val_dataset is not None: val_acc, val_std = meta_test(model.module.encoder_q, val_loader, only_base=True, is_norm=True, classifier="Cosine") if best_accs[-1] < val_acc: best_accs.append(val_acc) with open("best_accs.log", 'a') as f: f.write( f"EPOCH {epoch}: Validation Accuracy: {round(val_acc * 100, 2)}+-{round(val_std * 100, 2)}\n" ) best_flag = True log( args.rank, logger, f"EPOCH {epoch}: Validation Accuracy: {round(val_acc * 100, 2)}+-{round(val_std * 100, 2)}" ) if (epoch + 1) % args.save_freq == 0 and fg_val_dataset is not None: val_acc, val_std = meta_test(model.module.encoder_q, fg_val_loader, only_base=True, is_norm=True, classifier="Cosine") log( args.rank, logger, f"EPOCH {epoch}: Validation FG - Accuracy: {round(val_acc * 100, 2)}+-{round(val_std * 100, 2)}" ) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if (epoch + 1) % args.save_freq == 0 or best_flag: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_accs': best_accs }, is_best=best_flag, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) if args.rank % ngpus_per_node == 0: remove_excess_epochs(args.keep_epochs)
def main(): best_acc = 0 opt = parse_option() wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags) wandb.config.update(opt) wandb.save('*.py') wandb.run.save() # dataloader train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders( opt) # model model_t = [] if ("," in opt.path_t): for path in opt.path_t.split(","): model_t.append(load_teacher(path, opt.model_t, n_cls, opt.dataset)) else: model_t.append( load_teacher(opt.path_t, opt.model_t, n_cls, opt.dataset)) # model_s = create_model(opt.model_s, n_cls, opt.dataset, dropout=0.4) # model_s = Wrapper(model_, opt) model_s = copy.deepcopy(model_t[0]) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) criterion_kd = DistillKL(opt.kd_T) optimizer = optim.SGD(model_s.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) if torch.cuda.is_available(): for m in model_t: m.cuda() model_s.cuda() criterion_cls = criterion_cls.cuda() criterion_div = criterion_div.cuda() criterion_kd = criterion_kd.cuda() cudnn.benchmark = True meta_test_acc = 0 meta_test_std = 0 # routine: supervised model distillation for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, model_s, model_t, criterion_cls, criterion_div, criterion_kd, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) val_acc = 0 val_loss = 0 meta_val_acc = 0 meta_val_std = 0 # val_acc, val_acc_top5, val_loss = validate(val_loader, model_s, criterion_cls, opt) # #evaluate # start = time.time() # meta_val_acc, meta_val_std = meta_test(model_s, meta_valloader) # test_time = time.time() - start # print('Meta Val Acc: {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time)) #evaluate start = time.time() meta_test_acc, meta_test_std = meta_test(model_s, meta_testloader, use_logit=False) test_time = time.time() - start print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'. format(meta_test_acc, meta_test_std, test_time)) # regular saving if epoch % opt.save_freq == 0 or epoch == opt.epochs: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, 'model_' + str(wandb.run.name) + '.pth') torch.save(state, save_file) #wandb saving torch.save(state, os.path.join(wandb.run.dir, "model.pth")) wandb.log({ 'epoch': epoch, 'Train Acc': train_acc, 'Train Loss': train_loss, 'Val Acc': val_acc, 'Val Loss': val_loss, 'Meta Test Acc': meta_test_acc, 'Meta Test std': meta_test_std, 'Meta Val Acc': meta_val_acc, 'Meta Val std': meta_val_std }) #final report generate_final_report(model_s, opt, wandb) #remove output.txt log file output_log_file = os.path.join(wandb.run.dir, "output.log") if os.path.isfile(output_log_file): os.remove(output_log_file) else: ## Show an error ## print("Error: %s file not found" % output_log_file)