def main(): opt = parse_option() # test loader args = opt if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_test_options['D'] meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # load model model = create_model(opt.model, n_cls, opt.dataset) ckpt = torch.load(opt.model_path) model.load_state_dict(ckpt['model']) if torch.cuda.is_available(): model = model.cuda() cudnn.benchmark = True # evalation start = time.time() val_acc, val_std = meta_test(model, meta_valloader) val_time = time.time() - start print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc, val_std, val_time)) start = time.time() val_acc_feat, val_std_feat = meta_test(model, meta_valloader, use_logit=False) val_time = time.time() - start print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format( val_acc_feat, val_std_feat, val_time)) start = time.time() test_acc, test_std = meta_test(model, meta_testloader) test_time = time.time() - start print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc, test_std, test_time)) start = time.time() test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time))
def main(): opt = parse_option() # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # model if not opt.load_latest: model = create_model(opt.model, n_cls, opt.dataset) else: latest_file = os.path.join(opt.save_folder, 'latest.pth') model = load_teacher(latest_file, n_cls, opt.dataset) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, test_acc_top5, test_loss = validate(val_loader, model, criterion, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) logger.log_value('test_loss', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) latest_file = os.path.join(opt.save_folder, 'latest.pth') os.symlink(save_file, latest_file) # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model)) torch.save(state, save_file)
def main(): opt = parse_option() with open(f"{opt.tb_folder}/config.json", "w") as fo: fo.write(json.dumps(vars(opt), indent=4)) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) elif opt.dataset == "imagenet": train_trans, test_trans = transforms_options["A"] train_dataset = ImagenetFolder(root=os.path.join( opt.data_root, "train"), transform=train_trans) val_dataset = ImagenetFolder(root=os.path.join(opt.data_root, "val"), transform=test_trans) train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(val_dataset, batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) n_cls = 1000 else: raise NotImplementedError(opt.dataset) # model model = create_model(opt.model, n_cls, opt.dataset, use_srl=opt.srl) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) if opt.label_smoothing: criterion = LabelSmoothing(smoothing=opt.smoothing_ratio) elif opt.gce: criterion = GuidedComplementEntropy(alpha=opt.gce_alpha, classes=n_cls) else: criterion = nn.CrossEntropyLoss() if opt.opl: auxiliary_loss = OrthogonalProjectionLoss(use_attention=True) elif opt.popl: auxiliary_loss = PerpetualOrthogonalProjectionLoss(feat_dim=640) else: auxiliary_loss = None if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() if auxiliary_loss is not None: auxiliary_loss = auxiliary_loss.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) else: scheduler = None # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() if auxiliary_loss is not None: train_acc, train_loss, [train_cel, train_opl ] = train(epoch=epoch, train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, opt=opt, auxiliary=auxiliary_loss) else: train_acc, train_loss = train(epoch=epoch, train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, opt=opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('accuracy/train_acc', train_acc, epoch) logger.log_value('train_losses/loss', train_loss, epoch) if auxiliary_loss is not None: logger.log_value('train_losses/cel', train_cel, epoch) logger.log_value('train_losses/opl', train_opl, epoch) else: logger.log_value('train_losses/cel', train_loss, epoch) if auxiliary_loss is not None: test_acc, test_acc_top5, test_loss, [test_cel, test_opl] = \ validate(val_loader, model, criterion, opt, auxiliary=auxiliary_loss) else: test_acc, test_acc_top5, test_loss = validate( val_loader, model, criterion, opt) logger.log_value('accuracy/test_acc', test_acc, epoch) logger.log_value('accuracy/test_acc_top5', test_acc_top5, epoch) logger.log_value('test_losses/loss', test_loss, epoch) if auxiliary_loss is not None: logger.log_value('test_losses/cel', test_cel, epoch) logger.log_value('test_losses/opl', test_opl, epoch) else: logger.log_value('test_losses/cel', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model)) torch.save(state, save_file)
def get_eval_loaders(opt): """ Create the evaluation dataloaders """ train_trans, test_trans = transforms_options[opt.transform] # ImagetNet derivatives - miniImageNet if opt.dataset == 'miniImageNet': assert opt.transform == "A" meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 # ImagetNet derivatives - tieredImageNet elif opt.dataset == 'tieredImageNet': assert opt.transform == "A" meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 # CIFAR-100 derivatives - both CIFAR-FS & FC100 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': assert opt.transform == "D" meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format(opt.dataset)) # For cross-domain - we evaluate on a new dataset / domain elif opt.dataset in ['cub', 'cars', 'places', 'plantae']: train_classes = {'cub': 100, 'cars': 98, 'places': 183, 'plantae': 100} assert opt.transform == "C" assert not opt.use_trainval, f"Train val option not possible for dataset {opt.dataset}" meta_testloader = DataLoader(MetaCUB(args=opt, partition='novel', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCUB(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, fix_seed=False), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) n_cls = train_classes[opt.dataset] else: raise NotImplementedError(opt.dataset) return meta_testloader, meta_valloader, n_cls
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] if opt.distill in ['contrast']: train_set = ImageNet(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = ImageNet(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] if opt.distill in ['contrast']: train_set = TieredImageNet(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = TieredImageNet(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] if opt.distill in ['contrast']: train_set = CIFAR100(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = CIFAR100(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls, opt.dataset) model_s = create_model(opt.model_s, n_cls, opt.dataset) data = torch.randn(2, 3, 84, 84) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'contrast': criterion_kd = NCELoss(opt, n_data) embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim) embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'hint': criterion_kd = HintLoss() else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) # routine: supervised model distillation for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) logger.log_value('test_loss', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # save the last model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60
def main(): opt = parse_option() if opt.name is not None: wandb.init(name=opt.name) else: wandb.init() wandb.config.update(opt) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format(opt.dataset)) elif opt.dataset == 'CUB_200_2011': train_trans, test_trans = transforms_options['C'] vocab = lang_utils.load_vocab(opt.lang_dir) if opt.lsl else None devocab = {v:k for k,v in vocab.items()} if opt.lsl else None train_loader = DataLoader(CUB2011(args=opt, partition=train_partition, transform=train_trans, vocab=vocab), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CUB2011(args=opt, partition='val', transform=test_trans, vocab=vocab), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCUB2011(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, vocab=vocab), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCUB2011(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, vocab=vocab), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: raise NotImplementedError(opt.dataset) # no trainval supported yet n_cls = 150 else: n_cls = 100 else: raise NotImplementedError(opt.dataset) print('Amount training data: {}'.format(len(train_loader.dataset))) print('Amount val data: {}'.format(len(val_loader.dataset))) # model model = create_model(opt.model, n_cls, opt.dataset) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) criterion = nn.CrossEntropyLoss() # lsl lang_model = None if opt.lsl: if opt.glove_init: vecs = lang_utils.glove_init(vocab, emb_size=opt.lang_emb_size) embedding_model = nn.Embedding( len(vocab), opt.lang_emb_size, _weight=vecs if opt.glove_init else None ) if opt.freeze_emb: embedding_model.weight.requires_grad = False lang_input_size = n_cls if opt.use_logit else 640 # 640 for resnet12 lang_model = TextProposal( embedding_model, input_size=lang_input_size, hidden_size=opt.lang_hidden_size, project_input=lang_input_size != opt.lang_hidden_size, rnn=opt.rnn_type, num_layers=opt.rnn_num_layers, dropout=opt.rnn_dropout, vocab=vocab, **lang_utils.get_special_indices(vocab) ) if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True if opt.lsl: embedding_model = embedding_model.cuda() lang_model = lang_model.cuda() # tensorboard #logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1) # routine: supervised pre-training best_val_acc = 0 for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss, train_lang_loss = train( epoch, train_loader, model, criterion, optimizer, opt, lang_model, devocab=devocab if opt.lsl else None ) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print("==> validating...") test_acc, test_acc_top5, test_loss, test_lang_loss = validate( val_loader, model, criterion, opt, lang_model ) # wandb log_metrics = { 'train_acc': train_acc, 'train_loss': train_loss, 'val_acc': test_acc, 'val_acc_top5': test_acc_top5, 'val_loss': test_loss } if opt.lsl: log_metrics['train_lang_loss'] = train_lang_loss log_metrics['val_lang_loss'] = test_lang_loss wandb.log(log_metrics, step=epoch) # # regular saving # if epoch % opt.save_freq == 0 and not opt.dryrun: # print('==> Saving...') # state = { # 'epoch': epoch, # 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), # } # save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) # torch.save(state, save_file) if test_acc > best_val_acc: wandb.run.summary['best_val_acc'] = test_acc wandb.run.summary['best_val_acc_epoch'] = epoch # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(wandb.run.dir, '{}_last.pth'.format(opt.model)) torch.save(state, save_file) # evaluate on test set print("==> testing...") start = time.time() (test_acc, test_std), (test_acc5, test_std5) = meta_test(model, meta_testloader) test_time = time.time() - start print('Using logit layer for embedding') print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) print('test_acc top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(test_acc5, test_std5, test_time)) start = time.time() (test_acc_feat, test_std_feat), (test_acc5_feat, test_std5_feat) = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('Using layer before logits for embedding') print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time)) print('test_acc_feat top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format( test_acc5_feat, test_std5_feat, test_time)) wandb.run.summary['test_acc'] = test_acc wandb.run.summary['test_std'] = test_std wandb.run.summary['test_acc5'] = test_acc5 wandb.run.summary['test_std5'] = test_std5 wandb.run.summary['test_acc_feat'] = test_acc_feat wandb.run.summary['test_std_feat'] = test_std_feat wandb.run.summary['test_acc5_feat'] = test_acc5_feat wandb.run.summary['test_std5_feat'] = test_std5_feat
def get_dataloaders(opt): # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 no_sample = len( ImageNet(args=opt, partition=train_partition, transform=train_trans)) elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 no_sample = len( TieredImageNet(args=opt, partition=train_partition, transform=train_trans)) elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_trainloader = DataLoader(MetaCIFAR100(args=opt, partition='train', train_transform=train_trans, test_transform=test_trans), batch_size=1, shuffle=True, drop_last=False, num_workers=opt.num_workers) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) no_sample = len( CIFAR100(args=opt, partition=train_partition, transform=train_trans)) else: raise NotImplementedError(opt.dataset) return train_loader, val_loader, meta_testloader, meta_valloader, n_cls, no_sample
def main(): opt = parse_option() # dataloader train_partition = "trainval" if opt.use_trainval else "train" if opt.dataset == "miniImageNet": train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader( ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, ) val_loader = DataLoader( ImageNet(args=opt, partition="val", transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, ) # meta_testloader = DataLoader( # MetaImageNet( # args=opt, # partition="test", # train_transform=train_trans, # test_transform=test_trans, # ), # batch_size=opt.test_batch_size, # shuffle=False, # drop_last=False, # num_workers=opt.num_workers, # ) # meta_valloader = DataLoader( # MetaImageNet( # args=opt, # partition="val", # train_transform=train_trans, # test_transform=test_trans, # ), # batch_size=opt.test_batch_size, # shuffle=False, # drop_last=False, # num_workers=opt.num_workers, # ) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == "tieredImageNet": train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader( TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, ) val_loader = DataLoader( TieredImageNet(args=opt, partition="train_phase_val", transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, ) meta_testloader = DataLoader( MetaTieredImageNet( args=opt, partition="test", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) meta_valloader = DataLoader( MetaTieredImageNet( args=opt, partition="val", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == "CIFAR-FS" or opt.dataset == "FC100": train_trans, test_trans = transforms_options["D"] train_loader = DataLoader( CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, ) val_loader = DataLoader( CIFAR100(args=opt, partition="train", transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, ) meta_testloader = DataLoader( MetaCIFAR100( args=opt, partition="test", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) meta_valloader = DataLoader( MetaCIFAR100( args=opt, partition="val", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) if opt.use_trainval: n_cls = 80 else: if opt.dataset == "CIFAR-FS": n_cls = 64 elif opt.dataset == "FC100": n_cls = 60 else: raise NotImplementedError( "dataset not supported: {}".format(opt.dataset) ) else: raise NotImplementedError(opt.dataset) # model model = create_model(opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock) # optimizer if opt.adam: optimizer = torch.optim.Adam( model.parameters(), lr=opt.learning_rate, weight_decay=0.0005 ) else: optimizer = optim.SGD( model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay, ) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) comet_logger = Experiment( api_key=os.environ["COMET_API_KEY"], project_name=opt.comet_project_name, workspace=opt.comet_workspace, disabled=not opt.logcomet, ) comet_logger.set_name(opt.model_name) comet_logger.log_parameters(vars(opt)) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate ** opt.cosine_factor) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1 ) # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() with comet_logger.train(): train_acc, train_loss = train( epoch, train_loader, model, criterion, optimizer, opt ) comet_logger.log_metrics( {"acc": train_acc.cpu(), "loss_epoch": train_loss}, epoch=epoch ) time2 = time.time() print("epoch {}, total time {:.2f}".format(epoch, time2 - time1)) logger.log_value("train_acc", train_acc, epoch) logger.log_value("train_loss", train_loss, epoch) with comet_logger.validate(): test_acc, test_acc_top5, test_loss = validate( val_loader, model, criterion, opt ) comet_logger.log_metrics( {"acc": test_acc.cpu(), "acc_top5": test_acc_top5.cpu(), "loss": test_loss,}, epoch=epoch, ) logger.log_value("test_acc", test_acc, epoch) logger.log_value("test_acc_top5", test_acc_top5, epoch) logger.log_value("test_loss", test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print("==> Saving...") state = { "epoch": epoch, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch) ) torch.save(state, save_file) # save the last model state = { "opt": opt, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model)) torch.save(state, save_file)
def get_dataloaders(opt): # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'toy': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100_toy(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100_toy(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) # train_trans, test_trans = transforms_test_options[opt.transform] # meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', # train_transform=train_trans, # test_transform=test_trans), # batch_size=opt.test_batch_size, shuffle=False, drop_last=False, # num_workers=opt.num_workers) # meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', # train_transform=train_trans, # test_transform=test_trans), # batch_size=opt.test_batch_size, shuffle=False, drop_last=False, # num_workers=opt.num_workers) n_cls = 5 return train_loader, val_loader, 5, 5, n_cls if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] # ns = [opt.n_shots].copy() # opt.n_ways = 32 # opt.n_shots = 5 # opt.n_aug_support_samples = 2 meta_trainloader = DataLoader(MetaCIFAR100(args=opt, partition='train', train_transform=train_trans, test_transform=test_trans), batch_size=1, shuffle=True, drop_last=False, num_workers=opt.num_workers) # opt.n_ways = 5 # opt.n_shots = ns[0] # print(opt.n_shots) # opt.n_aug_support_samples = 5 meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) # return train_loader, val_loader, meta_trainloader, meta_testloader, meta_valloader, n_cls else: raise NotImplementedError(opt.dataset) return train_loader, val_loader, meta_testloader, meta_valloader, n_cls